inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,227 @@
1
+ import os.path
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+ from groundingdino.util.inference import load_model, predict
8
+ from torch import nn
9
+ from torchvision import transforms
10
+ from torchvision.ops import box_convert
11
+
12
+ from inference_models import Detections
13
+ from inference_models.configuration import DEFAULT_DEVICE
14
+ from inference_models.entities import ColorFormat, ImageDimensions
15
+ from inference_models.errors import ModelRuntimeError
16
+ from inference_models.models.base.object_detection import (
17
+ OpenVocabularyObjectDetectionModel,
18
+ )
19
+ from inference_models.models.common.model_packages import get_model_package_contents
20
+
21
+
22
+ class GroundingDinoForObjectDetectionTorch(
23
+ OpenVocabularyObjectDetectionModel[
24
+ torch.Tensor,
25
+ List[ImageDimensions],
26
+ Tuple[List[torch.Tensor], List[torch.Tensor], List[List[str]], List[str]],
27
+ ]
28
+ ):
29
+ @classmethod
30
+ def from_pretrained(
31
+ cls,
32
+ model_name_or_path: str,
33
+ device: torch.device = DEFAULT_DEVICE,
34
+ **kwargs,
35
+ ) -> "GroundingDinoForObjectDetectionTorch":
36
+ model_package_content = get_model_package_contents(
37
+ model_package_dir=model_name_or_path,
38
+ elements=["weights.pth", "config.py"],
39
+ )
40
+ text_encoder_dir = os.path.join(model_name_or_path, "text_encoder")
41
+ loader_kwargs = {}
42
+ if os.path.isdir(text_encoder_dir):
43
+ loader_kwargs["text_encoder_type"] = text_encoder_dir
44
+ model = load_model(
45
+ model_config_path=model_package_content["config.py"],
46
+ model_checkpoint_path=model_package_content["weights.pth"],
47
+ **loader_kwargs,
48
+ ).to(device)
49
+ return cls(model=model, device=device)
50
+
51
+ def __init__(
52
+ self,
53
+ model: nn.Module,
54
+ device: torch.device,
55
+ ):
56
+ self._model = model
57
+ self._device = device
58
+ self._numpy_transformations = transforms.Compose(
59
+ [
60
+ transforms.ToTensor(),
61
+ transforms.Resize([800, 800]),
62
+ transforms.Normalize(
63
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
64
+ ),
65
+ ]
66
+ )
67
+ self._tensors_transformations = transforms.Compose(
68
+ [
69
+ lambda x: x / 255.0,
70
+ transforms.Resize([800, 800]),
71
+ transforms.Normalize(
72
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
73
+ ),
74
+ ]
75
+ )
76
+
77
+ def pre_process(
78
+ self,
79
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
80
+ input_color_format: Optional[ColorFormat] = None,
81
+ **kwargs,
82
+ ) -> Tuple[torch.Tensor, List[ImageDimensions]]:
83
+ if isinstance(images, np.ndarray):
84
+ input_color_format = input_color_format or "bgr"
85
+ if input_color_format != "rgb":
86
+ images = np.ascontiguousarray(images[:, :, ::-1])
87
+ pre_processed = self._numpy_transformations(images)
88
+ return (
89
+ torch.unsqueeze(pre_processed, dim=0).to(self._device),
90
+ [ImageDimensions(height=images.shape[0], width=images.shape[1])],
91
+ )
92
+ if isinstance(images, torch.Tensor):
93
+ input_color_format = input_color_format or "rgb"
94
+ if len(images.shape) == 3:
95
+ images = torch.unsqueeze(images, dim=0)
96
+ image_dimensions = ImageDimensions(
97
+ height=images.shape[2], width=images.shape[3]
98
+ )
99
+ images = images.to(self._device)
100
+ if input_color_format != "rgb":
101
+ images = images[:, [2, 1, 0], :, :]
102
+ return (
103
+ self._tensors_transformations(images.float()),
104
+ [image_dimensions] * images.shape[0],
105
+ )
106
+ if not isinstance(images, list):
107
+ raise ModelRuntimeError(
108
+ message="Pre-processing supports only np.array or torch.Tensor or list of above.",
109
+ help_url="https://todo",
110
+ )
111
+ if not len(images):
112
+ raise ModelRuntimeError(
113
+ message="Detected empty input to the model",
114
+ help_url="https://todo",
115
+ )
116
+ if isinstance(images[0], np.ndarray):
117
+ input_color_format = input_color_format or "bgr"
118
+ pre_processed, image_dimensions = [], []
119
+ for image in images:
120
+ if input_color_format != "rgb":
121
+ image = np.ascontiguousarray(image[:, :, ::-1])
122
+ image_dimensions.append(
123
+ ImageDimensions(height=image.shape[0], width=image.shape[1])
124
+ )
125
+ pre_processed.append(self._numpy_transformations(image))
126
+ return torch.stack(pre_processed, dim=0).to(self._device), image_dimensions
127
+ if isinstance(images[0], torch.Tensor):
128
+ input_color_format = input_color_format or "rgb"
129
+ pre_processed, image_dimensions = [], []
130
+ for image in images:
131
+ if len(image.shape) == 3:
132
+ image = torch.unsqueeze(image, dim=0)
133
+ if input_color_format != "rgb":
134
+ image = image[:, [2, 1, 0], :, :]
135
+ image_dimensions.append(
136
+ ImageDimensions(height=image.shape[2], width=image.shape[3])
137
+ )
138
+ pre_processed.append(self._tensors_transformations(image.float()))
139
+ return torch.cat(pre_processed, dim=0).to(self._device), image_dimensions
140
+ raise ModelRuntimeError(
141
+ message=f"Detected unknown input batch element: {type(images[0])}",
142
+ help_url="https://todo",
143
+ )
144
+
145
+ def forward(
146
+ self,
147
+ pre_processed_images: torch.Tensor,
148
+ classes: List[str],
149
+ conf_thresh: float = 0.5,
150
+ text_threshold: Optional[float] = None,
151
+ **kwargs,
152
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[str]], List[str]]:
153
+ if text_threshold is None:
154
+ text_threshold = conf_thresh
155
+ caption = ". ".join(classes)
156
+ all_boxes, all_logits, all_phrases = [], [], []
157
+ with torch.inference_mode():
158
+ for image in pre_processed_images:
159
+ boxes, logits, phrases = predict(
160
+ model=self._model,
161
+ image=image,
162
+ caption=caption,
163
+ box_threshold=conf_thresh,
164
+ text_threshold=text_threshold,
165
+ device=self._device,
166
+ remove_combined=True,
167
+ )
168
+ all_boxes.append(boxes)
169
+ all_logits.append(logits)
170
+ all_phrases.append(phrases)
171
+ return all_boxes, all_logits, all_phrases, classes
172
+
173
+ def post_process(
174
+ self,
175
+ model_results: Tuple[
176
+ List[torch.Tensor], List[torch.Tensor], List[List[str]], List[str]
177
+ ],
178
+ pre_processing_meta: List[ImageDimensions],
179
+ iou_thresh: float = 0.45,
180
+ max_detections: int = 100,
181
+ class_agnostic: bool = False,
182
+ **kwargs,
183
+ ) -> List[Detections]:
184
+ all_boxes, all_logits, all_phrases, classes = model_results
185
+ results = []
186
+ for boxes, logits, phrases, origin_size in zip(
187
+ all_boxes, all_logits, all_phrases, pre_processing_meta
188
+ ):
189
+ boxes = boxes * torch.Tensor(
190
+ [
191
+ origin_size.width,
192
+ origin_size.height,
193
+ origin_size.width,
194
+ origin_size.height,
195
+ ],
196
+ device=boxes.device,
197
+ )
198
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy")
199
+ class_id = map_phrases_to_classes(
200
+ phrases=phrases,
201
+ classes=classes,
202
+ ).to(boxes.device)
203
+ nms_class_ids = torch.zeros_like(class_id) if class_agnostic else class_id
204
+ keep = torchvision.ops.batched_nms(xyxy, logits, nms_class_ids, iou_thresh)
205
+ if keep.numel() > max_detections:
206
+ keep = keep[:max_detections]
207
+ results.append(
208
+ Detections(
209
+ xyxy=xyxy[keep].round().int(),
210
+ confidence=logits[keep],
211
+ class_id=class_id[keep].int(),
212
+ ),
213
+ )
214
+ return results
215
+
216
+
217
+ def map_phrases_to_classes(phrases: List[str], classes: List[str]) -> torch.Tensor:
218
+ class_ids = []
219
+ for phrase in phrases:
220
+ for class_ in classes:
221
+ if class_ in phrase:
222
+ class_ids.append(classes.index(class_))
223
+ break
224
+ else:
225
+ # TODO: figure out how to mark additional classes
226
+ class_ids.append(len(classes))
227
+ return torch.tensor(class_ids)
File without changes
@@ -0,0 +1,216 @@
1
+ from dataclasses import dataclass
2
+ from threading import Lock
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torchvision import transforms
8
+
9
+ from inference_models.configuration import DEFAULT_DEVICE
10
+ from inference_models.entities import ColorFormat
11
+ from inference_models.errors import (
12
+ EnvironmentConfigurationError,
13
+ MissingDependencyError,
14
+ ModelRuntimeError,
15
+ )
16
+ from inference_models.models.base.types import PreprocessedInputs
17
+ from inference_models.models.common.model_packages import get_model_package_contents
18
+ from inference_models.models.common.onnx import (
19
+ run_session_via_iobinding,
20
+ run_session_with_batch_size_limit,
21
+ set_execution_provider_defaults,
22
+ )
23
+ from inference_models.utils.onnx_introspection import (
24
+ get_selected_onnx_execution_providers,
25
+ )
26
+
27
+ try:
28
+ import onnxruntime
29
+ except ImportError as import_error:
30
+ raise MissingDependencyError(
31
+ message=f"Could not import L2CS model with ONNX backend - this error means that some additional dependencies "
32
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
33
+ f"program, make sure the following extras of the package are installed: \n"
34
+ f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
35
+ f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
36
+ f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
37
+ f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
38
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
39
+ f"You can also contact Roboflow to get support.",
40
+ help_url="https://todo",
41
+ ) from import_error
42
+
43
+
44
+ DEFAULT_GAZE_MAX_BATCH_SIZE = 8
45
+
46
+
47
+ @dataclass
48
+ class L2CSGazeDetection:
49
+ yaw: torch.Tensor
50
+ pitch: torch.Tensor
51
+
52
+
53
+ class L2CSNetOnnx:
54
+
55
+ @classmethod
56
+ def from_pretrained(
57
+ cls,
58
+ model_name_or_path: str,
59
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
60
+ default_onnx_trt_options: bool = True,
61
+ device: torch.device = DEFAULT_DEVICE,
62
+ max_batch_size: int = DEFAULT_GAZE_MAX_BATCH_SIZE,
63
+ **kwargs,
64
+ ):
65
+ if onnx_execution_providers is None:
66
+ onnx_execution_providers = get_selected_onnx_execution_providers()
67
+ if not onnx_execution_providers:
68
+ raise EnvironmentConfigurationError(
69
+ message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
70
+ f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
71
+ f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
72
+ f"contact the platform support.",
73
+ help_url="https://todo",
74
+ )
75
+ onnx_execution_providers = set_execution_provider_defaults(
76
+ providers=onnx_execution_providers,
77
+ model_package_path=model_name_or_path,
78
+ device=device,
79
+ default_onnx_trt_options=default_onnx_trt_options,
80
+ )
81
+ model_package_content = get_model_package_contents(
82
+ model_package_dir=model_name_or_path,
83
+ elements=["weights.onnx"],
84
+ )
85
+ session = onnxruntime.InferenceSession(
86
+ path_or_bytes=model_package_content["weights.onnx"],
87
+ providers=onnx_execution_providers,
88
+ )
89
+ input_name = session.get_inputs()[0].name
90
+ return cls(
91
+ session=session,
92
+ max_batch_size=max_batch_size,
93
+ device=device,
94
+ input_name=input_name,
95
+ )
96
+
97
+ def __init__(
98
+ self,
99
+ session: onnxruntime.InferenceSession,
100
+ max_batch_size: int,
101
+ device: torch.device,
102
+ input_name: str,
103
+ ):
104
+ self._session = session
105
+ self._max_batch_size = max_batch_size
106
+ self._device = device
107
+ self._input_name = input_name
108
+ self._session_thread_lock = Lock()
109
+ self._numpy_transformations = transforms.Compose(
110
+ [
111
+ transforms.ToTensor(),
112
+ transforms.Resize([448, 448]),
113
+ transforms.Normalize(
114
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
115
+ ),
116
+ ]
117
+ )
118
+ self._tensors_transformations = transforms.Compose(
119
+ [
120
+ lambda x: x / 255.0,
121
+ transforms.Resize([448, 448]),
122
+ transforms.Normalize(
123
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
124
+ ),
125
+ ]
126
+ )
127
+
128
+ @property
129
+ def device(self) -> torch.device:
130
+ return self._device
131
+
132
+ def infer(
133
+ self,
134
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
135
+ **kwargs,
136
+ ) -> L2CSGazeDetection:
137
+ pre_processed_images = self.pre_process(images, **kwargs)
138
+ model_results = self.forward(pre_processed_images, **kwargs)
139
+ return self.post_process(model_results, **kwargs)
140
+
141
+ def pre_process(
142
+ self,
143
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
144
+ input_color_format: Optional[ColorFormat] = None,
145
+ **kwargs,
146
+ ) -> torch.Tensor:
147
+ if isinstance(images, np.ndarray):
148
+ input_color_format = input_color_format or "bgr"
149
+ if input_color_format != "rgb":
150
+ images = np.ascontiguousarray(images[:, :, ::-1])
151
+ pre_processed = self._numpy_transformations(images)
152
+ return torch.unsqueeze(pre_processed, dim=0).to(self._device)
153
+ if isinstance(images, torch.Tensor):
154
+ input_color_format = input_color_format or "rgb"
155
+ if len(images.shape) == 3:
156
+ images = torch.unsqueeze(images, dim=0)
157
+ images = images.to(self._device)
158
+ if input_color_format != "rgb":
159
+ images = images[:, [2, 1, 0], :, :]
160
+ return self._tensors_transformations(images.float())
161
+ if not isinstance(images, list):
162
+ raise ModelRuntimeError(
163
+ message="Pre-processing supports only np.array or torch.Tensor or list of above.",
164
+ help_url="https://todo",
165
+ )
166
+ if not len(images):
167
+ raise ModelRuntimeError(
168
+ message="Detected empty input to the model", help_url="https://todo"
169
+ )
170
+ if isinstance(images[0], np.ndarray):
171
+ input_color_format = input_color_format or "bgr"
172
+ pre_processed = []
173
+ for image in images:
174
+ if input_color_format != "rgb":
175
+ image = np.ascontiguousarray(image[:, :, ::-1])
176
+ pre_processed.append(self._numpy_transformations(image))
177
+ return torch.stack(pre_processed, dim=0).to(self._device)
178
+ if isinstance(images[0], torch.Tensor):
179
+ input_color_format = input_color_format or "rgb"
180
+ pre_processed = []
181
+ for image in images:
182
+ if len(image.shape) == 3:
183
+ image = torch.unsqueeze(image, dim=0)
184
+ if input_color_format != "rgb":
185
+ image = image[:, [2, 1, 0], :, :]
186
+ pre_processed.append(self._tensors_transformations(image.float()))
187
+ return torch.cat(pre_processed, dim=0).to(self._device)
188
+ raise ModelRuntimeError(
189
+ message=f"Detected unknown input batch element: {type(images[0])}",
190
+ help_url="https://todo",
191
+ )
192
+
193
+ def forward(
194
+ self,
195
+ pre_processed_images: PreprocessedInputs,
196
+ **kwargs,
197
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
198
+ with self._session_thread_lock:
199
+ yaw, pitch = run_session_with_batch_size_limit(
200
+ session=self._session, inputs={self._input_name: pre_processed_images}
201
+ )
202
+ return yaw, pitch
203
+
204
+ def post_process(
205
+ self,
206
+ model_results: Tuple[torch.Tensor, torch.Tensor],
207
+ **kwargs,
208
+ ) -> L2CSGazeDetection:
209
+ return L2CSGazeDetection(yaw=model_results[0], pitch=model_results[1])
210
+
211
+ def __call__(
212
+ self,
213
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
214
+ **kwargs,
215
+ ) -> L2CSGazeDetection:
216
+ return self.infer(images, **kwargs)
@@ -0,0 +1,203 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import Detections, KeyPoints, KeyPointsDetectionModel
8
+ from inference_models.entities import ColorFormat, ImageDimensions
9
+ from inference_models.errors import MissingDependencyError, ModelRuntimeError
10
+ from inference_models.models.common.model_packages import get_model_package_contents
11
+
12
+ try:
13
+ import mediapipe as mp
14
+ from mediapipe.tasks.python.components.containers import Detection
15
+ except ImportError as import_error:
16
+ raise MissingDependencyError(
17
+ message=f"Could not import face detection model from MediaPipe - this error means that some additional "
18
+ f"dependencies are not installed in the environment. If you run the `inference-models` library directly in your Python "
19
+ f"program, make sure the following extras of the package are installed: `mediapipe`."
20
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
21
+ f"You can also contact Roboflow to get support.",
22
+ help_url="https://todo",
23
+ ) from import_error
24
+
25
+
26
+ class MediaPipeFaceDetector(
27
+ KeyPointsDetectionModel[List[mp.Image], ImageDimensions, List[Detection]]
28
+ ):
29
+
30
+ @classmethod
31
+ def from_pretrained(
32
+ cls,
33
+ model_name_or_path: str,
34
+ **kwargs,
35
+ ) -> "MediaPipeFaceDetector":
36
+ model_package_content = get_model_package_contents(
37
+ model_package_dir=model_name_or_path,
38
+ elements=["mediapipe_face_detector.tflite"],
39
+ )
40
+ face_detector = mp.tasks.vision.FaceDetector.create_from_options(
41
+ mp.tasks.vision.FaceDetectorOptions(
42
+ base_options=mp.tasks.BaseOptions(
43
+ model_asset_path=model_package_content[
44
+ "mediapipe_face_detector.tflite"
45
+ ]
46
+ ),
47
+ running_mode=mp.tasks.vision.RunningMode.IMAGE,
48
+ )
49
+ )
50
+ return cls(face_detector=face_detector)
51
+
52
+ def __init__(self, face_detector: mp.tasks.vision.FaceDetector):
53
+ self._face_detector = face_detector
54
+ self._thread_lock = Lock()
55
+
56
+ @property
57
+ def class_names(self) -> List[str]:
58
+ return ["face"]
59
+
60
+ @property
61
+ def key_points_classes(self) -> List[List[str]]:
62
+ return [["right-eye", "left-eye", "nose", "mouth", "right-ear", "left-ear"]]
63
+
64
+ @property
65
+ def skeletons(self) -> List[List[Tuple[int, int]]]:
66
+ return [[(5, 1), (1, 2), (4, 0), (0, 2), (2, 3)]]
67
+
68
+ def pre_process(
69
+ self,
70
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
71
+ input_color_format: Optional[ColorFormat] = None,
72
+ **kwargs,
73
+ ) -> Tuple[List[mp.Image], List[ImageDimensions]]:
74
+ if isinstance(images, np.ndarray):
75
+ input_color_format = input_color_format or "bgr"
76
+ if input_color_format != "rgb":
77
+ images = np.ascontiguousarray(images[:, :, ::-1])
78
+ preprocessed_images = mp.Image(
79
+ image_format=mp.ImageFormat.SRGB, data=images.astype(np.uint8)
80
+ )
81
+ dimensions = ImageDimensions(height=images.shape[0], width=images.shape[1])
82
+ return [preprocessed_images], [dimensions]
83
+ if isinstance(images, torch.Tensor):
84
+ input_color_format = input_color_format or "rgb"
85
+ if len(images.shape) == 3:
86
+ images = torch.unsqueeze(images, dim=0)
87
+ if input_color_format != "rgb":
88
+ images = images[:, [2, 1, 0], :, :]
89
+ images = images.permute(0, 2, 3, 1)
90
+ preprocessed_images, dimensions = [], []
91
+ for image in images:
92
+ np_image = np.ascontiguousarray(image.cpu().numpy())
93
+ preprocessed_images.append(
94
+ mp.Image(
95
+ image_format=mp.ImageFormat.SRGB, data=np_image.astype(np.uint8)
96
+ )
97
+ )
98
+ dimensions.append(
99
+ ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
100
+ )
101
+ return preprocessed_images, dimensions
102
+ if not isinstance(images, list):
103
+ raise ModelRuntimeError(
104
+ message="Pre-processing supports only np.array or torch.Tensor or list of above.",
105
+ help_url="https://todo",
106
+ )
107
+ if not len(images):
108
+ raise ModelRuntimeError(
109
+ message="Detected empty input to the model", help_url="https://todo"
110
+ )
111
+ if isinstance(images[0], np.ndarray):
112
+ input_color_format = input_color_format or "bgr"
113
+ preprocessed_images, dimensions = [], []
114
+ for image in images:
115
+ if input_color_format != "rgb":
116
+ image = np.ascontiguousarray(image[:, :, ::-1])
117
+ preprocessed_images.append(
118
+ mp.Image(
119
+ image_format=mp.ImageFormat.SRGB, data=image.astype(np.uint8)
120
+ )
121
+ )
122
+ dimensions.append(
123
+ ImageDimensions(height=image.shape[0], width=image.shape[1])
124
+ )
125
+ return preprocessed_images, dimensions
126
+ if isinstance(images[0], torch.Tensor):
127
+ input_color_format = input_color_format or "rgb"
128
+ preprocessed_images, dimensions = [], []
129
+ for image in images:
130
+ if input_color_format != "rgb":
131
+ image = image[[2, 1, 0], :, :]
132
+ np_image = image.cpu().permute(1, 2, 0).numpy()
133
+ preprocessed_images.append(
134
+ mp.Image(
135
+ image_format=mp.ImageFormat.SRGB, data=np_image.astype(np.uint8)
136
+ )
137
+ )
138
+ dimensions.append(
139
+ ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
140
+ )
141
+ return preprocessed_images, dimensions
142
+ raise ModelRuntimeError(
143
+ message=f"Detected unknown input batch element: {type(images[0])}",
144
+ help_url="https://todo",
145
+ )
146
+
147
+ def forward(
148
+ self, pre_processed_images: List[mp.Image], **kwargs
149
+ ) -> List[List[Detection]]:
150
+ results = []
151
+ with self._thread_lock:
152
+ for input_image in pre_processed_images:
153
+ image_faces = self._face_detector.detect(image=input_image).detections
154
+ results.append(image_faces)
155
+ return results
156
+
157
+ def post_process(
158
+ self,
159
+ model_results: List[List[Detection]],
160
+ pre_processing_meta: List[ImageDimensions],
161
+ conf_thresh: float = 0.25,
162
+ **kwargs,
163
+ ) -> Tuple[List[KeyPoints], List[Detections]]:
164
+ final_key_points, final_detections = [], []
165
+ for image_results, image_dimensions in zip(model_results, pre_processing_meta):
166
+ detections_xyxy, detections_class_id, detections_confidence = [], [], []
167
+ key_points_xy, key_points_class_id, key_points_confidence = [], [], []
168
+ for detection in image_results:
169
+ if detection.categories[0].score < conf_thresh:
170
+ continue
171
+ xyxy = (
172
+ detection.bounding_box.origin_x,
173
+ detection.bounding_box.origin_y,
174
+ detection.bounding_box.origin_x + detection.bounding_box.width,
175
+ detection.bounding_box.origin_y + detection.bounding_box.height,
176
+ )
177
+ detections_xyxy.append(xyxy)
178
+ detections_class_id.append(0)
179
+ detections_confidence.append(detection.categories[0].score)
180
+ detection_key_points = []
181
+ for keypoint in detection.keypoints:
182
+ detection_key_points.append(
183
+ (
184
+ keypoint.x * image_dimensions.width,
185
+ keypoint.y * image_dimensions.height,
186
+ )
187
+ )
188
+ key_points_xy.append(detection_key_points)
189
+ key_points_class_id.append(0)
190
+ key_points_confidence.append([1.0] * len(detection_key_points))
191
+ detections = Detections(
192
+ xyxy=torch.tensor(detections_xyxy).round().int(),
193
+ class_id=torch.tensor(detections_class_id).int(),
194
+ confidence=torch.tensor(detections_confidence),
195
+ )
196
+ key_points = KeyPoints(
197
+ xy=torch.tensor(key_points_xy).round().int(),
198
+ class_id=torch.tensor(key_points_class_id).int(),
199
+ confidence=torch.tensor(key_points_confidence),
200
+ )
201
+ final_key_points.append(key_points)
202
+ final_detections.append(detections)
203
+ return final_key_points, final_detections
File without changes