inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,148 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Union
3
+
4
+ import clip
5
+ import numpy as np
6
+ import torch
7
+
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.entities import ColorFormat
10
+ from inference_models.errors import (
11
+ EnvironmentConfigurationError,
12
+ MissingDependencyError,
13
+ )
14
+ from inference_models.models.base.embeddings import TextImageEmbeddingModel
15
+ from inference_models.models.clip.preprocessing import create_clip_preprocessor
16
+ from inference_models.models.common.model_packages import get_model_package_contents
17
+ from inference_models.models.common.onnx import (
18
+ run_session_with_batch_size_limit,
19
+ set_execution_provider_defaults,
20
+ )
21
+ from inference_models.utils.onnx_introspection import (
22
+ get_selected_onnx_execution_providers,
23
+ )
24
+
25
+ try:
26
+ import onnxruntime
27
+ except ImportError as import_error:
28
+ raise MissingDependencyError(
29
+ message=f"Could not import CLIP model with ONNX backend - this error means that some additional dependencies "
30
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
31
+ f"program, make sure the following extras of the package are installed: \n"
32
+ f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
33
+ f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
34
+ f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
35
+ f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
36
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
37
+ f"You can also contact Roboflow to get support.",
38
+ help_url="https://todo",
39
+ ) from import_error
40
+
41
+
42
+ MEAN = (0.48145466, 0.4578275, 0.40821073)
43
+ STD = (0.26862954, 0.26130258, 0.27577711)
44
+
45
+
46
+ class ClipOnnx(TextImageEmbeddingModel):
47
+
48
+ @classmethod
49
+ def from_pretrained(
50
+ cls,
51
+ model_name_or_path: str,
52
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
53
+ default_onnx_trt_options: bool = True,
54
+ device: torch.device = DEFAULT_DEVICE,
55
+ max_batch_size: int = 32,
56
+ **kwargs,
57
+ ) -> "ClipOnnx":
58
+ if onnx_execution_providers is None:
59
+ onnx_execution_providers = get_selected_onnx_execution_providers()
60
+ if not onnx_execution_providers:
61
+ raise EnvironmentConfigurationError(
62
+ message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
63
+ f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
64
+ f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
65
+ f"contact the platform support.",
66
+ help_url="https://todo",
67
+ )
68
+ onnx_execution_providers = set_execution_provider_defaults(
69
+ providers=onnx_execution_providers,
70
+ model_package_path=model_name_or_path,
71
+ device=device,
72
+ default_onnx_trt_options=default_onnx_trt_options,
73
+ )
74
+ model_package_content = get_model_package_contents(
75
+ model_package_dir=model_name_or_path,
76
+ elements=[
77
+ "textual.onnx",
78
+ "visual.onnx",
79
+ ],
80
+ )
81
+ visual_onnx_session = onnxruntime.InferenceSession(
82
+ path_or_bytes=model_package_content["visual.onnx"],
83
+ providers=onnx_execution_providers,
84
+ )
85
+ textual_onnx_session = onnxruntime.InferenceSession(
86
+ path_or_bytes=model_package_content["textual.onnx"],
87
+ providers=onnx_execution_providers,
88
+ )
89
+ image_size = visual_onnx_session.get_inputs()[0].shape[2]
90
+ visual_input_name = visual_onnx_session.get_inputs()[0].name
91
+ textual_input_name = textual_onnx_session.get_inputs()[0].name
92
+ return cls(
93
+ visual_onnx_session=visual_onnx_session,
94
+ textual_onnx_session=textual_onnx_session,
95
+ image_size=image_size,
96
+ visual_input_name=visual_input_name,
97
+ textual_input_name=textual_input_name,
98
+ device=device,
99
+ max_batch_size=max_batch_size,
100
+ )
101
+
102
+ def __init__(
103
+ self,
104
+ visual_onnx_session: onnxruntime.InferenceSession,
105
+ textual_onnx_session: onnxruntime.InferenceSession,
106
+ image_size: int,
107
+ visual_input_name: str,
108
+ textual_input_name: str,
109
+ device: torch.device,
110
+ max_batch_size: int,
111
+ ):
112
+ self._visual_onnx_session = visual_onnx_session
113
+ self._textual_onnx_session = textual_onnx_session
114
+ self._image_size = image_size
115
+ self._visual_input_name = visual_input_name
116
+ self._textual_input_name = textual_input_name
117
+ self._device = device
118
+ self._max_batch_size = max_batch_size
119
+ self._visual_session_thread_lock = Lock()
120
+ self._textual_session_thread_lock = Lock()
121
+ self._preprocessor = create_clip_preprocessor(image_size=image_size)
122
+
123
+ def embed_images(
124
+ self,
125
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
126
+ input_color_format: Optional[ColorFormat] = None,
127
+ **kwargs,
128
+ ) -> torch.Tensor:
129
+ pre_processed_images = self._preprocessor(
130
+ images, input_color_format, self._device
131
+ )
132
+ with self._visual_session_thread_lock:
133
+ return run_session_with_batch_size_limit(
134
+ session=self._visual_onnx_session,
135
+ inputs={self._visual_input_name: pre_processed_images},
136
+ max_batch_size=self._max_batch_size,
137
+ )[0]
138
+
139
+ def embed_text(self, texts: Union[str, List[str]], **kwargs) -> torch.Tensor:
140
+ if not isinstance(texts, list):
141
+ texts = [texts]
142
+ tokenized_batch = clip.tokenize(texts)
143
+ with self._textual_session_thread_lock:
144
+ return run_session_with_batch_size_limit(
145
+ session=self._textual_onnx_session,
146
+ inputs={self._textual_input_name: tokenized_batch},
147
+ max_batch_size=self._max_batch_size,
148
+ )[0]
@@ -0,0 +1,104 @@
1
+ from typing import Callable, List, Optional, Union
2
+
3
+ import clip
4
+ import numpy as np
5
+ import torch
6
+ from clip.model import CLIP, build_model
7
+
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.entities import ColorFormat
10
+ from inference_models.errors import CorruptedModelPackageError
11
+ from inference_models.models.base.embeddings import TextImageEmbeddingModel
12
+ from inference_models.models.clip.preprocessing import create_clip_preprocessor
13
+ from inference_models.models.common.model_packages import get_model_package_contents
14
+
15
+
16
+ class ClipTorch(TextImageEmbeddingModel):
17
+
18
+ @classmethod
19
+ def from_pretrained(
20
+ cls,
21
+ model_name_or_path: str,
22
+ device: torch.device = DEFAULT_DEVICE,
23
+ max_batch_size: int = 32,
24
+ **kwargs,
25
+ ) -> "ClipTorch":
26
+ model_package_content = get_model_package_contents(
27
+ model_package_dir=model_name_or_path,
28
+ elements=["model.pt"],
29
+ )
30
+ model_weights_file = model_package_content["model.pt"]
31
+ model = build_clip_model(model_weights_file=model_weights_file, device=device)
32
+ model.eval()
33
+ return cls(
34
+ model=model,
35
+ tokenizer=clip.tokenize,
36
+ device=device,
37
+ max_batch_size=max_batch_size,
38
+ )
39
+
40
+ def __init__(
41
+ self,
42
+ model: CLIP,
43
+ tokenizer: Callable,
44
+ device: torch.device,
45
+ max_batch_size: int,
46
+ ):
47
+ self._model = model
48
+ self._tokenizer = tokenizer
49
+ self._device = device
50
+ self._preprocessor = create_clip_preprocessor(
51
+ image_size=model.visual.input_resolution
52
+ )
53
+ self._max_batch_size = max_batch_size
54
+
55
+ @torch.no_grad()
56
+ def embed_images(
57
+ self,
58
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
59
+ input_color_format: Optional[ColorFormat] = None,
60
+ **kwargs,
61
+ ) -> torch.Tensor:
62
+ tensor_batch = self._preprocessor(images, input_color_format, self._device)
63
+ if tensor_batch.shape[0] <= self._max_batch_size:
64
+ return self._model.encode_image(tensor_batch.to(self._device))
65
+ results = []
66
+ for i in range(0, tensor_batch.shape[0], self._max_batch_size):
67
+ batch_input = tensor_batch[i : i + self._max_batch_size].contiguous()
68
+ batch_results = self._model.encode_image(batch_input.to(self._device))
69
+ results.append(batch_results)
70
+ return torch.cat(results, dim=0)
71
+
72
+ @torch.no_grad()
73
+ def embed_text(
74
+ self,
75
+ texts: Union[str, List[str]],
76
+ **kwargs,
77
+ ) -> torch.Tensor:
78
+ if isinstance(texts, str):
79
+ texts = [texts]
80
+ text_tokens = self._tokenizer(texts).to(self._device)
81
+ if text_tokens.shape[0] <= self._max_batch_size:
82
+ return self._model.encode_text(text_tokens)
83
+ results = []
84
+ for i in range(0, text_tokens.shape[0], self._max_batch_size):
85
+ batch_input = text_tokens[i : i + self._max_batch_size].contiguous()
86
+ batch_results = self._model.encode_text(batch_input)
87
+ results.append(batch_results)
88
+ return torch.cat(results, dim=0)
89
+
90
+
91
+ def build_clip_model(model_weights_file: str, device: torch.device) -> CLIP:
92
+ try:
93
+ # The model file is a JIT archive, so we load it as such
94
+ # and then build a new model from its state dict.
95
+ jit_model = torch.jit.load(model_weights_file, map_location="cpu").eval()
96
+ state_dict = jit_model.state_dict()
97
+ model = build_model(state_dict).to(device)
98
+ if device.type == "cpu":
99
+ model.float()
100
+ return model
101
+ except Exception as e:
102
+ raise CorruptedModelPackageError(
103
+ f"Could not load TorchScript model from {model_weights_file}. Details: {e}"
104
+ ) from e
@@ -0,0 +1,162 @@
1
+ from functools import partial
2
+ from typing import Callable, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torchvision.transforms import (
7
+ CenterCrop,
8
+ Compose,
9
+ InterpolationMode,
10
+ Normalize,
11
+ Resize,
12
+ )
13
+
14
+ from inference_models.entities import ColorFormat
15
+ from inference_models.errors import ModelRuntimeError
16
+
17
+ MEAN = (0.48145466, 0.4578275, 0.40821073)
18
+ STD = (0.26862954, 0.26130258, 0.27577711)
19
+
20
+ PreprocessorFun = Callable[
21
+ [
22
+ Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
23
+ Optional[ColorFormat],
24
+ torch.device,
25
+ ],
26
+ torch.Tensor,
27
+ ]
28
+
29
+
30
+ def create_clip_preprocessor(image_size: int) -> PreprocessorFun:
31
+ """
32
+ Creates a preprocessor for CLIP models that operates on tensors.
33
+
34
+ This implementation replicates the logic of the original CLIP preprocessing pipeline
35
+ but is designed to work directly with torch.Tensors and np.ndarrays, avoiding
36
+ the need to convert to and from PIL.Image objects.
37
+
38
+ Note: Due to differences in the underlying resizing algorithms (torchvision vs. PIL),
39
+ the output of this preprocessor may have minor numerical differences compared to
40
+ the original. These differences have been tested and are known to produce
41
+ embeddings with very high cosine similarity, making them functionally equivalent.
42
+
43
+ Args:
44
+ image_size (int): The target size for the input images.`
45
+ device (torch.device): The device to move the tensors to.
46
+
47
+ Returns:
48
+ A callable function that preprocesses images.
49
+ """
50
+ # This pre-processing pipeline matches the original CLIP implementation.
51
+ # 1. Resize to `image_size`
52
+ # 2. Center crop to `image_size`
53
+ # 3. Scale pixel values to [0, 1]
54
+ # 4. Normalize with CLIP's specific mean and standard deviation.
55
+ transforms = Compose(
56
+ [
57
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
58
+ CenterCrop(image_size),
59
+ lambda x: x.to(torch.float32) / 255.0,
60
+ Normalize(MEAN, STD),
61
+ ]
62
+ )
63
+ return partial(pre_process_image, transforms=transforms)
64
+
65
+
66
+ def pre_process_image(
67
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
68
+ input_color_format: Optional[ColorFormat],
69
+ device: torch.device,
70
+ transforms: Compose,
71
+ ) -> torch.Tensor:
72
+ images = inputs_to_tensor(
73
+ images=images, device=device, input_color_format=input_color_format
74
+ )
75
+ if isinstance(images, torch.Tensor):
76
+ return transforms(images)
77
+ return torch.cat([transforms(i) for i in images], dim=0).contiguous()
78
+
79
+
80
+ def inputs_to_tensor(
81
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
82
+ device: torch.device,
83
+ input_color_format: Optional[ColorFormat] = None,
84
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
85
+ if not isinstance(images, (list, np.ndarray, torch.Tensor)):
86
+ raise ModelRuntimeError(
87
+ f"Unsupported input type: {type(images)}. Must be one of list, np.ndarray, or torch.Tensor."
88
+ )
89
+ if isinstance(images, list):
90
+ if not images:
91
+ raise ModelRuntimeError("Input image list cannot be empty.")
92
+ return [
93
+ input_to_tensor(
94
+ image=image,
95
+ device=device,
96
+ input_color_format=input_color_format,
97
+ batched_tensors_allowed=False,
98
+ )
99
+ for image in images
100
+ ]
101
+ return input_to_tensor(
102
+ image=images,
103
+ device=device,
104
+ input_color_format=input_color_format,
105
+ ).contiguous()
106
+
107
+
108
+ def input_to_tensor(
109
+ image: Union[torch.Tensor, np.ndarray],
110
+ device: torch.device,
111
+ input_color_format: Optional[ColorFormat] = None,
112
+ batched_tensors_allowed: bool = True,
113
+ ) -> torch.Tensor:
114
+ if not isinstance(image, (np.ndarray, torch.Tensor)):
115
+ raise ModelRuntimeError(
116
+ f"Unsupported input type: {type(image)}. Each element must be one of np.ndarray, or torch.Tensor."
117
+ )
118
+ is_numpy = isinstance(image, np.ndarray)
119
+ if is_numpy:
120
+ if len(image.shape) != 3:
121
+ raise ModelRuntimeError(
122
+ f"Unsupported input type: detected np.ndarray image of shape {image.shape} which has "
123
+ f"number of dimensions different than 3. This input is invalid."
124
+ )
125
+ if image.shape[-1] != 3:
126
+ raise ModelRuntimeError(
127
+ f"Unsupported input type: detected np.ndarray image of shape {image.shape} which has "
128
+ f"incorrect number of color channels (expected: 3)."
129
+ )
130
+ # HWC -> CHW
131
+ tensor_image = torch.from_numpy(image).to(device).permute(2, 0, 1).unsqueeze(0)
132
+ else:
133
+ expected_dimensions_str = (
134
+ "expected: 3 or 4" if batched_tensors_allowed else "expected: 3"
135
+ )
136
+ if len(image.shape) == 4 and not batched_tensors_allowed:
137
+ raise ModelRuntimeError(
138
+ f"Unsupported input type: detected torch.Tensor image of shape {image.shape} which has "
139
+ f"incorrect number of dimensions ({expected_dimensions_str})."
140
+ )
141
+ if len(image.shape) != 3 and len(image.shape) != 4:
142
+ raise ModelRuntimeError(
143
+ f"Unsupported input type: detected torch.Tensor image of shape {image.shape} which has "
144
+ f"incorrect number of dimensions ({expected_dimensions_str})."
145
+ )
146
+ if (len(image.shape) == 3 and image.shape[0] != 3) or (
147
+ len(image.shape) == 4 and image.shape[1] != 3
148
+ ):
149
+ raise ModelRuntimeError(
150
+ f"Unsupported input type: detected torch.Tensor image of shape {image.shape} which has "
151
+ f"incorrect number of color channels (expected: 3)."
152
+ )
153
+ if len(image.shape) == 3:
154
+ image = image.unsqueeze(0)
155
+ tensor_image = image.to(device)
156
+ effective_color_format = input_color_format
157
+ if effective_color_format is None:
158
+ effective_color_format = "bgr" if is_numpy else "rgb"
159
+ if effective_color_format == "bgr":
160
+ # BGR -> RGB
161
+ tensor_image = tensor_image[:, [2, 1, 0], :, :]
162
+ return tensor_image
File without changes
@@ -0,0 +1,30 @@
1
+ import contextlib
2
+ from typing import Generator
3
+
4
+ from inference_models.errors import MissingDependencyError
5
+
6
+ try:
7
+ import pycuda.driver as cuda
8
+ except ImportError as import_error:
9
+ raise MissingDependencyError(
10
+ message="TODO",
11
+ help_url="https://todo",
12
+ ) from import_error
13
+
14
+
15
+ @contextlib.contextmanager
16
+ def use_primary_cuda_context(
17
+ cuda_device: cuda.Device,
18
+ ) -> Generator[cuda.Context, None, None]:
19
+ context = cuda_device.retain_primary_context()
20
+ with use_cuda_context(context) as ctx:
21
+ yield ctx
22
+
23
+
24
+ @contextlib.contextmanager
25
+ def use_cuda_context(context: cuda.Context) -> Generator[cuda.Context, None, None]:
26
+ context.push()
27
+ try:
28
+ yield context
29
+ finally:
30
+ context.pop()
@@ -0,0 +1,25 @@
1
+ import os.path
2
+ from typing import Dict, List
3
+
4
+ from inference_models.errors import CorruptedModelPackageError
5
+
6
+
7
+ def get_model_package_contents(
8
+ model_package_dir: str,
9
+ elements: List[str],
10
+ ) -> Dict[str, str]:
11
+ result = {}
12
+ for element in elements:
13
+ element_path = os.path.join(model_package_dir, element)
14
+ if not os.path.exists(element_path):
15
+ raise CorruptedModelPackageError(
16
+ message=f"Model package is incomplete. Could not find element {element}. "
17
+ f"If you attempt to run `inference-models` locally - inspect the contents of local directory to check for "
18
+ f"completeness of model package download - lack of files may indicate network issues. Verification "
19
+ f"of connectivity may be a good first step. If you prepared the model package manually - examine the "
20
+ f"correctness of the setup. If you run on managed serving - contact support if the issue is "
21
+ f"not ephemeral.",
22
+ help_url="https://todo",
23
+ )
24
+ result[element] = element_path
25
+ return result