inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,291 @@
1
+ import hashlib
2
+ import os.path
3
+ import re
4
+ import urllib.parse
5
+ from typing import List, Optional, Union
6
+
7
+ import backoff
8
+ import cv2
9
+ import numpy as np
10
+ import pybase64
11
+ import requests
12
+ import torch
13
+ from requests import Timeout
14
+ from tldextract import tldextract
15
+ from tldextract.tldextract import ExtractResult
16
+
17
+ from inference_models.configuration import (
18
+ API_CALLS_MAX_TRIES,
19
+ IDEMPOTENT_API_REQUEST_CODES_TO_RETRY,
20
+ )
21
+ from inference_models.errors import ModelInputError, ModelRuntimeError, RetryError
22
+
23
+ BASE64_DATA_TYPE_PATTERN = re.compile(r"^data:image\/[a-z]+;base64,")
24
+
25
+
26
+ class LazyImageWrapper:
27
+
28
+ @classmethod
29
+ def init(
30
+ cls,
31
+ image: Union[np.ndarray, torch.Tensor, str, bytes],
32
+ allow_url_input: bool,
33
+ allow_non_https_url: bool,
34
+ allow_url_without_fqdn: bool,
35
+ whitelisted_domains: Optional[List[str]],
36
+ blacklisted_domains: Optional[List[str]],
37
+ allow_local_storage_access: bool,
38
+ ):
39
+ image_in_memory, image_reference = None, None
40
+ if isinstance(image, (torch.Tensor, np.ndarray)):
41
+ image_in_memory = image
42
+ else:
43
+ image_reference = image
44
+ return cls(
45
+ allow_url_input=allow_url_input,
46
+ allow_non_https_url=allow_non_https_url,
47
+ allow_url_without_fqdn=allow_url_without_fqdn,
48
+ whitelisted_domains=whitelisted_domains,
49
+ blacklisted_domains=blacklisted_domains,
50
+ allow_local_storage_access=allow_local_storage_access,
51
+ image_in_memory=image_in_memory,
52
+ image_reference=image_reference,
53
+ )
54
+
55
+ def __init__(
56
+ self,
57
+ allow_url_input: bool,
58
+ allow_non_https_url: bool,
59
+ allow_url_without_fqdn: bool,
60
+ whitelisted_domains: Optional[List[str]],
61
+ blacklisted_domains: Optional[List[str]],
62
+ allow_local_storage_access: bool,
63
+ image_in_memory: Optional[Union[np.ndarray, torch.Tensor]] = None,
64
+ image_reference: Optional[Union[str, bytes]] = None,
65
+ image_hash: Optional[str] = None,
66
+ ):
67
+ self._allow_url_input = allow_url_input
68
+ self._allow_non_https_url = allow_non_https_url
69
+ self._allow_url_without_fqdn = allow_url_without_fqdn
70
+ self._whitelisted_domains = whitelisted_domains
71
+ self._blacklisted_domains = blacklisted_domains
72
+ self._allow_local_storage_access = allow_local_storage_access
73
+ if image_in_memory is None and image_reference is None:
74
+ raise ModelRuntimeError(
75
+ message="Attempted to use OWLv2 image lazy loading not providing neither image "
76
+ "location nor image instance - this is invalid input. Contact Roboflow to get help.",
77
+ help_url="https://todo",
78
+ )
79
+ self._image_in_memory = image_in_memory
80
+ self._image_reference = image_reference
81
+ self._image_hash = image_hash
82
+
83
+ def as_numpy(self) -> np.ndarray:
84
+ if self._image_in_memory is not None:
85
+ if isinstance(self._image_in_memory, torch.Tensor):
86
+ self._image_in_memory = self._image_in_memory.cpu().numpy()
87
+ return self._image_in_memory
88
+ image = load_image_reference(
89
+ image_reference=self._image_reference,
90
+ allow_url_input=self._allow_url_input,
91
+ allow_non_https_url=self._allow_non_https_url,
92
+ allow_url_without_fqdn=self._allow_url_without_fqdn,
93
+ whitelisted_domains=self._whitelisted_domains,
94
+ blacklisted_domains=self._blacklisted_domains,
95
+ allow_local_storage_access=self._allow_local_storage_access,
96
+ )
97
+ self._image_in_memory = image
98
+ return image
99
+
100
+ def get_hash(self) -> str:
101
+ if self._image_hash is not None:
102
+ return self._image_hash
103
+ if self._image_reference is not None:
104
+ self._image_hash = hash_function(value=self._image_reference)
105
+ else:
106
+ self._image_hash = hash_function(value=self.as_numpy().tobytes())
107
+ return self._image_hash
108
+
109
+ def unload_image(self) -> None:
110
+ if self._image_in_memory is not None and self._image_reference is not None:
111
+ self._image_in_memory = None
112
+
113
+
114
+ def load_image_reference(
115
+ image_reference: Union[str, bytes],
116
+ allow_url_input: bool,
117
+ allow_non_https_url: bool,
118
+ allow_url_without_fqdn: bool,
119
+ whitelisted_domains: Optional[List[str]],
120
+ blacklisted_domains: Optional[List[str]],
121
+ allow_local_storage_access: bool,
122
+ ) -> np.ndarray:
123
+ if isinstance(image_reference, bytes):
124
+ return decode_image_from_bytes(image_bytes=image_reference)
125
+ if is_url(reference=image_reference):
126
+ return decode_image_from_url(
127
+ url=image_reference,
128
+ allow_url_input=allow_url_input,
129
+ allow_non_https_url=allow_non_https_url,
130
+ allow_url_without_fqdn=allow_url_without_fqdn,
131
+ whitelisted_domains=whitelisted_domains,
132
+ blacklisted_domains=blacklisted_domains,
133
+ )
134
+ if not allow_local_storage_access:
135
+ return decode_image_from_base64(value=image_reference)
136
+ elif os.path.isfile(image_reference):
137
+ return cv2.imread(image_reference)
138
+ else:
139
+ return decode_image_from_base64(value=image_reference)
140
+
141
+
142
+ def decode_image_from_url(
143
+ url: str,
144
+ allow_url_input: bool,
145
+ allow_non_https_url: bool,
146
+ allow_url_without_fqdn: bool,
147
+ whitelisted_domains: Optional[List[str]],
148
+ blacklisted_domains: Optional[List[str]],
149
+ ):
150
+ if not allow_url_input:
151
+ raise ModelInputError(
152
+ message="Providing images via URL is not supported in this configuration of `inference-models`.",
153
+ help_url="https://todo",
154
+ )
155
+ try:
156
+ parsed_url = urllib.parse.urlparse(url)
157
+ except ValueError as error:
158
+ raise ModelInputError(
159
+ message="Provided image URL is invalid.", help_url="https://todo"
160
+ ) from error
161
+ if parsed_url.scheme != "https" and not allow_non_https_url:
162
+ raise ModelInputError(
163
+ message="Providing images via non https:// URL is not supported in this configuration of `inference-models`.",
164
+ help_url="https://todo",
165
+ )
166
+ domain_extraction_result = tldextract.TLDExtract(suffix_list_urls=())(
167
+ parsed_url.netloc
168
+ ) # we get rid of potential ports and parse FQDNs
169
+ _ensure_resource_fqdn_allowed(
170
+ fqdn=domain_extraction_result.fqdn,
171
+ allow_url_without_fqdn=allow_url_without_fqdn,
172
+ )
173
+ address_parts_concatenated = _concatenate_chunks_of_network_location(
174
+ extraction_result=domain_extraction_result
175
+ ) # concatenation of chunks - even if there is no FQDN, but address
176
+ # it allows white-/black-list verification
177
+ _ensure_location_matches_destination_whitelist(
178
+ destination=address_parts_concatenated,
179
+ whitelisted_domains=whitelisted_domains,
180
+ )
181
+ _ensure_location_matches_destination_blacklist(
182
+ destination=address_parts_concatenated,
183
+ blacklisted_domains=blacklisted_domains,
184
+ )
185
+ image_content = _get_from_url(url=url)
186
+ return decode_image_from_bytes(image_bytes=image_content)
187
+
188
+
189
+ def decode_image_from_base64(value: str) -> np.ndarray:
190
+ try:
191
+ value = BASE64_DATA_TYPE_PATTERN.sub("", value)
192
+ decoded = pybase64.b64decode(value, validate=True)
193
+ return decode_image_from_bytes(image_bytes=decoded)
194
+ except Exception as error:
195
+ value_prefix = value[:16]
196
+ raise ModelInputError(
197
+ message=f"Could not decode bas64 image fro reference {value_prefix}.",
198
+ help_url="https://todo",
199
+ ) from error
200
+
201
+
202
+ def decode_image_from_bytes(image_bytes: bytes) -> np.ndarray:
203
+ byte_array = np.frombuffer(image_bytes, dtype=np.uint8)
204
+ return cv2.imdecode(byte_array, cv2.IMREAD_COLOR)
205
+
206
+
207
+ def is_url(reference: str) -> bool:
208
+ return reference.startswith("http://") or reference.startswith("https://")
209
+
210
+
211
+ def _ensure_resource_fqdn_allowed(fqdn: str, allow_url_without_fqdn: bool) -> None:
212
+ if not fqdn and not allow_url_without_fqdn:
213
+ raise ModelInputError(
214
+ message="Providing images via URL without FQDN is not supported in this configuration of `inference-models`.",
215
+ help_url="https://todo",
216
+ )
217
+ return None
218
+
219
+
220
+ def _concatenate_chunks_of_network_location(extraction_result: ExtractResult) -> str:
221
+ chunks = [
222
+ extraction_result.subdomain,
223
+ extraction_result.domain,
224
+ extraction_result.suffix,
225
+ ]
226
+ non_empty_chunks = [chunk for chunk in chunks if chunk]
227
+ result = ".".join(non_empty_chunks)
228
+ if result.startswith("[") and result.endswith("]"):
229
+ # dropping brackets for IPv6
230
+ return result[1:-1]
231
+ return result
232
+
233
+
234
+ def _ensure_location_matches_destination_whitelist(
235
+ destination: str, whitelisted_domains: Optional[List[str]]
236
+ ) -> None:
237
+ if whitelisted_domains is None:
238
+ return None
239
+ if destination not in whitelisted_domains:
240
+ raise ModelInputError(
241
+ message="It is not allowed to reach image URL - prohibited by whitelisted destinations",
242
+ help_url="https://todo",
243
+ )
244
+ return None
245
+
246
+
247
+ def _ensure_location_matches_destination_blacklist(
248
+ destination: str,
249
+ blacklisted_domains: Optional[List[str]],
250
+ ) -> None:
251
+ if blacklisted_domains is None:
252
+ return None
253
+ if destination in blacklisted_domains:
254
+ raise ModelInputError(
255
+ message="It is not allowed to reach image URL - prohibited by blacklisted destinations.",
256
+ help_url="https://todo",
257
+ )
258
+ return None
259
+
260
+
261
+ @backoff.on_exception(
262
+ backoff.constant,
263
+ exception=RetryError,
264
+ max_tries=API_CALLS_MAX_TRIES,
265
+ interval=1,
266
+ )
267
+ def _get_from_url(url: str, timeout: int = 5) -> bytes:
268
+ try:
269
+ with requests.get(url, stream=True, timeout=timeout) as response:
270
+ if response.status_code in IDEMPOTENT_API_REQUEST_CODES_TO_RETRY:
271
+ raise RetryError(
272
+ message=f"File hosting returned {response.status_code}",
273
+ help_url="https://todo",
274
+ )
275
+ response.raise_for_status()
276
+ return response.content
277
+ except (ConnectionError, Timeout, requests.exceptions.ConnectionError):
278
+ raise RetryError(
279
+ message=f"Connectivity error",
280
+ help_url="https://todo",
281
+ )
282
+
283
+
284
+ def compute_image_hash(image: Union[torch.Tensor, np.ndarray]) -> str:
285
+ if isinstance(image, torch.Tensor):
286
+ image = image.cpu().numpy()
287
+ return hash_function(value=image.tobytes())
288
+
289
+
290
+ def hash_function(value: Union[str, bytes]) -> str:
291
+ return hashlib.sha1(value).hexdigest()
File without changes
@@ -0,0 +1,209 @@
1
+ import os
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from peft import PeftModel
7
+ from transformers import (
8
+ AutoProcessor,
9
+ BitsAndBytesConfig,
10
+ PaliGemmaForConditionalGeneration,
11
+ )
12
+
13
+ from inference_models.configuration import DEFAULT_DEVICE
14
+ from inference_models.entities import ColorFormat
15
+ from inference_models.models.common.roboflow.model_packages import (
16
+ InferenceConfig,
17
+ ResizeMode,
18
+ parse_inference_config,
19
+ )
20
+ from inference_models.models.common.roboflow.pre_processing import (
21
+ pre_process_network_input,
22
+ )
23
+
24
+
25
+ class PaliGemmaHF:
26
+
27
+ @classmethod
28
+ def from_pretrained(
29
+ cls,
30
+ model_name_or_path: str,
31
+ device: torch.device = DEFAULT_DEVICE,
32
+ trust_remote_code: bool = False,
33
+ local_files_only: bool = True,
34
+ quantization_config: Optional[BitsAndBytesConfig] = None,
35
+ disable_quantization: bool = False,
36
+ **kwargs,
37
+ ) -> "PaliGemmaHF":
38
+ torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
39
+ inference_config_path = os.path.join(
40
+ model_name_or_path, "inference_config.json"
41
+ )
42
+ inference_config = None
43
+ if os.path.exists(inference_config_path):
44
+ inference_config = parse_inference_config(
45
+ config_path=inference_config_path,
46
+ allowed_resize_modes={
47
+ ResizeMode.STRETCH_TO,
48
+ ResizeMode.LETTERBOX,
49
+ ResizeMode.CENTER_CROP,
50
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
51
+ },
52
+ )
53
+ if (
54
+ quantization_config is None
55
+ and device.type == "cuda"
56
+ and not disable_quantization
57
+ ):
58
+ quantization_config = BitsAndBytesConfig(
59
+ load_in_4bit=True,
60
+ bnb_4bit_quant_type="nf4",
61
+ bnb_4bit_compute_dtype=torch.bfloat16,
62
+ )
63
+ adapter_config_path = os.path.join(model_name_or_path, "adapter_config.json")
64
+ if os.path.exists(adapter_config_path):
65
+ base_model_path = os.path.join(model_name_or_path, "base")
66
+ model = PaliGemmaForConditionalGeneration.from_pretrained(
67
+ base_model_path,
68
+ dtype=torch_dtype,
69
+ trust_remote_code=trust_remote_code,
70
+ local_files_only=local_files_only,
71
+ quantization_config=quantization_config,
72
+ )
73
+ model = PeftModel.from_pretrained(model, model_name_or_path)
74
+ if quantization_config is None:
75
+ model.merge_and_unload()
76
+ model.to(device)
77
+
78
+ processor = AutoProcessor.from_pretrained(
79
+ base_model_path,
80
+ trust_remote_code=trust_remote_code,
81
+ local_files_only=local_files_only,
82
+ use_fast=True,
83
+ )
84
+ else:
85
+ model = PaliGemmaForConditionalGeneration.from_pretrained(
86
+ model_name_or_path,
87
+ dtype=torch_dtype,
88
+ device_map=device,
89
+ trust_remote_code=trust_remote_code,
90
+ local_files_only=local_files_only,
91
+ quantization_config=quantization_config,
92
+ ).eval()
93
+ processor = AutoProcessor.from_pretrained(
94
+ model_name_or_path,
95
+ trust_remote_code=trust_remote_code,
96
+ local_files_only=local_files_only,
97
+ use_fast=True,
98
+ )
99
+ return cls(
100
+ model=model,
101
+ processor=processor,
102
+ inference_config=inference_config,
103
+ device=device,
104
+ torch_dtype=torch_dtype,
105
+ )
106
+
107
+ def __init__(
108
+ self,
109
+ model: PaliGemmaForConditionalGeneration,
110
+ processor: AutoProcessor,
111
+ inference_config: Optional[InferenceConfig],
112
+ device: torch.device,
113
+ torch_dtype: torch.dtype,
114
+ ):
115
+ self._model = model
116
+ self._processor = processor
117
+ self._inference_config = inference_config
118
+ self._device = device
119
+ self._torch_dtype = torch_dtype
120
+
121
+ def prompt(
122
+ self,
123
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
124
+ prompt: str,
125
+ input_color_format: Optional[ColorFormat] = None,
126
+ max_new_tokens: int = 400,
127
+ do_sample: bool = False,
128
+ skip_special_tokens: bool = True,
129
+ **kwargs,
130
+ ) -> List[str]:
131
+ inputs = self.pre_process_generation(
132
+ images=images, prompt=prompt, input_color_format=input_color_format
133
+ )
134
+ generated_ids = self.generate(
135
+ inputs=inputs,
136
+ max_new_tokens=max_new_tokens,
137
+ do_sample=do_sample,
138
+ )
139
+ return self.post_process_generation(
140
+ generated_ids=generated_ids,
141
+ skip_special_tokens=skip_special_tokens,
142
+ )
143
+
144
+ def pre_process_generation(
145
+ self,
146
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
147
+ prompt: str,
148
+ input_color_format: Optional[ColorFormat] = None,
149
+ image_size: Optional[Tuple[int, int]] = None,
150
+ **kwargs,
151
+ ) -> dict:
152
+ def _to_tensor(image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
153
+ is_numpy = isinstance(image, np.ndarray)
154
+ if is_numpy:
155
+ tensor_image = torch.from_numpy(image.copy()).permute(2, 0, 1)
156
+ else:
157
+ tensor_image = image
158
+ if input_color_format == "bgr" or (is_numpy and input_color_format is None):
159
+ tensor_image = tensor_image[[2, 1, 0], :, :]
160
+ return tensor_image
161
+
162
+ if self._inference_config is None:
163
+ if isinstance(images, torch.Tensor) and images.ndim > 3:
164
+ image_list = [_to_tensor(img) for img in images]
165
+ elif not isinstance(images, list):
166
+ image_list = [_to_tensor(images)]
167
+ else:
168
+ image_list = [_to_tensor(img) for img in images]
169
+ else:
170
+ images = pre_process_network_input(
171
+ images=images,
172
+ image_pre_processing=self._inference_config.image_pre_processing,
173
+ network_input=self._inference_config.network_input,
174
+ target_device=self._device,
175
+ input_color_format=input_color_format,
176
+ image_size_wh=image_size,
177
+ )[0]
178
+ image_list = [e[0] for e in torch.split(images, 1, dim=0)]
179
+ num_images = len(image_list)
180
+
181
+ if isinstance(prompt, str) and num_images > 1:
182
+ prompt = [prompt] * num_images
183
+ return self._processor(text=prompt, images=image_list, return_tensors="pt").to(
184
+ self._device
185
+ )
186
+
187
+ def generate(
188
+ self,
189
+ inputs: dict,
190
+ max_new_tokens: int = 400,
191
+ do_sample: bool = False,
192
+ **kwargs,
193
+ ) -> torch.Tensor:
194
+ with torch.inference_mode():
195
+ generation = self._model.generate(
196
+ **inputs, max_new_tokens=max_new_tokens, do_sample=do_sample
197
+ )
198
+ input_len = inputs["input_ids"].shape[-1]
199
+ return generation[:, input_len:]
200
+
201
+ def post_process_generation(
202
+ self,
203
+ generated_ids: torch.Tensor,
204
+ skip_special_tokens: bool = False,
205
+ **kwargs,
206
+ ) -> List[str]:
207
+ return self._processor.batch_decode(
208
+ generated_ids, skip_special_tokens=skip_special_tokens
209
+ )
File without changes
@@ -0,0 +1,197 @@
1
+ import json
2
+ from typing import Callable, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from pydantic import BaseModel, ValidationError
8
+
9
+ import inference_models.models.perception_encoder.vision_encoder.pe as pe
10
+ import inference_models.models.perception_encoder.vision_encoder.transforms as transforms
11
+ from inference_models.configuration import DEFAULT_DEVICE
12
+ from inference_models.entities import ColorFormat
13
+ from inference_models.errors import CorruptedModelPackageError
14
+ from inference_models.models.base.embeddings import TextImageEmbeddingModel
15
+ from inference_models.models.common.model_packages import get_model_package_contents
16
+
17
+
18
+ class PerceptionEncoderConfig(BaseModel):
19
+ vision_encoder_config: str
20
+
21
+
22
+ def load_config(config_path: str) -> PerceptionEncoderConfig:
23
+ config_data = {}
24
+ try:
25
+ with open(config_path) as f:
26
+ config_data = json.load(f)
27
+ except (IOError, json.JSONDecodeError) as e:
28
+ raise CorruptedModelPackageError(
29
+ message=f"Could not load or parse perception encoder model package config file: {config_path}. Details: {e}",
30
+ help_url="https://todo",
31
+ ) from e
32
+ try:
33
+ config = PerceptionEncoderConfig.model_validate(config_data)
34
+ return config
35
+ except ValidationError as e:
36
+ raise CorruptedModelPackageError(
37
+ f"Failed validate perception encoder model package config file: {config_path}. Details: {e}"
38
+ ) from e
39
+
40
+
41
+ # based on original implementation using PIL images found in vision_encoder/transforms.py
42
+ # but adjusted to work directly on tensors
43
+ def create_image_resize_transform(
44
+ image_size: int,
45
+ center_crop: bool = False,
46
+ interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
47
+ ):
48
+ if center_crop:
49
+ crop = [
50
+ T.Resize(image_size, interpolation=interpolation, antialias=True),
51
+ T.CenterCrop(image_size),
52
+ ]
53
+ else:
54
+ # "Squash": most versatile
55
+ crop = [
56
+ T.Resize(
57
+ (image_size, image_size), interpolation=interpolation, antialias=True
58
+ )
59
+ ]
60
+ return T.Compose(crop)
61
+
62
+
63
+ def create_image_normalize_transform():
64
+ return T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
65
+
66
+
67
+ def create_preprocessor(image_size: int) -> Callable:
68
+ resize_transform = create_image_resize_transform(image_size)
69
+ normalize_transform = create_image_normalize_transform()
70
+
71
+ def _preprocess(
72
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
73
+ input_color_format: Optional[ColorFormat] = None,
74
+ ) -> torch.Tensor:
75
+ def _to_tensor(image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
76
+ is_numpy = isinstance(image, np.ndarray)
77
+ if is_numpy:
78
+ tensor_image = torch.from_numpy(image).permute(2, 0, 1)
79
+ else:
80
+ tensor_image = image
81
+
82
+ # For numpy array inputs, we default to BGR -> RGB conversion for compatibility.
83
+ # For tensor inputs, we only convert if BGR is explicitly specified, otherwise RGB is assumed.
84
+ if input_color_format == "bgr" or (is_numpy and input_color_format is None):
85
+ # BGR -> RGB
86
+ tensor_image = tensor_image[[2, 1, 0], :, :]
87
+
88
+ return tensor_image
89
+
90
+ if isinstance(images, list):
91
+ # Resize each image individually, then stack to a batch
92
+ resized_images = [resize_transform(_to_tensor(img)) for img in images]
93
+ tensor_batch = torch.stack(resized_images, dim=0)
94
+ else:
95
+ # Handle single image or pre-batched tensor
96
+ tensor_batch = resize_transform(_to_tensor(images))
97
+
98
+ # Ensure there is a batch dimension for single images
99
+ if tensor_batch.ndim == 3:
100
+ tensor_batch = tensor_batch.unsqueeze(0)
101
+
102
+ # Perform dtype conversion and normalization on the whole batch for efficiency
103
+ if tensor_batch.dtype == torch.uint8:
104
+ tensor_batch = tensor_batch.to(torch.float32) / 255.0
105
+
106
+ transformed_batch = normalize_transform(tensor_batch)
107
+ return transformed_batch
108
+
109
+ return _preprocess
110
+
111
+
112
+ class PerceptionEncoderTorch(TextImageEmbeddingModel):
113
+ def __init__(
114
+ self,
115
+ model: pe.CLIP,
116
+ device: torch.device,
117
+ ):
118
+ self.model = model
119
+ self.device = device
120
+ self.preprocessor = create_preprocessor(model.image_size)
121
+ self.tokenizer = transforms.get_text_tokenizer(model.context_length)
122
+
123
+ @classmethod
124
+ def from_pretrained(
125
+ cls, model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, **kwargs
126
+ ) -> "PerceptionEncoderTorch":
127
+ # here model name came from path before, which maybe doesn't match directly with how our registry works
128
+ # instead should this be adopted to read config file that is served as part of model package?
129
+ # model_config = model_name_or_path.split("/")[-1]
130
+ # checkpoint_path = os.path.join(model_name_or_path, "model.pt")
131
+
132
+ model_package_content = get_model_package_contents(
133
+ model_package_dir=model_name_or_path,
134
+ elements=["config.json", "model.pt"],
135
+ )
136
+
137
+ model_config_file = model_package_content["config.json"]
138
+ model_weights_file = model_package_content["model.pt"]
139
+ config = load_config(model_config_file)
140
+
141
+ model = pe.CLIP.from_config(
142
+ config.vision_encoder_config,
143
+ pretrained=True,
144
+ checkpoint_path=model_weights_file,
145
+ )
146
+ model = model.to(device)
147
+ model.eval()
148
+
149
+ return cls(
150
+ model=model,
151
+ device=device,
152
+ )
153
+
154
+ def embed_images(
155
+ self,
156
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
157
+ input_color_format: Optional[ColorFormat] = None,
158
+ **kwargs,
159
+ ) -> torch.Tensor:
160
+ img_in = self.preprocessor(images, input_color_format=input_color_format).to(
161
+ self.device
162
+ )
163
+
164
+ if self.device.type == "cpu" or self.device.type == "mps":
165
+ with torch.inference_mode():
166
+ image_features, _, _ = self.model(img_in, None)
167
+ embeddings = image_features.float()
168
+ else:
169
+ with torch.inference_mode(), torch.autocast(self.device.type):
170
+ image_features, _, _ = self.model(img_in, None)
171
+ embeddings = image_features.float()
172
+
173
+ return embeddings
174
+
175
+ def embed_text(
176
+ self,
177
+ texts: Union[str, List[str]],
178
+ **kwargs,
179
+ ) -> torch.Tensor:
180
+ if isinstance(texts, list):
181
+ texts_to_embed = texts
182
+ else:
183
+ texts_to_embed = [texts]
184
+
185
+ # results = []
186
+ # The original implementation had batching here based on CLIP_MAX_BATCH_SIZE, but not entirely sure how to handle that with Tensor output
187
+ # I will leave it out for now, see https://github.com/roboflow/inference/blob/main/inference/models/perception_encoder/perception_encoder.py#L227
188
+ tokenized = self.tokenizer(texts_to_embed).to(self.device)
189
+ if self.device.type == "cpu" or self.device.type == "mps":
190
+ with torch.no_grad():
191
+ _, text_features, _ = self.model(None, tokenized)
192
+ else:
193
+ with torch.inference_mode(), torch.autocast(self.device.type):
194
+ _, text_features, _ = self.model(None, tokenized)
195
+
196
+ embeddings = text_features.float()
197
+ return embeddings