inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,52 @@
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from pydantic import BaseModel, ConfigDict, Field, ValidationError
5
+
6
+ from inference_models.configuration import DEFAULT_DEVICE
7
+ from inference_models.errors import DependencyModelParametersValidationError
8
+ from inference_models.models.auto_loaders.entities import BackendType
9
+ from inference_models.weights_providers.entities import Quantization
10
+
11
+
12
+ class DependencyModelParameters(BaseModel):
13
+ model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
14
+
15
+ model_id_or_path: str
16
+ model_package_id: Optional[str] = Field(default=None)
17
+ backend: Optional[Union[str, BackendType, List[Union[str, BackendType]]]] = Field(
18
+ default=None
19
+ )
20
+ batch_size: Optional[Union[int, Tuple[int, int]]] = Field(default=None)
21
+ quantization: Optional[Union[str, Quantization, List[Union[str, Quantization]]]] = (
22
+ Field(default=None)
23
+ )
24
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = Field(default=None)
25
+ device: torch.device = Field(default=DEFAULT_DEVICE)
26
+ default_onnx_trt_options: bool = Field(default=True)
27
+ nms_fusion_preferences: Optional[Union[bool, dict]] = Field(default=None)
28
+ model_type: Optional[str] = Field(default=None)
29
+ task_type: Optional[str] = Field(default=None)
30
+
31
+ @property
32
+ def kwargs(self) -> Dict[str, Any]:
33
+ return self.model_extra or {}
34
+
35
+
36
+ def prepare_dependency_model_parameters(
37
+ model_parameters: Union[str, dict, DependencyModelParameters]
38
+ ) -> DependencyModelParameters:
39
+ if isinstance(model_parameters, dict):
40
+ try:
41
+ return DependencyModelParameters.model_validate(model_parameters)
42
+ except ValidationError as error:
43
+ raise DependencyModelParametersValidationError(
44
+ message="Could not validate parameters to initialise dependent model - if you run locally, make sure "
45
+ f"that you initialise model properly, as at least one parameter parameter specified in "
46
+ f"dictionary with model options is invalid. If you use Roboflow hosted offering, contact us to "
47
+ f"get help.",
48
+ help_url="https://todo",
49
+ ) from error
50
+ if isinstance(model_parameters, str):
51
+ model_parameters = DependencyModelParameters(model_id_or_path=model_parameters)
52
+ return model_parameters
@@ -0,0 +1,57 @@
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from typing import Optional, Union
4
+
5
+ from inference_models.models.base.classification import (
6
+ ClassificationModel,
7
+ MultiLabelClassificationModel,
8
+ )
9
+ from inference_models.models.base.depth_estimation import DepthEstimationModel
10
+ from inference_models.models.base.documents_parsing import StructuredOCRModel
11
+ from inference_models.models.base.embeddings import TextImageEmbeddingModel
12
+ from inference_models.models.base.instance_segmentation import InstanceSegmentationModel
13
+ from inference_models.models.base.keypoints_detection import KeyPointsDetectionModel
14
+ from inference_models.models.base.object_detection import (
15
+ ObjectDetectionModel,
16
+ OpenVocabularyObjectDetectionModel,
17
+ )
18
+
19
+ ModelArchitecture = str
20
+ TaskType = Optional[str]
21
+ MODEL_CONFIG_FILE_NAME = "model_config.json"
22
+
23
+
24
+ class BackendType(str, Enum):
25
+ TORCH = "torch"
26
+ TORCH_SCRIPT = "torch-script"
27
+ ONNX = "onnx"
28
+ TRT = "trt"
29
+ HF = "hugging-face"
30
+ ULTRALYTICS = "ultralytics"
31
+ MEDIAPIPE = "mediapipe"
32
+ CUSTOM = "custom"
33
+
34
+
35
+ AnyModel = Union[
36
+ ClassificationModel,
37
+ MultiLabelClassificationModel,
38
+ DepthEstimationModel,
39
+ StructuredOCRModel,
40
+ TextImageEmbeddingModel,
41
+ InstanceSegmentationModel,
42
+ KeyPointsDetectionModel,
43
+ ObjectDetectionModel,
44
+ OpenVocabularyObjectDetectionModel,
45
+ ]
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class InferenceModelConfig:
50
+ model_architecture: Optional[ModelArchitecture]
51
+ task_type: TaskType
52
+ backend_type: Optional[BackendType]
53
+ model_module: Optional[str]
54
+ model_class: Optional[str]
55
+
56
+ def is_library_model(self) -> bool:
57
+ return self.model_architecture is not None and self.backend_type is not None
@@ -0,0 +1,497 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, Optional, Set, Tuple, Union
3
+
4
+ from inference_models.errors import ModelImplementationLoaderError
5
+ from inference_models.models.auto_loaders.entities import (
6
+ BackendType,
7
+ ModelArchitecture,
8
+ TaskType,
9
+ )
10
+ from inference_models.utils.imports import LazyClass
11
+
12
+ OBJECT_DETECTION_TASK = "object-detection"
13
+ INSTANCE_SEGMENTATION_TASK = "instance-segmentation"
14
+ SEMANTIC_SEGMENTATION_TASK = "semantic-segmentation"
15
+ KEYPOINT_DETECTION_TASK = "keypoint-detection"
16
+ VLM_TASK = "vlm"
17
+ EMBEDDING_TASK = "embedding"
18
+ CLASSIFICATION_TASK = "classification"
19
+ MULTI_LABEL_CLASSIFICATION_TASK = "multi-label-classification"
20
+ DEPTH_ESTIMATION_TASK = "depth-estimation"
21
+ STRUCTURED_OCR_TASK = "structured-ocr"
22
+ TEXT_ONLY_OCR_TASK = "text-only-ocr"
23
+ GAZE_DETECTION_TASK = "gaze-detection"
24
+ OPEN_VOCABULARY_OBJECT_DETECTION_TASK = "open-vocabulary-object-detection"
25
+ INTERACTIVE_INSTANCE_SEGMENTATION_TASK = "interactive-instance-segmentation"
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class RegistryEntry:
30
+ model_class: LazyClass
31
+ supported_model_features: Optional[Set[str]] = field(default=None)
32
+
33
+
34
+ REGISTERED_MODELS: Dict[
35
+ Tuple[ModelArchitecture, TaskType, BackendType], Union[LazyClass, RegistryEntry]
36
+ ] = {
37
+ ("yolonas", OBJECT_DETECTION_TASK, BackendType.ONNX): LazyClass(
38
+ module_name="inference_models.models.yolonas.yolonas_object_detection_onnx",
39
+ class_name="YOLONasForObjectDetectionOnnx",
40
+ ),
41
+ ("yolonas", OBJECT_DETECTION_TASK, BackendType.TRT): LazyClass(
42
+ module_name="inference_models.models.yolonas.yolonas_object_detection_trt",
43
+ class_name="YOLONasForObjectDetectionTRT",
44
+ ),
45
+ ("yolov5", OBJECT_DETECTION_TASK, BackendType.ONNX): LazyClass(
46
+ module_name="inference_models.models.yolov5.yolov5_object_detection_onnx",
47
+ class_name="YOLOv5ForObjectDetectionOnnx",
48
+ ),
49
+ ("yolov5", OBJECT_DETECTION_TASK, BackendType.TRT): LazyClass(
50
+ module_name="inference_models.models.yolov5.yolov5_object_detection_trt",
51
+ class_name="YOLOv5ForObjectDetectionTRT",
52
+ ),
53
+ ("yolov5", INSTANCE_SEGMENTATION_TASK, BackendType.ONNX): LazyClass(
54
+ module_name="inference_models.models.yolov5.yolov5_instance_segmentation_onnx",
55
+ class_name="YOLOv5ForInstanceSegmentationOnnx",
56
+ ),
57
+ ("yolov5", INSTANCE_SEGMENTATION_TASK, BackendType.TRT): LazyClass(
58
+ module_name="inference_models.models.yolov5.yolov5_instance_segmentation_trt",
59
+ class_name="YOLOv5ForInstanceSegmentationTRT",
60
+ ),
61
+ ("yolov7", INSTANCE_SEGMENTATION_TASK, BackendType.ONNX): LazyClass(
62
+ module_name="inference_models.models.yolov7.yolov7_instance_segmentation_onnx",
63
+ class_name="YOLOv7ForInstanceSegmentationOnnx",
64
+ ),
65
+ ("yolov7", INSTANCE_SEGMENTATION_TASK, BackendType.TRT): LazyClass(
66
+ module_name="inference_models.models.yolov7.yolov7_instance_segmentation_trt",
67
+ class_name="YOLOv7ForInstanceSegmentationTRT",
68
+ ),
69
+ ("yolov8", CLASSIFICATION_TASK, BackendType.ONNX): RegistryEntry(
70
+ model_class=LazyClass(
71
+ module_name="inference_models.models.yolov8.yolov8_classification_onnx",
72
+ class_name="YOLOv8ForClassificationOnnx",
73
+ ),
74
+ ),
75
+ ("yolov8", OBJECT_DETECTION_TASK, BackendType.ONNX): RegistryEntry(
76
+ model_class=LazyClass(
77
+ module_name="inference_models.models.yolov8.yolov8_object_detection_onnx",
78
+ class_name="YOLOv8ForObjectDetectionOnnx",
79
+ ),
80
+ supported_model_features={"nms_fused"},
81
+ ),
82
+ ("yolov8", OBJECT_DETECTION_TASK, BackendType.TORCH_SCRIPT): RegistryEntry(
83
+ model_class=LazyClass(
84
+ module_name="inference_models.models.yolov8.yolov8_object_detection_torch_script",
85
+ class_name="YOLOv8ForObjectDetectionTorchScript",
86
+ ),
87
+ supported_model_features={"nms_fused"},
88
+ ),
89
+ ("yolov8", OBJECT_DETECTION_TASK, BackendType.TRT): LazyClass(
90
+ module_name="inference_models.models.yolov8.yolov8_object_detection_trt",
91
+ class_name="YOLOv8ForObjectDetectionTRT",
92
+ ),
93
+ ("yolov8", KEYPOINT_DETECTION_TASK, BackendType.ONNX): RegistryEntry(
94
+ model_class=LazyClass(
95
+ module_name="inference_models.models.yolov8.yolov8_key_points_detection_onnx",
96
+ class_name="YOLOv8ForKeyPointsDetectionOnnx",
97
+ ),
98
+ supported_model_features={"nms_fused"},
99
+ ),
100
+ ("yolov8", KEYPOINT_DETECTION_TASK, BackendType.TORCH_SCRIPT): RegistryEntry(
101
+ model_class=LazyClass(
102
+ module_name="inference_models.models.yolov8.yolov8_key_points_detection_torch_script",
103
+ class_name="YOLOv8ForKeyPointsDetectionTorchScript",
104
+ ),
105
+ supported_model_features={"nms_fused"},
106
+ ),
107
+ ("yolov8", KEYPOINT_DETECTION_TASK, BackendType.TRT): LazyClass(
108
+ module_name="inference_models.models.yolov8.yolov8_key_points_detection_trt",
109
+ class_name="YOLOv8ForKeyPointsDetectionTRT",
110
+ ),
111
+ ("yolov8", INSTANCE_SEGMENTATION_TASK, BackendType.ONNX): RegistryEntry(
112
+ model_class=LazyClass(
113
+ module_name="inference_models.models.yolov8.yolov8_instance_segmentation_onnx",
114
+ class_name="YOLOv8ForInstanceSegmentationOnnx",
115
+ ),
116
+ supported_model_features={"nms_fused"},
117
+ ),
118
+ ("yolov8", INSTANCE_SEGMENTATION_TASK, BackendType.TORCH_SCRIPT): RegistryEntry(
119
+ model_class=LazyClass(
120
+ module_name="inference_models.models.yolov8.yolov8_instance_segmentation_torch_script",
121
+ class_name="YOLOv8ForInstanceSegmentationTorchScript",
122
+ ),
123
+ supported_model_features={"nms_fused"},
124
+ ),
125
+ ("yolov8", INSTANCE_SEGMENTATION_TASK, BackendType.TRT): LazyClass(
126
+ module_name="inference_models.models.yolov8.yolov8_instance_segmentation_trt",
127
+ class_name="YOLOv8ForInstanceSegmentationTRT",
128
+ ),
129
+ ("yolov9", OBJECT_DETECTION_TASK, BackendType.ONNX): RegistryEntry(
130
+ model_class=LazyClass(
131
+ module_name="inference_models.models.yolov9.yolov9_onnx",
132
+ class_name="YOLOv9ForObjectDetectionOnnx",
133
+ ),
134
+ supported_model_features={"nms_fused"},
135
+ ),
136
+ ("yolov9", OBJECT_DETECTION_TASK, BackendType.TORCH_SCRIPT): RegistryEntry(
137
+ model_class=LazyClass(
138
+ module_name="inference_models.models.yolov9.yolov9_torch_script",
139
+ class_name="YOLOv9ForObjectDetectionTorchScript",
140
+ ),
141
+ supported_model_features={"nms_fused"},
142
+ ),
143
+ ("yolov9", OBJECT_DETECTION_TASK, BackendType.TRT): LazyClass(
144
+ module_name="inference_models.models.yolov9.yolov9_trt",
145
+ class_name="YOLOv9ForObjectDetectionTRT",
146
+ ),
147
+ ("yolov10", OBJECT_DETECTION_TASK, BackendType.ONNX): LazyClass(
148
+ module_name="inference_models.models.yolov10.yolov10_object_detection_onnx",
149
+ class_name="YOLOv10ForObjectDetectionOnnx",
150
+ ),
151
+ ("yolov10", OBJECT_DETECTION_TASK, BackendType.TRT): LazyClass(
152
+ module_name="inference_models.models.yolov10.yolov10_object_detection_trt",
153
+ class_name="YOLOv10ForObjectDetectionTRT",
154
+ ),
155
+ ("yolov11", CLASSIFICATION_TASK, BackendType.ONNX): RegistryEntry(
156
+ model_class=LazyClass(
157
+ module_name="inference_models.models.yolov11.yolov11_onnx",
158
+ class_name="YOLOv11ForClassificationOnnx",
159
+ ),
160
+ ),
161
+ ("yolov11", OBJECT_DETECTION_TASK, BackendType.ONNX): RegistryEntry(
162
+ model_class=LazyClass(
163
+ module_name="inference_models.models.yolov11.yolov11_onnx",
164
+ class_name="YOLOv11ForObjectDetectionOnnx",
165
+ ),
166
+ supported_model_features={"nms_fused"},
167
+ ),
168
+ ("yolov11", OBJECT_DETECTION_TASK, BackendType.TORCH_SCRIPT): RegistryEntry(
169
+ model_class=LazyClass(
170
+ module_name="inference_models.models.yolov11.yolov11_torch_script",
171
+ class_name="YOLOv11ForObjectDetectionTorchScript",
172
+ ),
173
+ supported_model_features={"nms_fused"},
174
+ ),
175
+ ("yolov11", OBJECT_DETECTION_TASK, BackendType.TRT): LazyClass(
176
+ module_name="inference_models.models.yolov11.yolov11_trt",
177
+ class_name="YOLOv11ForObjectDetectionTRT",
178
+ ),
179
+ ("yolov11", KEYPOINT_DETECTION_TASK, BackendType.ONNX): RegistryEntry(
180
+ model_class=LazyClass(
181
+ module_name="inference_models.models.yolov11.yolov11_onnx",
182
+ class_name="YOLOv11ForForKeyPointsDetectionOnnx",
183
+ ),
184
+ supported_model_features={"nms_fused"},
185
+ ),
186
+ ("yolov11", KEYPOINT_DETECTION_TASK, BackendType.TORCH_SCRIPT): RegistryEntry(
187
+ model_class=LazyClass(
188
+ module_name="inference_models.models.yolov11.yolov11_torch_script",
189
+ class_name="YOLOv11ForForKeyPointsDetectionTorchScript",
190
+ ),
191
+ supported_model_features={"nms_fused"},
192
+ ),
193
+ ("yolov11", KEYPOINT_DETECTION_TASK, BackendType.TRT): LazyClass(
194
+ module_name="inference_models.models.yolov11.yolov11_trt",
195
+ class_name="YOLOv11ForForKeyPointsDetectionTRT",
196
+ ),
197
+ ("yolov11", INSTANCE_SEGMENTATION_TASK, BackendType.ONNX): RegistryEntry(
198
+ model_class=LazyClass(
199
+ module_name="inference_models.models.yolov11.yolov11_onnx",
200
+ class_name="YOLOv11ForInstanceSegmentationOnnx",
201
+ ),
202
+ supported_model_features={"nms_fused"},
203
+ ),
204
+ ("yolov11", INSTANCE_SEGMENTATION_TASK, BackendType.TORCH_SCRIPT): RegistryEntry(
205
+ model_class=LazyClass(
206
+ module_name="inference_models.models.yolov11.yolov11_torch_script",
207
+ class_name="YOLOv11ForInstanceSegmentationTorchScript",
208
+ ),
209
+ supported_model_features={"nms_fused"},
210
+ ),
211
+ ("yolov11", INSTANCE_SEGMENTATION_TASK, BackendType.TRT): LazyClass(
212
+ module_name="inference_models.models.yolov11.yolov11_trt",
213
+ class_name="YOLOv11ForInstanceSegmentationTRT",
214
+ ),
215
+ ("yolov12", OBJECT_DETECTION_TASK, BackendType.ONNX): RegistryEntry(
216
+ model_class=LazyClass(
217
+ module_name="inference_models.models.yolov12.yolov12_onnx",
218
+ class_name="YOLOv12ForObjectDetectionOnnx",
219
+ ),
220
+ supported_model_features={"nms_fused"},
221
+ ),
222
+ ("yolov12", OBJECT_DETECTION_TASK, BackendType.TORCH_SCRIPT): RegistryEntry(
223
+ model_class=LazyClass(
224
+ module_name="inference_models.models.yolov12.yolov12_torch_script",
225
+ class_name="YOLOv12ForObjectDetectionTorchScript",
226
+ ),
227
+ supported_model_features={"nms_fused"},
228
+ ),
229
+ ("yolov12", OBJECT_DETECTION_TASK, BackendType.TRT): LazyClass(
230
+ module_name="inference_models.models.yolov12.yolov12_trt",
231
+ class_name="YOLOv12ForObjectDetectionTRT",
232
+ ),
233
+ ("paligemma-2", VLM_TASK, BackendType.HF): LazyClass(
234
+ module_name="inference_models.models.paligemma.paligemma_hf",
235
+ class_name="PaliGemmaHF",
236
+ ),
237
+ ("paligemma", VLM_TASK, BackendType.HF): LazyClass(
238
+ module_name="inference_models.models.paligemma.paligemma_hf",
239
+ class_name="PaliGemmaHF",
240
+ ),
241
+ ("smolvlm-v2", VLM_TASK, BackendType.HF): LazyClass(
242
+ module_name="inference_models.models.smolvlm.smolvlm_hf",
243
+ class_name="SmolVLMHF",
244
+ ),
245
+ ("qwen25vl", VLM_TASK, BackendType.HF): LazyClass(
246
+ module_name="inference_models.models.qwen25vl.qwen25vl_hf",
247
+ class_name="Qwen25VLHF",
248
+ ),
249
+ ("florence-2", VLM_TASK, BackendType.HF): LazyClass(
250
+ module_name="inference_models.models.florence2.florence2_hf",
251
+ class_name="Florence2HF",
252
+ ),
253
+ ("clip", EMBEDDING_TASK, BackendType.TORCH): LazyClass(
254
+ module_name="inference_models.models.clip.clip_pytorch",
255
+ class_name="ClipTorch",
256
+ ),
257
+ ("clip", EMBEDDING_TASK, BackendType.ONNX): LazyClass(
258
+ module_name="inference_models.models.clip.clip_onnx",
259
+ class_name="ClipOnnx",
260
+ ),
261
+ ("perception-encoder", EMBEDDING_TASK, BackendType.TORCH): LazyClass(
262
+ module_name="inference_models.models.perception_encoder.perception_encoder_pytorch",
263
+ class_name="PerceptionEncoderTorch",
264
+ ),
265
+ ("rfdetr", OBJECT_DETECTION_TASK, BackendType.TRT): LazyClass(
266
+ module_name="inference_models.models.rfdetr.rfdetr_object_detection_trt",
267
+ class_name="RFDetrForObjectDetectionTRT",
268
+ ),
269
+ ("rfdetr", OBJECT_DETECTION_TASK, BackendType.TORCH): LazyClass(
270
+ module_name="inference_models.models.rfdetr.rfdetr_object_detection_pytorch",
271
+ class_name="RFDetrForObjectDetectionTorch",
272
+ ),
273
+ ("rfdetr", OBJECT_DETECTION_TASK, BackendType.ONNX): LazyClass(
274
+ module_name="inference_models.models.rfdetr.rfdetr_object_detection_onnx",
275
+ class_name="RFDetrForObjectDetectionONNX",
276
+ ),
277
+ ("rfdetr", INSTANCE_SEGMENTATION_TASK, BackendType.TORCH): LazyClass(
278
+ module_name="inference_models.models.rfdetr.rfdetr_instance_segmentation_pytorch",
279
+ class_name="RFDetrForInstanceSegmentationTorch",
280
+ ),
281
+ ("rfdetr", INSTANCE_SEGMENTATION_TASK, BackendType.ONNX): LazyClass(
282
+ module_name="inference_models.models.rfdetr.rfdetr_instance_segmentation_onnx",
283
+ class_name="RFDetrForInstanceSegmentationOnnx",
284
+ ),
285
+ ("rfdetr", INSTANCE_SEGMENTATION_TASK, BackendType.TRT): LazyClass(
286
+ module_name="inference_models.models.rfdetr.rfdetr_instance_segmentation_trt",
287
+ class_name="RFDetrForInstanceSegmentationTRT",
288
+ ),
289
+ ("moondream2", VLM_TASK, BackendType.HF): LazyClass(
290
+ module_name="inference_models.models.moondream2.moondream2_hf",
291
+ class_name="MoonDream2HF",
292
+ ),
293
+ ("vit", CLASSIFICATION_TASK, BackendType.ONNX): LazyClass(
294
+ module_name="inference_models.models.vit.vit_classification_onnx",
295
+ class_name="VITForClassificationOnnx",
296
+ ),
297
+ ("vit", MULTI_LABEL_CLASSIFICATION_TASK, BackendType.ONNX): LazyClass(
298
+ module_name="inference_models.models.vit.vit_classification_onnx",
299
+ class_name="VITForMultiLabelClassificationOnnx",
300
+ ),
301
+ ("vit", CLASSIFICATION_TASK, BackendType.HF): LazyClass(
302
+ module_name="inference_models.models.vit.vit_classification_huggingface",
303
+ class_name="VITForClassificationHF",
304
+ ),
305
+ ("vit", MULTI_LABEL_CLASSIFICATION_TASK, BackendType.HF): LazyClass(
306
+ module_name="inference_models.models.vit.vit_classification_huggingface",
307
+ class_name="VITForMultiLabelClassificationHF",
308
+ ),
309
+ ("vit", CLASSIFICATION_TASK, BackendType.TRT): LazyClass(
310
+ module_name="inference_models.models.vit.vit_classification_trt",
311
+ class_name="VITForClassificationTRT",
312
+ ),
313
+ ("vit", MULTI_LABEL_CLASSIFICATION_TASK, BackendType.TRT): LazyClass(
314
+ module_name="inference_models.models.vit.vit_classification_trt",
315
+ class_name="VITForMultiLabelClassificationTRT",
316
+ ),
317
+ ("resnet", CLASSIFICATION_TASK, BackendType.ONNX): LazyClass(
318
+ module_name="inference_models.models.resnet.resnet_classification_onnx",
319
+ class_name="ResNetForClassificationOnnx",
320
+ ),
321
+ ("resnet", MULTI_LABEL_CLASSIFICATION_TASK, BackendType.ONNX): LazyClass(
322
+ module_name="inference_models.models.resnet.resnet_classification_onnx",
323
+ class_name="ResNetForMultiLabelClassificationOnnx",
324
+ ),
325
+ ("resnet", CLASSIFICATION_TASK, BackendType.TORCH): LazyClass(
326
+ module_name="inference_models.models.resnet.resnet_classification_torch",
327
+ class_name="ResNetForClassificationTorch",
328
+ ),
329
+ ("resnet", MULTI_LABEL_CLASSIFICATION_TASK, BackendType.TORCH): LazyClass(
330
+ module_name="inference_models.models.resnet.resnet_classification_torch",
331
+ class_name="ResNetForMultiLabelClassificationTorch",
332
+ ),
333
+ ("resnet", CLASSIFICATION_TASK, BackendType.TRT): LazyClass(
334
+ module_name="inference_models.models.resnet.resnet_classification_trt",
335
+ class_name="ResNetForClassificationTRT",
336
+ ),
337
+ ("resnet", MULTI_LABEL_CLASSIFICATION_TASK, BackendType.TRT): LazyClass(
338
+ module_name="inference_models.models.resnet.resnet_classification_trt",
339
+ class_name="ResNetForMultiLabelClassificationTRT",
340
+ ),
341
+ ("segment-anything-2-rt", INSTANCE_SEGMENTATION_TASK, BackendType.TORCH): LazyClass(
342
+ module_name="inference_models.models.sam2_rt.sam2_pytorch",
343
+ class_name="SAM2ForStream",
344
+ ),
345
+ ("deep-lab-v3-plus", SEMANTIC_SEGMENTATION_TASK, BackendType.TORCH): LazyClass(
346
+ module_name="inference_models.models.deep_lab_v3_plus.deep_lab_v3_plus_segmentation_torch",
347
+ class_name="DeepLabV3PlusForSemanticSegmentationTorch",
348
+ ),
349
+ ("deep-lab-v3-plus", SEMANTIC_SEGMENTATION_TASK, BackendType.ONNX): LazyClass(
350
+ module_name="inference_models.models.deep_lab_v3_plus.deep_lab_v3_plus_segmentation_onnx",
351
+ class_name="DeepLabV3PlusForSemanticSegmentationOnnx",
352
+ ),
353
+ ("deep-lab-v3-plus", SEMANTIC_SEGMENTATION_TASK, BackendType.TRT): LazyClass(
354
+ module_name="inference_models.models.deep_lab_v3_plus.deep_lab_v3_plus_segmentation_trt",
355
+ class_name="DeepLabV3PlusForSemanticSegmentationTRT",
356
+ ),
357
+ ("yolact", INSTANCE_SEGMENTATION_TASK, BackendType.ONNX): LazyClass(
358
+ module_name="inference_models.models.yolact.yolact_instance_segmentation_onnx",
359
+ class_name="YOLOACTForInstanceSegmentationOnnx",
360
+ ),
361
+ ("yolact", INSTANCE_SEGMENTATION_TASK, BackendType.TRT): LazyClass(
362
+ module_name="inference_models.models.yolact.yolact_instance_segmentation_trt",
363
+ class_name="YOLOACTForInstanceSegmentationTRT",
364
+ ),
365
+ ("depth-anything-v2", DEPTH_ESTIMATION_TASK, BackendType.HF): LazyClass(
366
+ module_name="inference_models.models.depth_anything_v2.depth_anything_v2_hf",
367
+ class_name="DepthAnythingV2HF",
368
+ ),
369
+ ("doctr", STRUCTURED_OCR_TASK, BackendType.TORCH): LazyClass(
370
+ module_name="inference_models.models.doctr.doctr_torch", class_name="DocTR"
371
+ ),
372
+ ("easy-ocr", STRUCTURED_OCR_TASK, BackendType.TORCH): LazyClass(
373
+ module_name="inference_models.models.easy_ocr.easy_ocr_torch",
374
+ class_name="EasyOCRTorch",
375
+ ),
376
+ ("tr-ocr", TEXT_ONLY_OCR_TASK, BackendType.HF): LazyClass(
377
+ module_name="inference_models.models.trocr.trocr_hf",
378
+ class_name="TROcrHF",
379
+ ),
380
+ (
381
+ "mediapipe-face-detector",
382
+ KEYPOINT_DETECTION_TASK,
383
+ BackendType.MEDIAPIPE,
384
+ ): LazyClass(
385
+ module_name="inference_models.models.mediapipe_face_detection.face_detection",
386
+ class_name="MediaPipeFaceDetector",
387
+ ),
388
+ ("l2cs-net", GAZE_DETECTION_TASK, BackendType.ONNX): LazyClass(
389
+ module_name="inference_models.models.l2cs.l2cs_onnx",
390
+ class_name="L2CSNetOnnx",
391
+ ),
392
+ (
393
+ "grounding-dino",
394
+ OPEN_VOCABULARY_OBJECT_DETECTION_TASK,
395
+ BackendType.TORCH,
396
+ ): LazyClass(
397
+ module_name="inference_models.models.grounding_dino.grounding_dino_torch",
398
+ class_name="GroundingDinoForObjectDetectionTorch",
399
+ ),
400
+ (
401
+ "dinov3_probe",
402
+ MULTI_LABEL_CLASSIFICATION_TASK,
403
+ BackendType.ONNX,
404
+ ): LazyClass(
405
+ module_name="inference_models.models.dinov3.dinov3_classification_onnx",
406
+ class_name="DinoV3ForMultiLabelClassificationOnnx",
407
+ ),
408
+ (
409
+ "dinov3_probe",
410
+ CLASSIFICATION_TASK,
411
+ BackendType.ONNX,
412
+ ): LazyClass(
413
+ module_name="inference_models.models.dinov3.dinov3_classification_onnx",
414
+ class_name="DinoV3ForClassificationOnnx",
415
+ ),
416
+ (
417
+ "dinov3_probe",
418
+ MULTI_LABEL_CLASSIFICATION_TASK,
419
+ BackendType.TORCH,
420
+ ): LazyClass(
421
+ module_name="inference_models.models.dinov3.dinov3_classification_torch",
422
+ class_name="DinoV3ForMultiLabelClassificationTorch",
423
+ ),
424
+ (
425
+ "dinov3_probe",
426
+ CLASSIFICATION_TASK,
427
+ BackendType.TORCH,
428
+ ): LazyClass(
429
+ module_name="inference_models.models.dinov3.dinov3_classification_torch",
430
+ class_name="DinoV3ForClassificationTorch",
431
+ ),
432
+ (
433
+ "owlv2",
434
+ OPEN_VOCABULARY_OBJECT_DETECTION_TASK,
435
+ BackendType.HF,
436
+ ): LazyClass(
437
+ module_name="inference_models.models.owlv2.owlv2_hf",
438
+ class_name="OWLv2HF",
439
+ ),
440
+ (
441
+ "roboflow-instant",
442
+ OBJECT_DETECTION_TASK,
443
+ BackendType.HF,
444
+ ): LazyClass(
445
+ module_name="inference_models.models.roboflow_instant.roboflow_instant_hf",
446
+ class_name="RoboflowInstantHF",
447
+ ),
448
+ ("sam", INTERACTIVE_INSTANCE_SEGMENTATION_TASK, BackendType.TORCH): LazyClass(
449
+ module_name="inference_models.models.sam.sam_torch",
450
+ class_name="SAMTorch",
451
+ ),
452
+ ("sam2", INTERACTIVE_INSTANCE_SEGMENTATION_TASK, BackendType.TORCH): LazyClass(
453
+ module_name="inference_models.models.sam2.sam2_torch",
454
+ class_name="SAM2Torch",
455
+ ),
456
+ }
457
+
458
+
459
+ def resolve_model_class(
460
+ model_architecture: ModelArchitecture,
461
+ task_type: TaskType,
462
+ backend: BackendType,
463
+ model_features: Optional[Set[str]] = None,
464
+ ) -> type:
465
+ if not model_implementation_exists(
466
+ model_architecture=model_architecture,
467
+ task_type=task_type,
468
+ backend=backend,
469
+ model_features=model_features,
470
+ ):
471
+ raise ModelImplementationLoaderError(
472
+ message=f"Did not find implementation for model with architecture: {model_architecture}, "
473
+ f"task type: {task_type} backend: {backend} and model features: {model_features}",
474
+ help_url="https://todo",
475
+ )
476
+ matched_model = REGISTERED_MODELS[(model_architecture, task_type, backend)]
477
+ if isinstance(matched_model, RegistryEntry):
478
+ return matched_model.model_class.resolve()
479
+ return matched_model.resolve()
480
+
481
+
482
+ def model_implementation_exists(
483
+ model_architecture: ModelArchitecture,
484
+ task_type: TaskType,
485
+ backend: BackendType,
486
+ model_features: Optional[Set[str]] = None,
487
+ ) -> bool:
488
+ lookup_key = (model_architecture, task_type, backend)
489
+ if lookup_key not in REGISTERED_MODELS:
490
+ return False
491
+ if not model_features:
492
+ return True
493
+ matched_model = REGISTERED_MODELS[(model_architecture, task_type, backend)]
494
+ if not isinstance(matched_model, RegistryEntry):
495
+ # features requested, but no supported features manifested
496
+ return False
497
+ return all(f in matched_model.supported_model_features for f in model_features)