inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import os.path
|
|
2
|
+
from typing import List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
import torchvision
|
|
7
|
+
from groundingdino.util.inference import load_model, predict
|
|
8
|
+
from torch import nn
|
|
9
|
+
from torchvision import transforms
|
|
10
|
+
from torchvision.ops import box_convert
|
|
11
|
+
|
|
12
|
+
from inference_models import Detections
|
|
13
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
14
|
+
from inference_models.entities import ColorFormat, ImageDimensions
|
|
15
|
+
from inference_models.errors import ModelRuntimeError
|
|
16
|
+
from inference_models.models.base.object_detection import (
|
|
17
|
+
OpenVocabularyObjectDetectionModel,
|
|
18
|
+
)
|
|
19
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GroundingDinoForObjectDetectionTorch(
|
|
23
|
+
OpenVocabularyObjectDetectionModel[
|
|
24
|
+
torch.Tensor,
|
|
25
|
+
List[ImageDimensions],
|
|
26
|
+
Tuple[List[torch.Tensor], List[torch.Tensor], List[List[str]], List[str]],
|
|
27
|
+
]
|
|
28
|
+
):
|
|
29
|
+
@classmethod
|
|
30
|
+
def from_pretrained(
|
|
31
|
+
cls,
|
|
32
|
+
model_name_or_path: str,
|
|
33
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
34
|
+
**kwargs,
|
|
35
|
+
) -> "GroundingDinoForObjectDetectionTorch":
|
|
36
|
+
model_package_content = get_model_package_contents(
|
|
37
|
+
model_package_dir=model_name_or_path,
|
|
38
|
+
elements=["weights.pth", "config.py"],
|
|
39
|
+
)
|
|
40
|
+
text_encoder_dir = os.path.join(model_name_or_path, "text_encoder")
|
|
41
|
+
loader_kwargs = {}
|
|
42
|
+
if os.path.isdir(text_encoder_dir):
|
|
43
|
+
loader_kwargs["text_encoder_type"] = text_encoder_dir
|
|
44
|
+
model = load_model(
|
|
45
|
+
model_config_path=model_package_content["config.py"],
|
|
46
|
+
model_checkpoint_path=model_package_content["weights.pth"],
|
|
47
|
+
**loader_kwargs,
|
|
48
|
+
).to(device)
|
|
49
|
+
return cls(model=model, device=device)
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
model: nn.Module,
|
|
54
|
+
device: torch.device,
|
|
55
|
+
):
|
|
56
|
+
self._model = model
|
|
57
|
+
self._device = device
|
|
58
|
+
self._numpy_transformations = transforms.Compose(
|
|
59
|
+
[
|
|
60
|
+
transforms.ToTensor(),
|
|
61
|
+
transforms.Resize([800, 800]),
|
|
62
|
+
transforms.Normalize(
|
|
63
|
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
64
|
+
),
|
|
65
|
+
]
|
|
66
|
+
)
|
|
67
|
+
self._tensors_transformations = transforms.Compose(
|
|
68
|
+
[
|
|
69
|
+
lambda x: x / 255.0,
|
|
70
|
+
transforms.Resize([800, 800]),
|
|
71
|
+
transforms.Normalize(
|
|
72
|
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
73
|
+
),
|
|
74
|
+
]
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def pre_process(
|
|
78
|
+
self,
|
|
79
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
80
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
81
|
+
**kwargs,
|
|
82
|
+
) -> Tuple[torch.Tensor, List[ImageDimensions]]:
|
|
83
|
+
if isinstance(images, np.ndarray):
|
|
84
|
+
input_color_format = input_color_format or "bgr"
|
|
85
|
+
if input_color_format != "rgb":
|
|
86
|
+
images = np.ascontiguousarray(images[:, :, ::-1])
|
|
87
|
+
pre_processed = self._numpy_transformations(images)
|
|
88
|
+
return (
|
|
89
|
+
torch.unsqueeze(pre_processed, dim=0).to(self._device),
|
|
90
|
+
[ImageDimensions(height=images.shape[0], width=images.shape[1])],
|
|
91
|
+
)
|
|
92
|
+
if isinstance(images, torch.Tensor):
|
|
93
|
+
input_color_format = input_color_format or "rgb"
|
|
94
|
+
if len(images.shape) == 3:
|
|
95
|
+
images = torch.unsqueeze(images, dim=0)
|
|
96
|
+
image_dimensions = ImageDimensions(
|
|
97
|
+
height=images.shape[2], width=images.shape[3]
|
|
98
|
+
)
|
|
99
|
+
images = images.to(self._device)
|
|
100
|
+
if input_color_format != "rgb":
|
|
101
|
+
images = images[:, [2, 1, 0], :, :]
|
|
102
|
+
return (
|
|
103
|
+
self._tensors_transformations(images.float()),
|
|
104
|
+
[image_dimensions] * images.shape[0],
|
|
105
|
+
)
|
|
106
|
+
if not isinstance(images, list):
|
|
107
|
+
raise ModelRuntimeError(
|
|
108
|
+
message="Pre-processing supports only np.array or torch.Tensor or list of above.",
|
|
109
|
+
help_url="https://todo",
|
|
110
|
+
)
|
|
111
|
+
if not len(images):
|
|
112
|
+
raise ModelRuntimeError(
|
|
113
|
+
message="Detected empty input to the model",
|
|
114
|
+
help_url="https://todo",
|
|
115
|
+
)
|
|
116
|
+
if isinstance(images[0], np.ndarray):
|
|
117
|
+
input_color_format = input_color_format or "bgr"
|
|
118
|
+
pre_processed, image_dimensions = [], []
|
|
119
|
+
for image in images:
|
|
120
|
+
if input_color_format != "rgb":
|
|
121
|
+
image = np.ascontiguousarray(image[:, :, ::-1])
|
|
122
|
+
image_dimensions.append(
|
|
123
|
+
ImageDimensions(height=image.shape[0], width=image.shape[1])
|
|
124
|
+
)
|
|
125
|
+
pre_processed.append(self._numpy_transformations(image))
|
|
126
|
+
return torch.stack(pre_processed, dim=0).to(self._device), image_dimensions
|
|
127
|
+
if isinstance(images[0], torch.Tensor):
|
|
128
|
+
input_color_format = input_color_format or "rgb"
|
|
129
|
+
pre_processed, image_dimensions = [], []
|
|
130
|
+
for image in images:
|
|
131
|
+
if len(image.shape) == 3:
|
|
132
|
+
image = torch.unsqueeze(image, dim=0)
|
|
133
|
+
if input_color_format != "rgb":
|
|
134
|
+
image = image[:, [2, 1, 0], :, :]
|
|
135
|
+
image_dimensions.append(
|
|
136
|
+
ImageDimensions(height=image.shape[2], width=image.shape[3])
|
|
137
|
+
)
|
|
138
|
+
pre_processed.append(self._tensors_transformations(image.float()))
|
|
139
|
+
return torch.cat(pre_processed, dim=0).to(self._device), image_dimensions
|
|
140
|
+
raise ModelRuntimeError(
|
|
141
|
+
message=f"Detected unknown input batch element: {type(images[0])}",
|
|
142
|
+
help_url="https://todo",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def forward(
|
|
146
|
+
self,
|
|
147
|
+
pre_processed_images: torch.Tensor,
|
|
148
|
+
classes: List[str],
|
|
149
|
+
conf_thresh: float = 0.5,
|
|
150
|
+
text_threshold: Optional[float] = None,
|
|
151
|
+
**kwargs,
|
|
152
|
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[str]], List[str]]:
|
|
153
|
+
if text_threshold is None:
|
|
154
|
+
text_threshold = conf_thresh
|
|
155
|
+
caption = ". ".join(classes)
|
|
156
|
+
all_boxes, all_logits, all_phrases = [], [], []
|
|
157
|
+
with torch.inference_mode():
|
|
158
|
+
for image in pre_processed_images:
|
|
159
|
+
boxes, logits, phrases = predict(
|
|
160
|
+
model=self._model,
|
|
161
|
+
image=image,
|
|
162
|
+
caption=caption,
|
|
163
|
+
box_threshold=conf_thresh,
|
|
164
|
+
text_threshold=text_threshold,
|
|
165
|
+
device=self._device,
|
|
166
|
+
remove_combined=True,
|
|
167
|
+
)
|
|
168
|
+
all_boxes.append(boxes)
|
|
169
|
+
all_logits.append(logits)
|
|
170
|
+
all_phrases.append(phrases)
|
|
171
|
+
return all_boxes, all_logits, all_phrases, classes
|
|
172
|
+
|
|
173
|
+
def post_process(
|
|
174
|
+
self,
|
|
175
|
+
model_results: Tuple[
|
|
176
|
+
List[torch.Tensor], List[torch.Tensor], List[List[str]], List[str]
|
|
177
|
+
],
|
|
178
|
+
pre_processing_meta: List[ImageDimensions],
|
|
179
|
+
iou_thresh: float = 0.45,
|
|
180
|
+
max_detections: int = 100,
|
|
181
|
+
class_agnostic: bool = False,
|
|
182
|
+
**kwargs,
|
|
183
|
+
) -> List[Detections]:
|
|
184
|
+
all_boxes, all_logits, all_phrases, classes = model_results
|
|
185
|
+
results = []
|
|
186
|
+
for boxes, logits, phrases, origin_size in zip(
|
|
187
|
+
all_boxes, all_logits, all_phrases, pre_processing_meta
|
|
188
|
+
):
|
|
189
|
+
boxes = boxes * torch.Tensor(
|
|
190
|
+
[
|
|
191
|
+
origin_size.width,
|
|
192
|
+
origin_size.height,
|
|
193
|
+
origin_size.width,
|
|
194
|
+
origin_size.height,
|
|
195
|
+
],
|
|
196
|
+
device=boxes.device,
|
|
197
|
+
)
|
|
198
|
+
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy")
|
|
199
|
+
class_id = map_phrases_to_classes(
|
|
200
|
+
phrases=phrases,
|
|
201
|
+
classes=classes,
|
|
202
|
+
).to(boxes.device)
|
|
203
|
+
nms_class_ids = torch.zeros_like(class_id) if class_agnostic else class_id
|
|
204
|
+
keep = torchvision.ops.batched_nms(xyxy, logits, nms_class_ids, iou_thresh)
|
|
205
|
+
if keep.numel() > max_detections:
|
|
206
|
+
keep = keep[:max_detections]
|
|
207
|
+
results.append(
|
|
208
|
+
Detections(
|
|
209
|
+
xyxy=xyxy[keep].round().int(),
|
|
210
|
+
confidence=logits[keep],
|
|
211
|
+
class_id=class_id[keep].int(),
|
|
212
|
+
),
|
|
213
|
+
)
|
|
214
|
+
return results
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def map_phrases_to_classes(phrases: List[str], classes: List[str]) -> torch.Tensor:
|
|
218
|
+
class_ids = []
|
|
219
|
+
for phrase in phrases:
|
|
220
|
+
for class_ in classes:
|
|
221
|
+
if class_ in phrase:
|
|
222
|
+
class_ids.append(classes.index(class_))
|
|
223
|
+
break
|
|
224
|
+
else:
|
|
225
|
+
# TODO: figure out how to mark additional classes
|
|
226
|
+
class_ids.append(len(classes))
|
|
227
|
+
return torch.tensor(class_ids)
|
|
File without changes
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from threading import Lock
|
|
3
|
+
from typing import List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from torchvision import transforms
|
|
8
|
+
|
|
9
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
10
|
+
from inference_models.entities import ColorFormat
|
|
11
|
+
from inference_models.errors import (
|
|
12
|
+
EnvironmentConfigurationError,
|
|
13
|
+
MissingDependencyError,
|
|
14
|
+
ModelRuntimeError,
|
|
15
|
+
)
|
|
16
|
+
from inference_models.models.base.types import PreprocessedInputs
|
|
17
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
18
|
+
from inference_models.models.common.onnx import (
|
|
19
|
+
run_session_via_iobinding,
|
|
20
|
+
run_session_with_batch_size_limit,
|
|
21
|
+
set_execution_provider_defaults,
|
|
22
|
+
)
|
|
23
|
+
from inference_models.utils.onnx_introspection import (
|
|
24
|
+
get_selected_onnx_execution_providers,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
import onnxruntime
|
|
29
|
+
except ImportError as import_error:
|
|
30
|
+
raise MissingDependencyError(
|
|
31
|
+
message=f"Could not import L2CS model with ONNX backend - this error means that some additional dependencies "
|
|
32
|
+
f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
|
|
33
|
+
f"program, make sure the following extras of the package are installed: \n"
|
|
34
|
+
f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
|
|
35
|
+
f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
|
|
36
|
+
f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
|
|
37
|
+
f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
|
|
38
|
+
f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
|
|
39
|
+
f"You can also contact Roboflow to get support.",
|
|
40
|
+
help_url="https://todo",
|
|
41
|
+
) from import_error
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
DEFAULT_GAZE_MAX_BATCH_SIZE = 8
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class L2CSGazeDetection:
|
|
49
|
+
yaw: torch.Tensor
|
|
50
|
+
pitch: torch.Tensor
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class L2CSNetOnnx:
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def from_pretrained(
|
|
57
|
+
cls,
|
|
58
|
+
model_name_or_path: str,
|
|
59
|
+
onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
|
|
60
|
+
default_onnx_trt_options: bool = True,
|
|
61
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
62
|
+
max_batch_size: int = DEFAULT_GAZE_MAX_BATCH_SIZE,
|
|
63
|
+
**kwargs,
|
|
64
|
+
):
|
|
65
|
+
if onnx_execution_providers is None:
|
|
66
|
+
onnx_execution_providers = get_selected_onnx_execution_providers()
|
|
67
|
+
if not onnx_execution_providers:
|
|
68
|
+
raise EnvironmentConfigurationError(
|
|
69
|
+
message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
|
|
70
|
+
f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
|
|
71
|
+
f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
|
|
72
|
+
f"contact the platform support.",
|
|
73
|
+
help_url="https://todo",
|
|
74
|
+
)
|
|
75
|
+
onnx_execution_providers = set_execution_provider_defaults(
|
|
76
|
+
providers=onnx_execution_providers,
|
|
77
|
+
model_package_path=model_name_or_path,
|
|
78
|
+
device=device,
|
|
79
|
+
default_onnx_trt_options=default_onnx_trt_options,
|
|
80
|
+
)
|
|
81
|
+
model_package_content = get_model_package_contents(
|
|
82
|
+
model_package_dir=model_name_or_path,
|
|
83
|
+
elements=["weights.onnx"],
|
|
84
|
+
)
|
|
85
|
+
session = onnxruntime.InferenceSession(
|
|
86
|
+
path_or_bytes=model_package_content["weights.onnx"],
|
|
87
|
+
providers=onnx_execution_providers,
|
|
88
|
+
)
|
|
89
|
+
input_name = session.get_inputs()[0].name
|
|
90
|
+
return cls(
|
|
91
|
+
session=session,
|
|
92
|
+
max_batch_size=max_batch_size,
|
|
93
|
+
device=device,
|
|
94
|
+
input_name=input_name,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
session: onnxruntime.InferenceSession,
|
|
100
|
+
max_batch_size: int,
|
|
101
|
+
device: torch.device,
|
|
102
|
+
input_name: str,
|
|
103
|
+
):
|
|
104
|
+
self._session = session
|
|
105
|
+
self._max_batch_size = max_batch_size
|
|
106
|
+
self._device = device
|
|
107
|
+
self._input_name = input_name
|
|
108
|
+
self._session_thread_lock = Lock()
|
|
109
|
+
self._numpy_transformations = transforms.Compose(
|
|
110
|
+
[
|
|
111
|
+
transforms.ToTensor(),
|
|
112
|
+
transforms.Resize([448, 448]),
|
|
113
|
+
transforms.Normalize(
|
|
114
|
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
115
|
+
),
|
|
116
|
+
]
|
|
117
|
+
)
|
|
118
|
+
self._tensors_transformations = transforms.Compose(
|
|
119
|
+
[
|
|
120
|
+
lambda x: x / 255.0,
|
|
121
|
+
transforms.Resize([448, 448]),
|
|
122
|
+
transforms.Normalize(
|
|
123
|
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
124
|
+
),
|
|
125
|
+
]
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def device(self) -> torch.device:
|
|
130
|
+
return self._device
|
|
131
|
+
|
|
132
|
+
def infer(
|
|
133
|
+
self,
|
|
134
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
135
|
+
**kwargs,
|
|
136
|
+
) -> L2CSGazeDetection:
|
|
137
|
+
pre_processed_images = self.pre_process(images, **kwargs)
|
|
138
|
+
model_results = self.forward(pre_processed_images, **kwargs)
|
|
139
|
+
return self.post_process(model_results, **kwargs)
|
|
140
|
+
|
|
141
|
+
def pre_process(
|
|
142
|
+
self,
|
|
143
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
144
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
145
|
+
**kwargs,
|
|
146
|
+
) -> torch.Tensor:
|
|
147
|
+
if isinstance(images, np.ndarray):
|
|
148
|
+
input_color_format = input_color_format or "bgr"
|
|
149
|
+
if input_color_format != "rgb":
|
|
150
|
+
images = np.ascontiguousarray(images[:, :, ::-1])
|
|
151
|
+
pre_processed = self._numpy_transformations(images)
|
|
152
|
+
return torch.unsqueeze(pre_processed, dim=0).to(self._device)
|
|
153
|
+
if isinstance(images, torch.Tensor):
|
|
154
|
+
input_color_format = input_color_format or "rgb"
|
|
155
|
+
if len(images.shape) == 3:
|
|
156
|
+
images = torch.unsqueeze(images, dim=0)
|
|
157
|
+
images = images.to(self._device)
|
|
158
|
+
if input_color_format != "rgb":
|
|
159
|
+
images = images[:, [2, 1, 0], :, :]
|
|
160
|
+
return self._tensors_transformations(images.float())
|
|
161
|
+
if not isinstance(images, list):
|
|
162
|
+
raise ModelRuntimeError(
|
|
163
|
+
message="Pre-processing supports only np.array or torch.Tensor or list of above.",
|
|
164
|
+
help_url="https://todo",
|
|
165
|
+
)
|
|
166
|
+
if not len(images):
|
|
167
|
+
raise ModelRuntimeError(
|
|
168
|
+
message="Detected empty input to the model", help_url="https://todo"
|
|
169
|
+
)
|
|
170
|
+
if isinstance(images[0], np.ndarray):
|
|
171
|
+
input_color_format = input_color_format or "bgr"
|
|
172
|
+
pre_processed = []
|
|
173
|
+
for image in images:
|
|
174
|
+
if input_color_format != "rgb":
|
|
175
|
+
image = np.ascontiguousarray(image[:, :, ::-1])
|
|
176
|
+
pre_processed.append(self._numpy_transformations(image))
|
|
177
|
+
return torch.stack(pre_processed, dim=0).to(self._device)
|
|
178
|
+
if isinstance(images[0], torch.Tensor):
|
|
179
|
+
input_color_format = input_color_format or "rgb"
|
|
180
|
+
pre_processed = []
|
|
181
|
+
for image in images:
|
|
182
|
+
if len(image.shape) == 3:
|
|
183
|
+
image = torch.unsqueeze(image, dim=0)
|
|
184
|
+
if input_color_format != "rgb":
|
|
185
|
+
image = image[:, [2, 1, 0], :, :]
|
|
186
|
+
pre_processed.append(self._tensors_transformations(image.float()))
|
|
187
|
+
return torch.cat(pre_processed, dim=0).to(self._device)
|
|
188
|
+
raise ModelRuntimeError(
|
|
189
|
+
message=f"Detected unknown input batch element: {type(images[0])}",
|
|
190
|
+
help_url="https://todo",
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def forward(
|
|
194
|
+
self,
|
|
195
|
+
pre_processed_images: PreprocessedInputs,
|
|
196
|
+
**kwargs,
|
|
197
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
198
|
+
with self._session_thread_lock:
|
|
199
|
+
yaw, pitch = run_session_with_batch_size_limit(
|
|
200
|
+
session=self._session, inputs={self._input_name: pre_processed_images}
|
|
201
|
+
)
|
|
202
|
+
return yaw, pitch
|
|
203
|
+
|
|
204
|
+
def post_process(
|
|
205
|
+
self,
|
|
206
|
+
model_results: Tuple[torch.Tensor, torch.Tensor],
|
|
207
|
+
**kwargs,
|
|
208
|
+
) -> L2CSGazeDetection:
|
|
209
|
+
return L2CSGazeDetection(yaw=model_results[0], pitch=model_results[1])
|
|
210
|
+
|
|
211
|
+
def __call__(
|
|
212
|
+
self,
|
|
213
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
214
|
+
**kwargs,
|
|
215
|
+
) -> L2CSGazeDetection:
|
|
216
|
+
return self.infer(images, **kwargs)
|
|
File without changes
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from inference_models import Detections, KeyPoints, KeyPointsDetectionModel
|
|
8
|
+
from inference_models.entities import ColorFormat, ImageDimensions
|
|
9
|
+
from inference_models.errors import MissingDependencyError, ModelRuntimeError
|
|
10
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import mediapipe as mp
|
|
14
|
+
from mediapipe.tasks.python.components.containers import Detection
|
|
15
|
+
except ImportError as import_error:
|
|
16
|
+
raise MissingDependencyError(
|
|
17
|
+
message=f"Could not import face detection model from MediaPipe - this error means that some additional "
|
|
18
|
+
f"dependencies are not installed in the environment. If you run the `inference-models` library directly in your Python "
|
|
19
|
+
f"program, make sure the following extras of the package are installed: `mediapipe`."
|
|
20
|
+
f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
|
|
21
|
+
f"You can also contact Roboflow to get support.",
|
|
22
|
+
help_url="https://todo",
|
|
23
|
+
) from import_error
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MediaPipeFaceDetector(
|
|
27
|
+
KeyPointsDetectionModel[List[mp.Image], ImageDimensions, List[Detection]]
|
|
28
|
+
):
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def from_pretrained(
|
|
32
|
+
cls,
|
|
33
|
+
model_name_or_path: str,
|
|
34
|
+
**kwargs,
|
|
35
|
+
) -> "MediaPipeFaceDetector":
|
|
36
|
+
model_package_content = get_model_package_contents(
|
|
37
|
+
model_package_dir=model_name_or_path,
|
|
38
|
+
elements=["mediapipe_face_detector.tflite"],
|
|
39
|
+
)
|
|
40
|
+
face_detector = mp.tasks.vision.FaceDetector.create_from_options(
|
|
41
|
+
mp.tasks.vision.FaceDetectorOptions(
|
|
42
|
+
base_options=mp.tasks.BaseOptions(
|
|
43
|
+
model_asset_path=model_package_content[
|
|
44
|
+
"mediapipe_face_detector.tflite"
|
|
45
|
+
]
|
|
46
|
+
),
|
|
47
|
+
running_mode=mp.tasks.vision.RunningMode.IMAGE,
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
return cls(face_detector=face_detector)
|
|
51
|
+
|
|
52
|
+
def __init__(self, face_detector: mp.tasks.vision.FaceDetector):
|
|
53
|
+
self._face_detector = face_detector
|
|
54
|
+
self._thread_lock = Lock()
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def class_names(self) -> List[str]:
|
|
58
|
+
return ["face"]
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def key_points_classes(self) -> List[List[str]]:
|
|
62
|
+
return [["right-eye", "left-eye", "nose", "mouth", "right-ear", "left-ear"]]
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def skeletons(self) -> List[List[Tuple[int, int]]]:
|
|
66
|
+
return [[(5, 1), (1, 2), (4, 0), (0, 2), (2, 3)]]
|
|
67
|
+
|
|
68
|
+
def pre_process(
|
|
69
|
+
self,
|
|
70
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
71
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
72
|
+
**kwargs,
|
|
73
|
+
) -> Tuple[List[mp.Image], List[ImageDimensions]]:
|
|
74
|
+
if isinstance(images, np.ndarray):
|
|
75
|
+
input_color_format = input_color_format or "bgr"
|
|
76
|
+
if input_color_format != "rgb":
|
|
77
|
+
images = np.ascontiguousarray(images[:, :, ::-1])
|
|
78
|
+
preprocessed_images = mp.Image(
|
|
79
|
+
image_format=mp.ImageFormat.SRGB, data=images.astype(np.uint8)
|
|
80
|
+
)
|
|
81
|
+
dimensions = ImageDimensions(height=images.shape[0], width=images.shape[1])
|
|
82
|
+
return [preprocessed_images], [dimensions]
|
|
83
|
+
if isinstance(images, torch.Tensor):
|
|
84
|
+
input_color_format = input_color_format or "rgb"
|
|
85
|
+
if len(images.shape) == 3:
|
|
86
|
+
images = torch.unsqueeze(images, dim=0)
|
|
87
|
+
if input_color_format != "rgb":
|
|
88
|
+
images = images[:, [2, 1, 0], :, :]
|
|
89
|
+
images = images.permute(0, 2, 3, 1)
|
|
90
|
+
preprocessed_images, dimensions = [], []
|
|
91
|
+
for image in images:
|
|
92
|
+
np_image = np.ascontiguousarray(image.cpu().numpy())
|
|
93
|
+
preprocessed_images.append(
|
|
94
|
+
mp.Image(
|
|
95
|
+
image_format=mp.ImageFormat.SRGB, data=np_image.astype(np.uint8)
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
dimensions.append(
|
|
99
|
+
ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
|
|
100
|
+
)
|
|
101
|
+
return preprocessed_images, dimensions
|
|
102
|
+
if not isinstance(images, list):
|
|
103
|
+
raise ModelRuntimeError(
|
|
104
|
+
message="Pre-processing supports only np.array or torch.Tensor or list of above.",
|
|
105
|
+
help_url="https://todo",
|
|
106
|
+
)
|
|
107
|
+
if not len(images):
|
|
108
|
+
raise ModelRuntimeError(
|
|
109
|
+
message="Detected empty input to the model", help_url="https://todo"
|
|
110
|
+
)
|
|
111
|
+
if isinstance(images[0], np.ndarray):
|
|
112
|
+
input_color_format = input_color_format or "bgr"
|
|
113
|
+
preprocessed_images, dimensions = [], []
|
|
114
|
+
for image in images:
|
|
115
|
+
if input_color_format != "rgb":
|
|
116
|
+
image = np.ascontiguousarray(image[:, :, ::-1])
|
|
117
|
+
preprocessed_images.append(
|
|
118
|
+
mp.Image(
|
|
119
|
+
image_format=mp.ImageFormat.SRGB, data=image.astype(np.uint8)
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
dimensions.append(
|
|
123
|
+
ImageDimensions(height=image.shape[0], width=image.shape[1])
|
|
124
|
+
)
|
|
125
|
+
return preprocessed_images, dimensions
|
|
126
|
+
if isinstance(images[0], torch.Tensor):
|
|
127
|
+
input_color_format = input_color_format or "rgb"
|
|
128
|
+
preprocessed_images, dimensions = [], []
|
|
129
|
+
for image in images:
|
|
130
|
+
if input_color_format != "rgb":
|
|
131
|
+
image = image[[2, 1, 0], :, :]
|
|
132
|
+
np_image = image.cpu().permute(1, 2, 0).numpy()
|
|
133
|
+
preprocessed_images.append(
|
|
134
|
+
mp.Image(
|
|
135
|
+
image_format=mp.ImageFormat.SRGB, data=np_image.astype(np.uint8)
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
dimensions.append(
|
|
139
|
+
ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
|
|
140
|
+
)
|
|
141
|
+
return preprocessed_images, dimensions
|
|
142
|
+
raise ModelRuntimeError(
|
|
143
|
+
message=f"Detected unknown input batch element: {type(images[0])}",
|
|
144
|
+
help_url="https://todo",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def forward(
|
|
148
|
+
self, pre_processed_images: List[mp.Image], **kwargs
|
|
149
|
+
) -> List[List[Detection]]:
|
|
150
|
+
results = []
|
|
151
|
+
with self._thread_lock:
|
|
152
|
+
for input_image in pre_processed_images:
|
|
153
|
+
image_faces = self._face_detector.detect(image=input_image).detections
|
|
154
|
+
results.append(image_faces)
|
|
155
|
+
return results
|
|
156
|
+
|
|
157
|
+
def post_process(
|
|
158
|
+
self,
|
|
159
|
+
model_results: List[List[Detection]],
|
|
160
|
+
pre_processing_meta: List[ImageDimensions],
|
|
161
|
+
conf_thresh: float = 0.25,
|
|
162
|
+
**kwargs,
|
|
163
|
+
) -> Tuple[List[KeyPoints], List[Detections]]:
|
|
164
|
+
final_key_points, final_detections = [], []
|
|
165
|
+
for image_results, image_dimensions in zip(model_results, pre_processing_meta):
|
|
166
|
+
detections_xyxy, detections_class_id, detections_confidence = [], [], []
|
|
167
|
+
key_points_xy, key_points_class_id, key_points_confidence = [], [], []
|
|
168
|
+
for detection in image_results:
|
|
169
|
+
if detection.categories[0].score < conf_thresh:
|
|
170
|
+
continue
|
|
171
|
+
xyxy = (
|
|
172
|
+
detection.bounding_box.origin_x,
|
|
173
|
+
detection.bounding_box.origin_y,
|
|
174
|
+
detection.bounding_box.origin_x + detection.bounding_box.width,
|
|
175
|
+
detection.bounding_box.origin_y + detection.bounding_box.height,
|
|
176
|
+
)
|
|
177
|
+
detections_xyxy.append(xyxy)
|
|
178
|
+
detections_class_id.append(0)
|
|
179
|
+
detections_confidence.append(detection.categories[0].score)
|
|
180
|
+
detection_key_points = []
|
|
181
|
+
for keypoint in detection.keypoints:
|
|
182
|
+
detection_key_points.append(
|
|
183
|
+
(
|
|
184
|
+
keypoint.x * image_dimensions.width,
|
|
185
|
+
keypoint.y * image_dimensions.height,
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
key_points_xy.append(detection_key_points)
|
|
189
|
+
key_points_class_id.append(0)
|
|
190
|
+
key_points_confidence.append([1.0] * len(detection_key_points))
|
|
191
|
+
detections = Detections(
|
|
192
|
+
xyxy=torch.tensor(detections_xyxy).round().int(),
|
|
193
|
+
class_id=torch.tensor(detections_class_id).int(),
|
|
194
|
+
confidence=torch.tensor(detections_confidence),
|
|
195
|
+
)
|
|
196
|
+
key_points = KeyPoints(
|
|
197
|
+
xy=torch.tensor(key_points_xy).round().int(),
|
|
198
|
+
class_id=torch.tensor(key_points_class_id).int(),
|
|
199
|
+
confidence=torch.tensor(key_points_confidence),
|
|
200
|
+
)
|
|
201
|
+
final_key_points.append(key_points)
|
|
202
|
+
final_detections.append(detections)
|
|
203
|
+
return final_key_points, final_detections
|
|
File without changes
|