inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,695 @@
1
+ from collections import defaultdict
2
+ from typing import Dict, Iterable, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+ from transformers import Owlv2ForObjectDetection, Owlv2Processor
8
+ from transformers.models.owlv2.modeling_owlv2 import Owlv2ObjectDetectionOutput, box_iou
9
+
10
+ from inference_models import Detections, OpenVocabularyObjectDetectionModel
11
+ from inference_models.configuration import (
12
+ ALLOW_LOCAL_STORAGE_ACCESS_FOR_REFERENCE_DATA,
13
+ ALLOW_NON_HTTPS_URL_INPUT,
14
+ ALLOW_URL_INPUT,
15
+ ALLOW_URL_INPUT_WITHOUT_FQDN,
16
+ BLACKLISTED_DESTINATIONS_FOR_URL_INPUT,
17
+ DEFAULT_DEVICE,
18
+ WHITELISTED_DESTINATIONS_FOR_URL_INPUT,
19
+ )
20
+ from inference_models.entities import ImageDimensions
21
+ from inference_models.errors import ModelInputError
22
+ from inference_models.models.base.types import PreprocessedInputs, PreprocessingMetadata
23
+ from inference_models.models.common.roboflow.pre_processing import (
24
+ extract_input_images_dimensions,
25
+ )
26
+ from inference_models.models.owlv2.cache import (
27
+ OwlV2ClassEmbeddingsCache,
28
+ OwlV2ClassEmbeddingsCacheNullObject,
29
+ OwlV2ImageEmbeddingsCache,
30
+ OwlV2ImageEmbeddingsCacheNullObject,
31
+ hash_reference_examples,
32
+ )
33
+ from inference_models.models.owlv2.entities import (
34
+ NEGATIVE_EXAMPLE,
35
+ POSITIVE_EXAMPLE,
36
+ ImageEmbeddings,
37
+ LazyReferenceExample,
38
+ ReferenceExample,
39
+ ReferenceExamplesClassEmbeddings,
40
+ ReferenceExamplesEmbeddings,
41
+ )
42
+ from inference_models.models.owlv2.reference_dataset import (
43
+ LazyImageWrapper,
44
+ compute_image_hash,
45
+ )
46
+
47
+ Query = Dict[
48
+ str,
49
+ Tuple[Union[int, float], Union[int, float], Union[int, float], Union[int, float]],
50
+ ]
51
+
52
+
53
+ class OWLv2HF(
54
+ OpenVocabularyObjectDetectionModel[
55
+ torch.Tensor, List[ImageDimensions], Owlv2ObjectDetectionOutput
56
+ ]
57
+ ):
58
+
59
+ @classmethod
60
+ def from_pretrained(
61
+ cls,
62
+ model_name_or_path: str,
63
+ device: torch.device = DEFAULT_DEVICE,
64
+ local_files_only: bool = True,
65
+ owlv2_class_embeddings_cache: Optional[OwlV2ClassEmbeddingsCache] = None,
66
+ owlv2_images_embeddings_cache: Optional[OwlV2ImageEmbeddingsCache] = None,
67
+ allow_url_input: bool = ALLOW_URL_INPUT,
68
+ allow_non_https_url: bool = ALLOW_NON_HTTPS_URL_INPUT,
69
+ allow_url_without_fqdn: bool = ALLOW_URL_INPUT_WITHOUT_FQDN,
70
+ whitelisted_domains: Optional[List[str]] = None,
71
+ blacklisted_domains: Optional[List[str]] = None,
72
+ allow_local_storage_access_for_reference_images: bool = ALLOW_LOCAL_STORAGE_ACCESS_FOR_REFERENCE_DATA,
73
+ owlv2_enforce_model_compilation: bool = False,
74
+ **kwargs,
75
+ ) -> "OpenVocabularyObjectDetectionModel":
76
+ if owlv2_class_embeddings_cache is None:
77
+ owlv2_class_embeddings_cache = OwlV2ClassEmbeddingsCacheNullObject()
78
+ if owlv2_images_embeddings_cache is None:
79
+ owlv2_images_embeddings_cache = OwlV2ImageEmbeddingsCacheNullObject()
80
+ if whitelisted_domains is None:
81
+ whitelisted_domains = WHITELISTED_DESTINATIONS_FOR_URL_INPUT
82
+ if blacklisted_domains is None:
83
+ blacklisted_domains = BLACKLISTED_DESTINATIONS_FOR_URL_INPUT
84
+ processor = Owlv2Processor.from_pretrained(
85
+ model_name_or_path,
86
+ local_files_only=local_files_only,
87
+ use_fast=True,
88
+ )
89
+ model = Owlv2ForObjectDetection.from_pretrained(
90
+ model_name_or_path,
91
+ local_files_only=local_files_only,
92
+ ).to(device)
93
+ instance = cls(
94
+ model=model,
95
+ processor=processor,
96
+ device=device,
97
+ owlv2_class_embeddings_cache=owlv2_class_embeddings_cache,
98
+ owlv2_images_embeddings_cache=owlv2_images_embeddings_cache,
99
+ allow_url_input=allow_url_input,
100
+ allow_non_https_url=allow_non_https_url,
101
+ allow_url_without_fqdn=allow_url_without_fqdn,
102
+ whitelisted_domains=whitelisted_domains,
103
+ blacklisted_domains=blacklisted_domains,
104
+ allow_local_storage_access_for_reference_images=allow_local_storage_access_for_reference_images,
105
+ )
106
+ if owlv2_enforce_model_compilation:
107
+ instance.optimize_for_inference()
108
+ return instance
109
+
110
+ def __init__(
111
+ self,
112
+ model: Owlv2ForObjectDetection,
113
+ processor: Owlv2Processor,
114
+ device: torch.device,
115
+ owlv2_class_embeddings_cache: OwlV2ClassEmbeddingsCache,
116
+ owlv2_images_embeddings_cache: OwlV2ImageEmbeddingsCache,
117
+ allow_url_input: bool,
118
+ allow_non_https_url: bool,
119
+ allow_url_without_fqdn: bool,
120
+ whitelisted_domains: Optional[List[str]],
121
+ blacklisted_domains: Optional[List[str]],
122
+ allow_local_storage_access_for_reference_images: bool,
123
+ ):
124
+ self._model = model
125
+ self._processor = processor
126
+ self._device = device
127
+ self._owlv2_class_embeddings_cache = owlv2_class_embeddings_cache
128
+ self._owlv2_images_embeddings_cache = owlv2_images_embeddings_cache
129
+ self._allow_url_input = allow_url_input
130
+ self._allow_non_https_url = allow_non_https_url
131
+ self._allow_url_without_fqdn = allow_url_without_fqdn
132
+ self._whitelisted_domains = whitelisted_domains
133
+ self._blacklisted_domains = blacklisted_domains
134
+ self._allow_local_storage_access_for_reference_images = (
135
+ allow_local_storage_access_for_reference_images
136
+ )
137
+ self._compiled = False
138
+
139
+ def optimize_for_inference(self) -> None:
140
+ if self._compiled:
141
+ return None
142
+ self._model.owlv2.vision_model = torch.compile(self._model.owlv2.vision_model)
143
+ example_image = torch.randint(
144
+ low=0, high=255, size=(3, 128, 128), dtype=torch.uint8
145
+ ).to(self._device)
146
+ _ = self.infer(example_image, ["some", "other"])
147
+ self._compiled = True
148
+
149
+ def pre_process(
150
+ self,
151
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
152
+ **kwargs,
153
+ ) -> Tuple[PreprocessedInputs, PreprocessingMetadata]:
154
+ image_dimensions = extract_input_images_dimensions(images=images)
155
+ inputs = self._processor(images=images, return_tensors="pt")
156
+ return inputs["pixel_values"].to(self._device), image_dimensions
157
+
158
+ def forward(
159
+ self,
160
+ pre_processed_images: torch.Tensor,
161
+ classes: List[str],
162
+ **kwargs,
163
+ ) -> Owlv2ObjectDetectionOutput:
164
+ input_ids = self._processor(text=[classes], return_tensors="pt")[
165
+ "input_ids"
166
+ ].to(self._device)
167
+ with torch.inference_mode():
168
+ return self._model(input_ids=input_ids, pixel_values=pre_processed_images)
169
+
170
+ def post_process(
171
+ self,
172
+ model_results: Owlv2ObjectDetectionOutput,
173
+ pre_processing_meta: List[ImageDimensions],
174
+ conf_thresh: float = 0.1,
175
+ iou_thresh: float = 0.45,
176
+ class_agnostic: bool = False,
177
+ max_detections: int = 100,
178
+ **kwargs,
179
+ ) -> List[Detections]:
180
+ target_sizes = [(dim.height, dim.width) for dim in pre_processing_meta]
181
+ post_processed_outputs = self._processor.post_process_grounded_object_detection(
182
+ outputs=model_results,
183
+ target_sizes=target_sizes,
184
+ threshold=conf_thresh,
185
+ )
186
+ results = []
187
+ for i in range(len(post_processed_outputs)):
188
+ boxes, scores, labels = (
189
+ post_processed_outputs[i]["boxes"],
190
+ post_processed_outputs[i]["scores"],
191
+ post_processed_outputs[i]["labels"],
192
+ )
193
+ nms_class_ids = torch.zeros_like(labels) if class_agnostic else labels
194
+ keep = torchvision.ops.batched_nms(boxes, scores, nms_class_ids, iou_thresh)
195
+ keep = keep[:max_detections]
196
+ results.append(
197
+ Detections(
198
+ xyxy=boxes[keep].contiguous().int(),
199
+ class_id=labels[keep].contiguous().int(),
200
+ confidence=scores[keep].contiguous(),
201
+ )
202
+ )
203
+ return results
204
+
205
+ def infer_with_reference_examples(
206
+ self,
207
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
208
+ reference_examples: List[ReferenceExample],
209
+ confidence_threshold: float = 0.99,
210
+ iou_threshold: float = 0.3,
211
+ max_detections: int = 300,
212
+ ) -> List[Detections]:
213
+ reference_embeddings = self.prepare_reference_examples_embeddings(
214
+ reference_examples=reference_examples,
215
+ iou_threshold=iou_threshold,
216
+ )
217
+ return self.infer_with_reference_examples_embeddings(
218
+ images=images,
219
+ class_embeddings=reference_embeddings.class_embeddings,
220
+ confidence_threshold=confidence_threshold,
221
+ iou_threshold=iou_threshold,
222
+ max_detections=max_detections,
223
+ )
224
+
225
+ def infer_with_reference_examples_embeddings(
226
+ self,
227
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
228
+ class_embeddings: Dict[str, ReferenceExamplesClassEmbeddings],
229
+ confidence_threshold: float = 0.99,
230
+ iou_threshold: float = 0.3,
231
+ max_detections: int = 300,
232
+ ) -> List[Detections]:
233
+ images_embeddings, images_dimensions = self.embed_images(
234
+ images=images, max_detections=max_detections
235
+ )
236
+ images_predictions = self.forward_pass_with_precomputed_embeddings(
237
+ images_embeddings=images_embeddings,
238
+ class_embeddings=class_embeddings,
239
+ confidence_threshold=confidence_threshold,
240
+ iou_threshold=iou_threshold,
241
+ )
242
+ return self.post_process_predictions_for_precomputed_embeddings(
243
+ predictions=images_predictions,
244
+ images_dimensions=images_dimensions,
245
+ max_detections=max_detections,
246
+ iou_threshold=iou_threshold,
247
+ )
248
+
249
+ def forward_pass_with_precomputed_embeddings(
250
+ self,
251
+ images_embeddings: List[ImageEmbeddings],
252
+ class_embeddings: Dict[str, ReferenceExamplesClassEmbeddings],
253
+ confidence_threshold: float = 0.99,
254
+ iou_threshold: float = 0.3,
255
+ ) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
256
+ results = []
257
+ for image_embedding in images_embeddings:
258
+ image_embedding = image_embedding.to(self._device)
259
+ class_mapping, class_names = make_class_mapping(
260
+ class_names=class_embeddings.keys()
261
+ )
262
+ all_predicted_boxes, all_predicted_classes, all_predicted_scores = (
263
+ [],
264
+ [],
265
+ [],
266
+ )
267
+ for (
268
+ class_name,
269
+ reference_examples_class_embeddings,
270
+ ) in class_embeddings.items():
271
+ boxes, classes, scores = get_class_predictions_from_embedings(
272
+ reference_examples_class_embeddings=reference_examples_class_embeddings,
273
+ image_class_embeddings=image_embedding.image_class_embeddings,
274
+ image_boxes=image_embedding.boxes,
275
+ confidence_threshold=confidence_threshold,
276
+ class_mapping=class_mapping,
277
+ class_name=class_name,
278
+ iou_threshold=iou_threshold,
279
+ )
280
+ all_predicted_boxes.append(boxes)
281
+ all_predicted_classes.append(classes)
282
+ all_predicted_scores.append(scores)
283
+ if not all_predicted_boxes:
284
+ results.append(
285
+ (torch.empty((0,)), torch.empty((0,)), torch.empty((0,)))
286
+ )
287
+ continue
288
+ all_predicted_boxes = torch.cat(all_predicted_boxes, dim=0)
289
+ all_predicted_classes = torch.cat(all_predicted_classes, dim=0)
290
+ all_predicted_scores = torch.cat(all_predicted_scores, dim=0)
291
+ results.append(
292
+ (all_predicted_boxes, all_predicted_classes, all_predicted_scores)
293
+ )
294
+ return results
295
+
296
+ def post_process_predictions_for_precomputed_embeddings(
297
+ self,
298
+ predictions: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
299
+ images_dimensions: List[ImageDimensions],
300
+ max_detections: int = 300,
301
+ iou_threshold: float = 0.3,
302
+ ) -> List[Detections]:
303
+ results = []
304
+ for image_predictions, image_dimensions in zip(predictions, images_dimensions):
305
+ all_predicted_boxes, all_predicted_classes, all_predicted_scores = (
306
+ image_predictions
307
+ )
308
+ if all_predicted_boxes.numel() == 0:
309
+ results.append(
310
+ Detections(
311
+ xyxy=torch.empty(
312
+ (0, 4), dtype=torch.int32, device=self._device
313
+ ),
314
+ confidence=torch.empty(
315
+ (0,), dtype=torch.float32, device=self._device
316
+ ),
317
+ class_id=torch.empty(
318
+ (0,), dtype=torch.int32, device=self._device
319
+ ),
320
+ )
321
+ )
322
+ continue
323
+ survival_indices = torchvision.ops.nms(
324
+ to_corners(all_predicted_boxes), all_predicted_scores, iou_threshold
325
+ )
326
+ all_predicted_boxes = all_predicted_boxes[survival_indices]
327
+ all_predicted_classes = all_predicted_classes[survival_indices]
328
+ all_predicted_scores = all_predicted_scores[survival_indices]
329
+ if len(all_predicted_boxes) > max_detections:
330
+ all_predicted_boxes = all_predicted_boxes[:max_detections]
331
+ all_predicted_classes = all_predicted_classes[:max_detections]
332
+ all_predicted_scores = all_predicted_scores[:max_detections]
333
+ xyxy = xywh_normalized_to_xyxy(
334
+ boxes_xywh=all_predicted_boxes,
335
+ image_size_wh=(image_dimensions.width, image_dimensions.height),
336
+ )
337
+ results.append(
338
+ Detections(
339
+ xyxy=xyxy.int(),
340
+ confidence=all_predicted_scores,
341
+ class_id=all_predicted_classes.int(),
342
+ )
343
+ )
344
+ return results
345
+
346
+ def prepare_reference_examples_embeddings(
347
+ self,
348
+ reference_examples: List[ReferenceExample],
349
+ iou_threshold: float,
350
+ return_image_embeddings: bool = False,
351
+ ) -> ReferenceExamplesEmbeddings:
352
+ lazy_reference_examples = [
353
+ LazyReferenceExample(
354
+ image=LazyImageWrapper.init(
355
+ image=example.image,
356
+ allow_url_input=self._allow_url_input,
357
+ allow_non_https_url=self._allow_non_https_url,
358
+ allow_url_without_fqdn=self._allow_url_without_fqdn,
359
+ whitelisted_domains=self._whitelisted_domains,
360
+ blacklisted_domains=self._blacklisted_domains,
361
+ allow_local_storage_access=self._allow_local_storage_access_for_reference_images,
362
+ ),
363
+ boxes=example.boxes,
364
+ )
365
+ for example in reference_examples
366
+ ]
367
+ examples_hash_key = hash_reference_examples(
368
+ reference_examples=lazy_reference_examples
369
+ )
370
+ cached_embeddings = self._owlv2_class_embeddings_cache.retrieve_embeddings(
371
+ key=examples_hash_key
372
+ )
373
+ if cached_embeddings is not None and not return_image_embeddings:
374
+ cached_embeddings = {
375
+ k: v.to(self._device) for k, v in cached_embeddings.items()
376
+ }
377
+ return ReferenceExamplesEmbeddings(
378
+ class_embeddings=cached_embeddings,
379
+ image_embeddings=None,
380
+ )
381
+ class_embeddings_dict = defaultdict(
382
+ lambda: {POSITIVE_EXAMPLE: [], NEGATIVE_EXAMPLE: []}
383
+ )
384
+ bool_to_literal = {True: POSITIVE_EXAMPLE, False: NEGATIVE_EXAMPLE}
385
+ image_embeddings_to_be_returned = {}
386
+ for reference_example in lazy_reference_examples:
387
+ image_embeddings = self.embed_image(image=reference_example.image)
388
+ if return_image_embeddings:
389
+ image_embeddings_to_be_returned[image_embeddings.image_hash] = (
390
+ image_embeddings
391
+ )
392
+ coordinates = [
393
+ bbox.to_tuple(image_wh=image_embeddings.image_size_wh)
394
+ for bbox in reference_example.boxes
395
+ ]
396
+ classes = [box.cls for box in reference_example.boxes]
397
+ is_positive = [not box.negative for box in reference_example.boxes]
398
+ query = {image_embeddings.image_hash: coordinates}
399
+ image_class_embeddings_matching_query = self.query_images_for_bboxes(
400
+ query=query,
401
+ images_embeddings={image_embeddings.image_hash: image_embeddings},
402
+ iou_threshold=iou_threshold,
403
+ )
404
+ if image_class_embeddings_matching_query is None:
405
+ continue
406
+ for embedding, class_name, is_pos in zip(
407
+ image_class_embeddings_matching_query, classes, is_positive
408
+ ):
409
+ class_embeddings_dict[class_name][bool_to_literal[is_pos]].append(
410
+ embedding
411
+ )
412
+ class_embeddings = {
413
+ class_name: ReferenceExamplesClassEmbeddings(
414
+ positive=(
415
+ torch.stack(embeddings[POSITIVE_EXAMPLE])
416
+ if embeddings[POSITIVE_EXAMPLE]
417
+ else None
418
+ ),
419
+ negative=(
420
+ torch.stack(embeddings[NEGATIVE_EXAMPLE])
421
+ if embeddings[NEGATIVE_EXAMPLE]
422
+ else None
423
+ ),
424
+ )
425
+ for class_name, embeddings in class_embeddings_dict.items()
426
+ }
427
+ self._owlv2_class_embeddings_cache.save_embeddings(
428
+ key=examples_hash_key, embeddings=class_embeddings
429
+ )
430
+ return ReferenceExamplesEmbeddings(
431
+ class_embeddings=class_embeddings,
432
+ image_embeddings=(
433
+ image_embeddings_to_be_returned if return_image_embeddings else None
434
+ ),
435
+ )
436
+
437
+ @torch.inference_mode()
438
+ def query_images_for_bboxes(
439
+ self,
440
+ query: Query,
441
+ images_embeddings: Dict[str, ImageEmbeddings],
442
+ iou_threshold: float,
443
+ ) -> Optional[torch.Tensor]:
444
+ query_embeddings = []
445
+ for image_hash, query_boxes in query.items():
446
+ image_embeddings = images_embeddings.get(image_hash)
447
+ if image_embeddings is None:
448
+ raise ModelInputError(
449
+ message="Could not find image embeddings matching bounding boxes query for OWLv2 model. This "
450
+ "means that most likely, model API was used incorrectly.",
451
+ help_url="https://todo",
452
+ )
453
+ image_embeddings = image_embeddings.to(self._device)
454
+ query_boxes_tensor = torch.tensor(
455
+ query_boxes,
456
+ dtype=image_embeddings.boxes.dtype,
457
+ device=self._device,
458
+ )
459
+ if image_embeddings.boxes.numel() == 0 or query_boxes_tensor.numel() == 0:
460
+ continue
461
+ iou, _ = box_iou(
462
+ boxes1=to_corners(image_embeddings.boxes),
463
+ boxes2=to_corners(query_boxes_tensor),
464
+ ) # 3000, k
465
+ ious, indices = torch.max(iou, dim=0)
466
+ # filter for only iou > 0.4
467
+ iou_mask = ious > iou_threshold
468
+ indices = indices[iou_mask]
469
+ if not indices.numel() > 0:
470
+ continue
471
+ matching_image_embeddings = image_embeddings.image_class_embeddings[indices]
472
+ query_embeddings.append(matching_image_embeddings)
473
+ if not query_embeddings:
474
+ return None
475
+ return torch.cat(query_embeddings, dim=0)
476
+
477
+ def embed_images(
478
+ self,
479
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
480
+ max_detections: int = 300,
481
+ ) -> Tuple[List[ImageEmbeddings], List[ImageDimensions]]:
482
+ if isinstance(images, torch.Tensor):
483
+ if len(images.shape) == 3:
484
+ images = [images]
485
+ else:
486
+ images = torch.unbind(images, dim=0)
487
+ elif not isinstance(images, list):
488
+ images = [images]
489
+ results = []
490
+ image_dimensions = []
491
+ for image in images:
492
+ image_embedding = self.embed_image(
493
+ image=image, max_detections=max_detections
494
+ )
495
+ results.append(image_embedding)
496
+ image_dimensions.append(
497
+ ImageDimensions(
498
+ height=image_embedding.image_size_wh[1],
499
+ width=image_embedding.image_size_wh[0],
500
+ )
501
+ )
502
+ return results, image_dimensions
503
+
504
+ @torch.inference_mode()
505
+ def embed_image(
506
+ self,
507
+ image: Union[torch.Tensor, np.ndarray, LazyImageWrapper],
508
+ max_detections: int = 300,
509
+ unload_after_use: bool = True,
510
+ ) -> ImageEmbeddings:
511
+ if isinstance(image, LazyImageWrapper):
512
+ image_hash = image.get_hash()
513
+ image_instance = image.as_numpy()
514
+ if unload_after_use:
515
+ image.unload_image()
516
+ else:
517
+ image_hash = compute_image_hash(image=image)
518
+ image_instance = image
519
+ cached_embeddings = self._owlv2_images_embeddings_cache.retrieve_embeddings(
520
+ key=image_hash
521
+ )
522
+ if cached_embeddings:
523
+ return cached_embeddings
524
+ pixel_values, image_dimensions = self.pre_process(image_instance)
525
+ device_type = self._device.type
526
+ with torch.autocast(
527
+ device_type=device_type, dtype=torch.float16, enabled=device_type == "cuda"
528
+ ):
529
+ image_embeds, *_ = self._model.image_embedder(pixel_values=pixel_values)
530
+ batch_size, h, w, dim = image_embeds.shape
531
+ image_features = image_embeds.reshape(batch_size, h * w, dim)
532
+ objectness = self._model.objectness_predictor(image_features)
533
+ boxes = self._model.box_predictor(image_features, feature_map=image_embeds)
534
+ image_class_embeddings = self._model.class_head.dense0(image_features)
535
+ image_class_embeddings /= (
536
+ torch.linalg.norm(image_class_embeddings, ord=2, dim=-1, keepdim=True)
537
+ + 1e-6
538
+ )
539
+ logit_shift = self._model.class_head.logit_shift(image_features)
540
+ logit_scale = (
541
+ self._model.class_head.elu(
542
+ self._model.class_head.logit_scale(image_features)
543
+ )
544
+ + 1
545
+ )
546
+ objectness = objectness.sigmoid()
547
+ objectness, boxes, image_class_embeddings, logit_shift, logit_scale = (
548
+ filter_tensors_by_objectness(
549
+ objectness,
550
+ boxes,
551
+ image_class_embeddings,
552
+ logit_shift,
553
+ logit_scale,
554
+ max_detections,
555
+ )
556
+ )
557
+ embeddings = ImageEmbeddings(
558
+ image_hash=image_hash,
559
+ objectness=objectness,
560
+ boxes=boxes,
561
+ image_class_embeddings=image_class_embeddings,
562
+ logit_shift=logit_shift,
563
+ logit_scale=logit_scale,
564
+ image_size_wh=(image_dimensions[0].width, image_dimensions[0].height),
565
+ )
566
+ self._owlv2_images_embeddings_cache.save_embeddings(embeddings=embeddings)
567
+ return embeddings
568
+
569
+
570
+ def to_corners(box: torch.Tensor) -> torch.Tensor:
571
+ cx, cy, w, h = box.unbind(-1)
572
+ x1 = cx - w / 2
573
+ y1 = cy - h / 2
574
+ x2 = cx + w / 2
575
+ y2 = cy + h / 2
576
+ return torch.stack([x1, y1, x2, y2], dim=-1)
577
+
578
+
579
+ def make_class_mapping(
580
+ class_names: Iterable[str],
581
+ ) -> Tuple[Dict[Tuple[str, str], int], List[str]]:
582
+ class_names = sorted(class_names)
583
+ class_map_positive = {
584
+ (class_name, POSITIVE_EXAMPLE): i for i, class_name in enumerate(class_names)
585
+ }
586
+ class_map_negative = {
587
+ (class_name, NEGATIVE_EXAMPLE): i + len(class_names)
588
+ for i, class_name in enumerate(class_names)
589
+ }
590
+ class_map = {**class_map_positive, **class_map_negative}
591
+ return class_map, class_names
592
+
593
+
594
+ def filter_tensors_by_objectness(
595
+ objectness: torch.Tensor,
596
+ boxes: torch.Tensor,
597
+ image_class_embeds: torch.Tensor,
598
+ logit_shift: torch.Tensor,
599
+ logit_scale: torch.Tensor,
600
+ max_detections: int,
601
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
602
+ objectness = objectness.squeeze(0)
603
+ objectness, objectness_indices = torch.topk(objectness, max_detections, dim=0)
604
+ boxes = boxes.squeeze(0)
605
+ image_class_embeds = image_class_embeds.squeeze(0)
606
+ logit_shift = logit_shift.squeeze(0).squeeze(1)
607
+ logit_scale = logit_scale.squeeze(0).squeeze(1)
608
+ boxes = boxes[objectness_indices]
609
+ image_class_embeds = image_class_embeds[objectness_indices]
610
+ logit_shift = logit_shift[objectness_indices]
611
+ logit_scale = logit_scale[objectness_indices]
612
+ return objectness, boxes, image_class_embeds, logit_shift, logit_scale
613
+
614
+
615
+ def get_class_predictions_from_embedings(
616
+ reference_examples_class_embeddings: ReferenceExamplesClassEmbeddings,
617
+ image_class_embeddings: torch.Tensor,
618
+ image_boxes: torch.Tensor,
619
+ confidence_threshold: float,
620
+ class_mapping: Dict[Tuple[str, str], int],
621
+ class_name: str,
622
+ iou_threshold: float,
623
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
624
+ predicted_boxes_per_class = []
625
+ predicted_class_indices_per_class = []
626
+ predicted_scores_per_class = []
627
+ positive_arr_per_class = []
628
+ if reference_examples_class_embeddings.positive is not None:
629
+ pred_logits = torch.einsum(
630
+ "sd,nd->ns",
631
+ image_class_embeddings,
632
+ reference_examples_class_embeddings.positive,
633
+ )
634
+ prediction_scores = pred_logits.max(dim=0)[0]
635
+ prediction_scores = (prediction_scores + 1) / 2
636
+ score_mask = prediction_scores > confidence_threshold
637
+ predicted_boxes_per_class.append(image_boxes[score_mask])
638
+ scores = prediction_scores[score_mask]
639
+ predicted_scores_per_class.append(scores)
640
+ class_ind = class_mapping[(class_name, POSITIVE_EXAMPLE)]
641
+ predicted_class_indices_per_class.append(class_ind * torch.ones_like(scores))
642
+ positive_arr_per_class.append(torch.ones_like(scores))
643
+ if reference_examples_class_embeddings.negative is not None:
644
+ pred_logits = torch.einsum(
645
+ "sd,nd->ns",
646
+ image_class_embeddings,
647
+ reference_examples_class_embeddings.positive,
648
+ )
649
+ prediction_scores = pred_logits.max(dim=0)[0]
650
+ prediction_scores = (prediction_scores + 1) / 2
651
+ score_mask = prediction_scores > confidence_threshold
652
+ predicted_boxes_per_class.append(image_boxes[score_mask])
653
+ scores = prediction_scores[score_mask]
654
+ predicted_scores_per_class.append(scores)
655
+ class_ind = class_mapping[(class_name, NEGATIVE_EXAMPLE)]
656
+ predicted_class_indices_per_class.append(class_ind * torch.ones_like(scores))
657
+ positive_arr_per_class.append(torch.zeros_like(scores))
658
+ if not predicted_boxes_per_class:
659
+ return (
660
+ torch.empty((0, 4)),
661
+ torch.empty((0,)),
662
+ torch.empty((0,)),
663
+ )
664
+ # concat tensors
665
+ pred_boxes = torch.cat(predicted_boxes_per_class, dim=0).float()
666
+ pred_classes = torch.cat(predicted_class_indices_per_class, dim=0).float()
667
+ pred_scores = torch.cat(predicted_scores_per_class, dim=0).float()
668
+ positive = torch.cat(positive_arr_per_class, dim=0).float()
669
+ # nms
670
+ survival_indices = torchvision.ops.nms(
671
+ to_corners(pred_boxes), pred_scores, iou_threshold
672
+ )
673
+ # filter to post-nms
674
+ pred_boxes = pred_boxes[survival_indices, :]
675
+ pred_classes = pred_classes[survival_indices]
676
+ pred_scores = pred_scores[survival_indices]
677
+ positive = positive[survival_indices]
678
+ is_positive = positive == 1
679
+ # return only positive elements of tensor
680
+ return pred_boxes[is_positive], pred_classes[is_positive], pred_scores[is_positive]
681
+
682
+
683
+ def xywh_normalized_to_xyxy(
684
+ boxes_xywh: torch.Tensor, image_size_wh: Tuple[int, int]
685
+ ) -> torch.Tensor:
686
+ max_dim = max(image_size_wh)
687
+ x_center = boxes_xywh[..., 0] * max_dim
688
+ y_center = boxes_xywh[..., 1] * max_dim
689
+ box_width = boxes_xywh[..., 2] * max_dim
690
+ box_height = boxes_xywh[..., 3] * max_dim
691
+ x1 = x_center - box_width / 2
692
+ y1 = y_center - box_height / 2
693
+ x2 = x_center + box_width / 2
694
+ y2 = y_center + box_height / 2
695
+ return torch.stack([x1, y1, x2, y2], dim=-1).to(device=boxes_xywh.device)