inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,281 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, List, Literal, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import Detections
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.entities import ColorFormat, ImageDimensions
10
+ from inference_models.errors import ModelRuntimeError
11
+ from inference_models.models.common.model_packages import get_model_package_contents
12
+ from inference_models.models.common.roboflow.pre_processing import images_to_pillow
13
+ from inference_models.utils.imports import import_class_from_file
14
+
15
+
16
+ @dataclass
17
+ class EncodedImage:
18
+ moondream_encoded_image: Any
19
+ image_dimensions: ImageDimensions
20
+
21
+
22
+ @dataclass
23
+ class Points:
24
+ xy: torch.Tensor
25
+ confidence: torch.Tensor
26
+ class_id: torch.Tensor
27
+
28
+
29
+ class MoonDream2HF:
30
+
31
+ @classmethod
32
+ def from_pretrained(
33
+ cls,
34
+ model_name_or_path: str,
35
+ device: torch.device = DEFAULT_DEVICE,
36
+ **kwargs,
37
+ ) -> "MoonDream2HF":
38
+ if torch.mps.is_available():
39
+ raise ModelRuntimeError(
40
+ message=f"This model cannot run on Apple device with MPS unit - original implementation contains bug "
41
+ f"preventing proper allocation of tensors which causes runtime error. Run this model on the "
42
+ f"machine with Nvidia GPU or x86 CPU.",
43
+ help_url="https://todo",
44
+ )
45
+ model_package_content = get_model_package_contents(
46
+ model_package_dir=model_name_or_path,
47
+ elements=["hf_moondream.py"],
48
+ )
49
+ model_class = import_class_from_file(
50
+ file_path=model_package_content["hf_moondream.py"],
51
+ class_name="HfMoondream",
52
+ )
53
+ model = model_class.from_pretrained(model_name_or_path).to(device)
54
+ return cls(model=model, device=device)
55
+
56
+ def __init__(self, model, device: torch.device):
57
+ self._model = model
58
+ self._device = device
59
+
60
+ def detect(
61
+ self,
62
+ images: Union[
63
+ EncodedImage,
64
+ List[EncodedImage],
65
+ torch.Tensor,
66
+ List[torch.Tensor],
67
+ np.ndarray,
68
+ List[np.ndarray],
69
+ ],
70
+ classes: List[str],
71
+ max_tokens: int = 700,
72
+ input_color_format: Optional[ColorFormat] = None,
73
+ ) -> List[Detections]:
74
+ encoded_images = self.encode_images(
75
+ images=images, input_color_format=input_color_format
76
+ )
77
+ results = []
78
+ for encoded_image in encoded_images:
79
+ image_detections = []
80
+ for class_id, class_name in enumerate(classes):
81
+ class_detections = self._model.detect(
82
+ image=encoded_image.moondream_encoded_image,
83
+ object=class_name,
84
+ settings={"max_tokens": max_tokens},
85
+ )["objects"]
86
+ image_detections.append((class_id, class_detections))
87
+ image_results = post_process_detections(
88
+ raw_detections=image_detections,
89
+ image_dimensions=encoded_image.image_dimensions,
90
+ device=self._device,
91
+ )
92
+ results.append(image_results)
93
+ return results
94
+
95
+ def caption(
96
+ self,
97
+ images: Union[
98
+ EncodedImage,
99
+ List[EncodedImage],
100
+ torch.Tensor,
101
+ List[torch.Tensor],
102
+ np.ndarray,
103
+ List[np.ndarray],
104
+ ],
105
+ length: Literal["normal", "short", "long"] = "normal",
106
+ max_tokens: int = 700,
107
+ input_color_format: Optional[ColorFormat] = None,
108
+ ) -> List[str]:
109
+ encoded_images = self.encode_images(
110
+ images=images, input_color_format=input_color_format
111
+ )
112
+ results = []
113
+ for encoded_image in encoded_images:
114
+ result = self._model.caption(
115
+ image=encoded_image.moondream_encoded_image,
116
+ length=length,
117
+ settings={"max_tokens": max_tokens},
118
+ )
119
+ results.append(result["caption"].strip())
120
+ return results
121
+
122
+ def query(
123
+ self,
124
+ images: Union[
125
+ EncodedImage,
126
+ List[EncodedImage],
127
+ torch.Tensor,
128
+ List[torch.Tensor],
129
+ np.ndarray,
130
+ List[np.ndarray],
131
+ ],
132
+ question: str,
133
+ max_tokens: int = 700,
134
+ input_color_format: Optional[ColorFormat] = None,
135
+ ) -> List[str]:
136
+ encoded_images = self.encode_images(
137
+ images=images, input_color_format=input_color_format
138
+ )
139
+ results = []
140
+ for encoded_image in encoded_images:
141
+ result = self._model.query(
142
+ image=encoded_image.moondream_encoded_image,
143
+ question=question,
144
+ settings={"max_tokens": max_tokens},
145
+ )
146
+ results.append(result["answer"].strip())
147
+ return results
148
+
149
+ def point(
150
+ self,
151
+ images: Union[
152
+ EncodedImage,
153
+ List[EncodedImage],
154
+ torch.Tensor,
155
+ List[torch.Tensor],
156
+ np.ndarray,
157
+ List[np.ndarray],
158
+ ],
159
+ classes: List[str],
160
+ max_tokens: int = 700,
161
+ input_color_format: Optional[ColorFormat] = None,
162
+ ) -> List[Points]:
163
+ encoded_images = self.encode_images(
164
+ images=images, input_color_format=input_color_format
165
+ )
166
+ results = []
167
+ for encoded_image in encoded_images:
168
+ image_points = []
169
+ for class_id, class_name in enumerate(classes):
170
+ class_points = self._model.point(
171
+ image=encoded_image.moondream_encoded_image,
172
+ object=class_name,
173
+ settings={"max_tokens": max_tokens},
174
+ )["points"]
175
+ image_points.append((class_id, class_points))
176
+ image_results = post_process_points(
177
+ raw_points=image_points,
178
+ image_dimensions=encoded_image.image_dimensions,
179
+ device=self._device,
180
+ )
181
+ results.append(image_results)
182
+ return results
183
+
184
+ def encode_images(
185
+ self,
186
+ images: Union[
187
+ EncodedImage,
188
+ List[EncodedImage],
189
+ torch.Tensor,
190
+ List[torch.Tensor],
191
+ np.ndarray,
192
+ List[np.ndarray],
193
+ ],
194
+ input_color_format: Optional[ColorFormat] = None,
195
+ ) -> List[EncodedImage]:
196
+ if are_images_encoded(images=images):
197
+ if not isinstance(images, list):
198
+ return [images]
199
+ return images
200
+ pillow_images, images_dimensions = images_to_pillow(
201
+ images=images,
202
+ input_color_format=input_color_format,
203
+ model_color_format="rgb",
204
+ )
205
+ result = []
206
+ for image, image_dimensions in zip(pillow_images, images_dimensions):
207
+ moondream_encoded = self._model.encode_image(image)
208
+ result.append(
209
+ EncodedImage(
210
+ moondream_encoded_image=moondream_encoded,
211
+ image_dimensions=image_dimensions,
212
+ )
213
+ )
214
+ return result
215
+
216
+
217
+ def are_images_encoded(
218
+ images: Union[
219
+ EncodedImage,
220
+ List[EncodedImage],
221
+ torch.Tensor,
222
+ List[torch.Tensor],
223
+ np.ndarray,
224
+ List[np.ndarray],
225
+ ],
226
+ ) -> bool:
227
+ if isinstance(images, list):
228
+ if not len(images):
229
+ raise ModelRuntimeError(
230
+ message="Detected empty input to the model", help_url="https://todo"
231
+ )
232
+ return isinstance(images[0], EncodedImage)
233
+ return isinstance(images, EncodedImage)
234
+
235
+
236
+ def post_process_detections(
237
+ raw_detections: List[Tuple[int, List[dict]]],
238
+ image_dimensions: ImageDimensions,
239
+ device: torch.device,
240
+ ) -> Detections:
241
+ xyxy, confidence, class_id = [], [], []
242
+ for detection_class_id, raw_class_detections in raw_detections:
243
+ for raw_detection in raw_class_detections:
244
+ xyxy.append(
245
+ [
246
+ raw_detection["x_min"] * image_dimensions.width,
247
+ raw_detection["y_min"] * image_dimensions.height,
248
+ raw_detection["x_max"] * image_dimensions.width,
249
+ raw_detection["y_max"] * image_dimensions.height,
250
+ ]
251
+ )
252
+ class_id.append(detection_class_id)
253
+ confidence.append(1.0)
254
+ return Detections(
255
+ xyxy=torch.tensor(xyxy, device=device).round().int(),
256
+ class_id=torch.tensor(class_id, device=device).int(),
257
+ confidence=torch.tensor(confidence, device=device),
258
+ )
259
+
260
+
261
+ def post_process_points(
262
+ raw_points: List[Tuple[int, List[dict]]],
263
+ image_dimensions: ImageDimensions,
264
+ device: torch.device,
265
+ ) -> Points:
266
+ xy, confidence, class_id = [], [], []
267
+ for point_class_id, raw_class_points in raw_points:
268
+ for raw_point in raw_class_points:
269
+ xy.append(
270
+ [
271
+ raw_point["x"] * image_dimensions.width,
272
+ raw_point["y"] * image_dimensions.height,
273
+ ]
274
+ )
275
+ class_id.append(point_class_id)
276
+ confidence.append(1.0)
277
+ return Points(
278
+ xy=torch.tensor(xy, device=device).round().int(),
279
+ class_id=torch.tensor(class_id, device=device).int(),
280
+ confidence=torch.tensor(confidence, device=device),
281
+ )
File without changes
@@ -0,0 +1,182 @@
1
+ import hashlib
2
+ import json
3
+ from abc import ABC, abstractmethod
4
+ from collections import OrderedDict
5
+ from threading import Lock
6
+ from typing import Dict, List, Optional
7
+
8
+ import torch
9
+
10
+ from inference_models.errors import EnvironmentConfigurationError
11
+ from inference_models.models.owlv2.entities import (
12
+ ImageEmbeddings,
13
+ LazyReferenceExample,
14
+ ReferenceExamplesClassEmbeddings,
15
+ )
16
+
17
+
18
+ class OwlV2ClassEmbeddingsCache(ABC):
19
+
20
+ @abstractmethod
21
+ def retrieve_embeddings(
22
+ self, key: str
23
+ ) -> Optional[Dict[str, ReferenceExamplesClassEmbeddings]]:
24
+ pass
25
+
26
+ @abstractmethod
27
+ def save_embeddings(
28
+ self, key: str, embeddings: Dict[str, ReferenceExamplesClassEmbeddings]
29
+ ) -> None:
30
+ pass
31
+
32
+
33
+ class OwlV2ClassEmbeddingsCacheNullObject(OwlV2ClassEmbeddingsCache):
34
+
35
+ def retrieve_embeddings(
36
+ self, key: str
37
+ ) -> Optional[Dict[str, ReferenceExamplesClassEmbeddings]]:
38
+ return None
39
+
40
+ def save_embeddings(
41
+ self, key: str, embeddings: Dict[str, ReferenceExamplesClassEmbeddings]
42
+ ) -> None:
43
+ pass
44
+
45
+
46
+ class InMemoryOwlV2ClassEmbeddingsCache(OwlV2ClassEmbeddingsCache):
47
+
48
+ @classmethod
49
+ def init(
50
+ cls, size_limit: Optional[int], send_to_cpu: bool = True
51
+ ) -> "InMemoryOwlV2ClassEmbeddingsCache":
52
+ return cls(
53
+ state=OrderedDict(),
54
+ size_limit=size_limit,
55
+ send_to_cpu=send_to_cpu,
56
+ )
57
+
58
+ def __init__(
59
+ self,
60
+ state: OrderedDict,
61
+ size_limit: Optional[int],
62
+ send_to_cpu: bool = True,
63
+ ):
64
+ self._state = state
65
+ self._size_limit = size_limit
66
+ self._send_to_cpu = send_to_cpu
67
+ self._state_lock = Lock()
68
+
69
+ def retrieve_embeddings(
70
+ self, key: str
71
+ ) -> Optional[Dict[str, ReferenceExamplesClassEmbeddings]]:
72
+ return self._state.get(key)
73
+
74
+ def save_embeddings(
75
+ self, key: str, embeddings: Dict[str, ReferenceExamplesClassEmbeddings]
76
+ ) -> None:
77
+ with self._state_lock:
78
+ if key in self._state:
79
+ return None
80
+ self._ensure_cache_has_capacity()
81
+ if self._send_to_cpu:
82
+ embeddings = {
83
+ k: v.to(device=torch.device("cpu")) for k, v in embeddings.items()
84
+ }
85
+ self._state[key] = embeddings
86
+
87
+ def _ensure_cache_has_capacity(self) -> None:
88
+ if self._size_limit < 1:
89
+ raise EnvironmentConfigurationError(
90
+ message=f"In memory cache size for OWLv2 embeddings was set to invalid value. "
91
+ f"If you are running inference locally - adjust settings of your deployment. If you see this "
92
+ f"error running on Roboflow platform - contact us to get help.",
93
+ help_url="https://todo",
94
+ )
95
+ if self._size_limit is None or self._size_limit < 1:
96
+ return None
97
+ while len(self._state) > self._size_limit:
98
+ _ = self._state.popitem(last=False)
99
+
100
+
101
+ class OwlV2ImageEmbeddingsCache(ABC):
102
+
103
+ @abstractmethod
104
+ def retrieve_embeddings(self, key: str) -> Optional[ImageEmbeddings]:
105
+ pass
106
+
107
+ @abstractmethod
108
+ def save_embeddings(self, embeddings: ImageEmbeddings) -> None:
109
+ pass
110
+
111
+
112
+ class OwlV2ImageEmbeddingsCacheNullObject(OwlV2ImageEmbeddingsCache):
113
+
114
+ def retrieve_embeddings(self, key: str) -> Optional[ImageEmbeddings]:
115
+ return None
116
+
117
+ def save_embeddings(self, embeddings: ImageEmbeddings) -> None:
118
+ pass
119
+
120
+
121
+ class InMemoryOwlV2ImageEmbeddingsCache(OwlV2ImageEmbeddingsCache):
122
+
123
+ @classmethod
124
+ def init(
125
+ cls, size_limit: Optional[int], send_to_cpu: bool = True
126
+ ) -> "InMemoryOwlV2ImageEmbeddingsCache":
127
+ return cls(
128
+ state=OrderedDict(),
129
+ size_limit=size_limit,
130
+ send_to_cpu=send_to_cpu,
131
+ )
132
+
133
+ def __init__(
134
+ self,
135
+ state: OrderedDict,
136
+ size_limit: Optional[int],
137
+ send_to_cpu: bool = True,
138
+ ):
139
+ self._state = state
140
+ self._size_limit = size_limit
141
+ self._send_to_cpu = send_to_cpu
142
+ self._state_lock = Lock()
143
+
144
+ def retrieve_embeddings(self, key: str) -> Optional[ImageEmbeddings]:
145
+ return self._state.get(key)
146
+
147
+ def save_embeddings(self, embeddings: ImageEmbeddings) -> None:
148
+ with self._state_lock:
149
+ if embeddings.image_hash in self._state:
150
+ return None
151
+ self._ensure_cache_has_capacity()
152
+ if self._send_to_cpu:
153
+ embeddings = embeddings.to(device=torch.device("cpu"))
154
+ self._state[embeddings.image_hash] = embeddings
155
+
156
+ def _ensure_cache_has_capacity(self) -> None:
157
+ if self._size_limit < 1:
158
+ raise EnvironmentConfigurationError(
159
+ message=f"In memory cache size for OWLv2 embeddings was set to invalid value. "
160
+ f"If you are running inference locally - adjust settings of your deployment. If you see this "
161
+ f"error running on Roboflow platform - contact us to get help.",
162
+ help_url="https://todo",
163
+ )
164
+ if self._size_limit is None or self._size_limit < 1:
165
+ return None
166
+ while len(self._state) > self._size_limit:
167
+ _ = self._state.popitem(last=False)
168
+
169
+
170
+ def hash_reference_examples(reference_examples: List[LazyReferenceExample]) -> str:
171
+ result = hashlib.sha1()
172
+ for example in reference_examples:
173
+ image_hash = example.image.get_hash()
174
+ result.update(image_hash.encode())
175
+ bboxes_hash_base = "---".join(
176
+ [
177
+ json.dumps(box.model_dump(), sort_keys=True, separators=(",", ":"))
178
+ for box in example.boxes
179
+ ]
180
+ )
181
+ result.update(bboxes_hash_base.encode())
182
+ return result.hexdigest()
@@ -0,0 +1,112 @@
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from pydantic import BaseModel, ConfigDict, Field
7
+
8
+ from inference_models.models.owlv2.reference_dataset import LazyImageWrapper
9
+
10
+ POSITIVE_EXAMPLE = "positive"
11
+ NEGATIVE_EXAMPLE = "negative"
12
+
13
+
14
+ class ReferenceBoundingBox(BaseModel):
15
+ x: Union[float, int]
16
+ y: Union[float, int]
17
+ w: Union[float, int]
18
+ h: Union[float, int]
19
+ cls: str
20
+ negative: bool = Field(default=False)
21
+ absolute: bool = Field(default=True)
22
+
23
+ def to_tuple(
24
+ self, image_wh: Optional[Tuple[int, int]] = None
25
+ ) -> Tuple[
26
+ Union[int, float], Union[int, float], Union[int, float], Union[int, float]
27
+ ]:
28
+ if image_wh is None or self.absolute is False:
29
+ return self.x, self.y, self.w, self.h
30
+ max_dim = max(image_wh)
31
+ return (
32
+ self.x / max_dim,
33
+ self.y / max_dim,
34
+ self.w / max_dim,
35
+ self.h / max_dim,
36
+ )
37
+
38
+
39
+ class ReferenceExample(BaseModel):
40
+ model_config = ConfigDict(arbitrary_types_allowed=True)
41
+
42
+ image: Union[np.ndarray, torch.Tensor, str, bytes]
43
+ boxes: List[ReferenceBoundingBox]
44
+
45
+
46
+ @dataclass(frozen=True)
47
+ class LazyReferenceExample:
48
+ image: LazyImageWrapper
49
+ boxes: List[ReferenceBoundingBox]
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class ImageEmbeddings:
54
+ image_hash: str
55
+ objectness: torch.Tensor
56
+ boxes: torch.Tensor
57
+ image_class_embeddings: torch.Tensor
58
+ logit_shift: torch.Tensor
59
+ logit_scale: torch.Tensor
60
+ image_size_wh: Tuple[int, int]
61
+
62
+ def to(self, device: torch.device) -> "ImageEmbeddings":
63
+ return ImageEmbeddings(
64
+ image_hash=self.image_hash,
65
+ objectness=self.objectness.to(device=device),
66
+ boxes=self.boxes.to(device=device),
67
+ image_class_embeddings=self.image_class_embeddings.to(device=device),
68
+ logit_shift=self.logit_shift.to(device=device),
69
+ logit_scale=self.logit_scale.to(device=device),
70
+ image_size_wh=self.image_size_wh,
71
+ )
72
+
73
+
74
+ @dataclass(frozen=True)
75
+ class ReferenceExamplesClassEmbeddings:
76
+ positive: Optional[torch.Tensor]
77
+ negative: Optional[torch.Tensor]
78
+
79
+ def to(self, device: torch.device) -> "ReferenceExamplesClassEmbeddings":
80
+ return ReferenceExamplesClassEmbeddings(
81
+ positive=(
82
+ self.positive.to(device=device) if self.positive is not None else None
83
+ ),
84
+ negative=(
85
+ self.negative.to(device=device) if self.negative is not None else None
86
+ ),
87
+ )
88
+
89
+
90
+ @dataclass(frozen=True)
91
+ class ReferenceExamplesEmbeddings:
92
+ class_embeddings: Dict[str, ReferenceExamplesClassEmbeddings]
93
+ image_embeddings: Optional[Dict[str, ImageEmbeddings]]
94
+
95
+ @classmethod
96
+ def from_class_embeddings_dict(
97
+ cls,
98
+ class_embeddings: Dict[str, Dict[str, torch.Tensor]],
99
+ device: torch.device,
100
+ ) -> "ReferenceExamplesEmbeddings":
101
+ result = {}
102
+ for class_name, examples_class_embeddings in class_embeddings.items():
103
+ positive, negative = None, None
104
+ if POSITIVE_EXAMPLE in examples_class_embeddings:
105
+ positive = examples_class_embeddings[POSITIVE_EXAMPLE]
106
+ if NEGATIVE_EXAMPLE in examples_class_embeddings:
107
+ negative = examples_class_embeddings[NEGATIVE_EXAMPLE]
108
+ single_class_embeddings = ReferenceExamplesClassEmbeddings(
109
+ positive=positive, negative=negative
110
+ ).to(device=device)
111
+ result[class_name] = single_class_embeddings
112
+ return cls(class_embeddings=result, image_embeddings=None)