inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,675 @@
1
+ import hashlib
2
+ from typing import Dict, Generator, List, Optional, Tuple, TypeVar, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from segment_anything import sam_model_registry
7
+ from segment_anything.modeling import Sam
8
+ from segment_anything.utils.transforms import ResizeLongestSide
9
+
10
+ from inference_models import ColorFormat
11
+ from inference_models.configuration import DEFAULT_DEVICE
12
+ from inference_models.errors import CorruptedModelPackageError, ModelInputError
13
+ from inference_models.models.common.model_packages import get_model_package_contents
14
+ from inference_models.models.sam.cache import (
15
+ SamImageEmbeddingsCache,
16
+ SamImageEmbeddingsCacheNullObject,
17
+ SamLowResolutionMasksCache,
18
+ SamLowResolutionMasksCacheNullObject,
19
+ )
20
+ from inference_models.models.sam.entities import SAMImageEmbeddings, SAMPrediction
21
+ from inference_models.utils.file_system import read_json
22
+
23
+ T = TypeVar("T")
24
+
25
+ MAX_SAM_BATCH_SIZE = 8
26
+
27
+ ArrayOrTensor = Union[np.ndarray, torch.Tensor]
28
+
29
+
30
+ class SAMTorch:
31
+
32
+ @classmethod
33
+ def from_pretrained(
34
+ cls,
35
+ model_name_or_path: str,
36
+ device: torch.device = DEFAULT_DEVICE,
37
+ max_batch_size: int = MAX_SAM_BATCH_SIZE,
38
+ sam_image_embeddings_cache: Optional[SamImageEmbeddingsCache] = None,
39
+ sam_low_resolution_masks_cache: Optional[SamLowResolutionMasksCache] = None,
40
+ **kwargs,
41
+ ) -> "SAMTorch":
42
+ if sam_image_embeddings_cache is None:
43
+ sam_image_embeddings_cache = SamImageEmbeddingsCacheNullObject()
44
+ if sam_low_resolution_masks_cache is None:
45
+ sam_low_resolution_masks_cache = SamLowResolutionMasksCacheNullObject()
46
+ model_package_content = get_model_package_contents(
47
+ model_package_dir=model_name_or_path,
48
+ elements=[
49
+ "model.pth",
50
+ "sam_configuration.json",
51
+ ],
52
+ )
53
+ try:
54
+ version = decode_sam_version(
55
+ config_path=model_package_content["sam_configuration.json"]
56
+ )
57
+ except Exception as error:
58
+ raise CorruptedModelPackageError(
59
+ message="Cold not decode SAM model version. If you see this error running inference locally, "
60
+ "verify the contents of model package. If you see the error running on Roboflow platform - "
61
+ "contact us to get help.",
62
+ help_url="https://todo",
63
+ ) from error
64
+ try:
65
+ sam_model = sam_model_registry[version](
66
+ checkpoint=model_package_content["model.pth"]
67
+ ).to(device)
68
+ except Exception as error:
69
+ raise CorruptedModelPackageError(
70
+ message=f"Cold not decode initialize SAM model - cause: {error} If you see this error running "
71
+ f"locally - verify installation of inference and contents of model package. If you use "
72
+ f"Roboflow platform, contact us to get help.",
73
+ help_url="https://todo",
74
+ ) from error
75
+ transform = ResizeLongestSide(sam_model.image_encoder.img_size)
76
+ return cls(
77
+ model=sam_model,
78
+ transform=transform,
79
+ device=device,
80
+ max_batch_size=max_batch_size,
81
+ sam_image_embeddings_cache=sam_image_embeddings_cache,
82
+ sam_low_resolution_masks_cache=sam_low_resolution_masks_cache,
83
+ )
84
+
85
+ def __init__(
86
+ self,
87
+ model: Sam,
88
+ transform: ResizeLongestSide,
89
+ device: torch.device,
90
+ max_batch_size: int,
91
+ sam_image_embeddings_cache: SamImageEmbeddingsCache,
92
+ sam_low_resolution_masks_cache: SamLowResolutionMasksCache,
93
+ ):
94
+ self._model = model
95
+ self._transform = transform
96
+ self._device = device
97
+ self._max_batch_size = max_batch_size
98
+ self._sam_image_embeddings_cache = sam_image_embeddings_cache
99
+ self._sam_low_resolution_masks_cache = sam_low_resolution_masks_cache
100
+
101
+ def embed_images(
102
+ self,
103
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
104
+ input_color_format: Optional[ColorFormat] = None,
105
+ use_embeddings_cache: bool = True,
106
+ **kwargs,
107
+ ) -> List[SAMImageEmbeddings]:
108
+ model_input_images, image_hashes, original_image_sizes = (
109
+ self.pre_process_images(
110
+ images=images,
111
+ input_color_format=input_color_format,
112
+ **kwargs,
113
+ )
114
+ )
115
+ embeddings_from_cache: Dict[int, SAMImageEmbeddings] = {}
116
+ images_to_compute = []
117
+ for idx, (image, image_hash) in enumerate(
118
+ zip(model_input_images, image_hashes)
119
+ ):
120
+ cache_content = None
121
+ if use_embeddings_cache:
122
+ cache_content = self._sam_image_embeddings_cache.retrieve_embeddings(
123
+ key=image_hash
124
+ )
125
+ if cache_content is not None:
126
+ cache_content = cache_content.to(device=self._device)
127
+ embeddings_from_cache[idx] = cache_content
128
+ else:
129
+ images_to_compute.append(image)
130
+ if len(images_to_compute) > 0:
131
+ images_to_compute = torch.stack(images_to_compute, dim=0)
132
+ computed_embeddings = self.forward_image_embeddings(
133
+ model_input_images=images_to_compute,
134
+ )
135
+ computed_embeddings_idx = 0
136
+ result_embeddings = []
137
+ for i in range(len(model_input_images)):
138
+ if i in embeddings_from_cache:
139
+ result_embeddings.append(embeddings_from_cache[i].embeddings)
140
+ else:
141
+ result_embeddings.append(
142
+ computed_embeddings[computed_embeddings_idx]
143
+ )
144
+ computed_embeddings_idx += 1
145
+ else:
146
+ result_embeddings = [
147
+ embeddings_from_cache[i].embeddings
148
+ for i in range(len(model_input_images))
149
+ ]
150
+ results = []
151
+ for image_hash, image_size, image_embeddings in zip(
152
+ image_hashes, original_image_sizes, result_embeddings
153
+ ):
154
+ result = SAMImageEmbeddings(
155
+ image_hash=image_hash,
156
+ image_size_hw=image_size,
157
+ embeddings=image_embeddings,
158
+ )
159
+ results.append(result)
160
+ if use_embeddings_cache:
161
+ self._sam_image_embeddings_cache.save_embeddings(
162
+ key=image_hash, embeddings=result
163
+ )
164
+ return results
165
+
166
+ def pre_process_images(
167
+ self,
168
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
169
+ input_color_format: Optional[ColorFormat] = None,
170
+ **kwargs,
171
+ ) -> Tuple[torch.Tensor, List[str], List[Tuple[int, int]]]:
172
+ if isinstance(images, torch.Tensor):
173
+ images = images.to(device=self._device)
174
+ if images.device.type == "cuda":
175
+ images = images.float()
176
+ if len(images.shape) == 4:
177
+ image_hashes = [compute_image_hash(image=image) for image in images]
178
+ if input_color_format == "bgr":
179
+ images = images[:, :-1, :, :].contiguous()
180
+ original_image_sizes = [tuple(images.shape[2:4])] * images.shape[0]
181
+ model_input_images = self._transform.apply_image_torch(image=images)
182
+ else:
183
+ image_hashes = [compute_image_hash(image=images)]
184
+ if input_color_format == "bgr":
185
+ images = images[::-1, :, :].contiguous()
186
+ original_image_sizes = [tuple(images.shape[1:3])]
187
+ model_input_images = self._transform.apply_image_torch(
188
+ image=images.unsqueeze(dim=0)
189
+ )
190
+ else:
191
+ if isinstance(images, list):
192
+ image_hashes = [compute_image_hash(image=image) for image in images]
193
+ original_image_sizes = []
194
+ model_input_images = []
195
+ for image in images:
196
+ if isinstance(image, np.ndarray):
197
+ original_image_sizes.append(image.shape[:2])
198
+ if input_color_format in {None, "bgr"}:
199
+ image = np.ascontiguousarray(image[:, :, ::-1])
200
+ input_image = self._transform.apply_image(image=image)
201
+ input_image = (
202
+ torch.as_tensor(input_image, device=self._device)
203
+ .permute(2, 0, 1)
204
+ .contiguous()
205
+ )
206
+ model_input_images.append(input_image)
207
+ else:
208
+ original_image_sizes.append(tuple(image.shape[1:3]))
209
+ image = image.to(self._device)
210
+ if image.device.type == "cuda":
211
+ image = image.float()
212
+ if input_color_format == "bgr":
213
+ image = image[::-1, :, :].contiguous()
214
+ input_image = self._transform.apply_image_torch(
215
+ image=image.unsqueeze(dim=0)
216
+ )[0]
217
+ model_input_images.append(input_image)
218
+ model_input_images = torch.stack(model_input_images, dim=0)
219
+ else:
220
+ image_hashes = [compute_image_hash(image=images)]
221
+ original_image_sizes = [images.shape[:2]]
222
+ if input_color_format in {None, "bgr"}:
223
+ images = np.ascontiguousarray(images[:, :, ::-1])
224
+ model_input_images = self._transform.apply_image(image=images)
225
+ model_input_images = (
226
+ torch.as_tensor(model_input_images, device=self._device)
227
+ .permute(2, 0, 1)
228
+ .contiguous()[None, :, :, :]
229
+ )
230
+ return model_input_images, image_hashes, original_image_sizes
231
+
232
+ @torch.inference_mode()
233
+ def forward_image_embeddings(
234
+ self, model_input_images: torch.Tensor, **kwargs
235
+ ) -> torch.Tensor:
236
+ result_embeddings = []
237
+ for i in range(0, model_input_images.shape[0], self._max_batch_size):
238
+ input_images_batch = model_input_images[
239
+ i : i + self._max_batch_size
240
+ ].contiguous()
241
+ pre_processed_images_batch = self._model.preprocess(input_images_batch)
242
+ batch_embeddings = self._model.image_encoder(pre_processed_images_batch).to(
243
+ device=self._device
244
+ )
245
+ result_embeddings.append(batch_embeddings)
246
+ return torch.cat(result_embeddings, dim=0)
247
+
248
+ def segment_images(
249
+ self,
250
+ images: Optional[
251
+ Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]]
252
+ ] = None,
253
+ embeddings: Optional[
254
+ Union[List[SAMImageEmbeddings], SAMImageEmbeddings]
255
+ ] = None,
256
+ point_coordinates: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
257
+ point_labels: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
258
+ boxes: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
259
+ mask_input: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
260
+ multi_mask_output: bool = True,
261
+ return_logits: bool = False,
262
+ input_color_format: Optional[ColorFormat] = None,
263
+ mask_threshold: Optional[float] = None,
264
+ enforce_mask_input: bool = False,
265
+ use_mask_input_cache: bool = True,
266
+ use_embeddings_cache: bool = True,
267
+ **kwargs,
268
+ ) -> List[SAMPrediction]:
269
+ if images is None and embeddings is None:
270
+ raise ModelInputError(
271
+ message="Attempted to use SAM model segment_images(...) method not providing valid input - "
272
+ "neither `images` nor `embeddings` parameter is given. If you run inference locally, "
273
+ "verify your integration making sure that the model interface is used correctly. Running "
274
+ "on Roboflow platform - contact us to get help.",
275
+ help_url="https://todo",
276
+ )
277
+ if images is not None:
278
+ embeddings = self.embed_images(
279
+ images=images,
280
+ input_color_format=input_color_format,
281
+ use_embeddings_cache=use_embeddings_cache,
282
+ **kwargs,
283
+ )
284
+ else:
285
+ embeddings = maybe_wrap_in_list(value=embeddings)
286
+ embeddings_tensors = [e.embeddings.to(self._device) for e in embeddings]
287
+ image_hashes = [e.image_hash for e in embeddings]
288
+ original_image_sizes = [e.image_size_hw for e in embeddings]
289
+ point_coordinates = maybe_wrap_in_list(value=point_coordinates)
290
+ point_labels = maybe_wrap_in_list(value=point_labels)
291
+ boxes = maybe_wrap_in_list(value=boxes)
292
+ mask_input = maybe_wrap_in_list(value=mask_input)
293
+ masks_from_the_cache = [
294
+ (
295
+ self._sam_low_resolution_masks_cache.retrieve_mask(key=image_hash)
296
+ if use_mask_input_cache
297
+ else None
298
+ )
299
+ for image_hash in image_hashes
300
+ ]
301
+ if enforce_mask_input and mask_input is None:
302
+ if not all(e is not None for e in masks_from_the_cache):
303
+ raise ModelInputError(
304
+ message="Attempted to use SAM model segment_images(...) method enforcing the presence of "
305
+ "low-resolution mask input and not providing the mask explicitly (causing fallback to "
306
+ "SAM cache lookup which failed for at least one image) - this problem may be temporary, "
307
+ "but may also be a result of bug or invalid integration. If you run inference locally, "
308
+ "verify your integration making sure that the model interface is used correctly. Running "
309
+ "on Roboflow platform - contact us to get help.",
310
+ help_url="https://todo",
311
+ )
312
+ mask_input = [mask.to(self._device) for mask in masks_from_the_cache]
313
+ point_coordinates, point_labels, boxes, mask_input = equalize_batch_size(
314
+ embeddings_batch_size=len(embeddings),
315
+ point_coordinates=point_coordinates,
316
+ point_labels=point_labels,
317
+ boxes=boxes,
318
+ mask_input=mask_input,
319
+ )
320
+ point_coordinates, point_labels, boxes, mask_input = pre_process_prompts(
321
+ point_coordinates=point_coordinates,
322
+ point_labels=point_labels,
323
+ boxes=boxes,
324
+ mask_input=mask_input,
325
+ device=self._device,
326
+ transform=self._transform,
327
+ original_image_sizes=original_image_sizes,
328
+ )
329
+ predictions = []
330
+ for (
331
+ image_embedding,
332
+ image_hash,
333
+ image_size,
334
+ image_point_coordinates,
335
+ image_point_labels,
336
+ image_boxes,
337
+ image_mask_input,
338
+ ) in generate_model_inputs(
339
+ embeddings=embeddings_tensors,
340
+ image_hashes=image_hashes,
341
+ original_image_sizes=original_image_sizes,
342
+ point_coordinates=point_coordinates,
343
+ point_labels=point_labels,
344
+ boxes=boxes,
345
+ mask_input=mask_input,
346
+ ):
347
+ prediction = predict_for_single_image(
348
+ model=self._model,
349
+ transform=self._transform,
350
+ embeddings=image_embedding,
351
+ original_image_size=image_size,
352
+ point_coordinates=image_point_coordinates,
353
+ point_labels=image_point_labels,
354
+ boxes=image_boxes,
355
+ mask_input=image_mask_input,
356
+ multi_mask_output=multi_mask_output,
357
+ return_logits=return_logits,
358
+ mask_threshold=mask_threshold,
359
+ )
360
+ if use_mask_input_cache and len(prediction[0].shape) == 3:
361
+ max_score_id = torch.argmax(prediction[1]).item()
362
+ self._sam_low_resolution_masks_cache.save_mask(
363
+ key=image_hash, mask=prediction[2][max_score_id].unsqueeze(dim=0)
364
+ )
365
+ parsed_prediction = SAMPrediction(
366
+ masks=prediction[0],
367
+ scores=prediction[1],
368
+ logits=prediction[2],
369
+ )
370
+ predictions.append(parsed_prediction)
371
+ return predictions
372
+
373
+
374
+ def decode_sam_version(config_path: str) -> str:
375
+ config = read_json(path=config_path)
376
+ version = config["version"]
377
+ if not isinstance(version, str):
378
+ raise ValueError("Could not decode SAM model version")
379
+ return version
380
+
381
+
382
+ def compute_image_hash(image: Union[torch.Tensor, np.ndarray]) -> str:
383
+ if isinstance(image, torch.Tensor):
384
+ image = image.cpu().numpy()
385
+ return hash_function(value=image.tobytes())
386
+
387
+
388
+ def hash_function(value: Union[str, bytes]) -> str:
389
+ return hashlib.sha1(value).hexdigest()
390
+
391
+
392
+ def maybe_wrap_in_list(value: Optional[Union[T, List[T]]]) -> Optional[List[T]]:
393
+ if value is None:
394
+ return None
395
+ if isinstance(value, list):
396
+ return value
397
+ return [value]
398
+
399
+
400
+ def equalize_batch_size(
401
+ embeddings_batch_size: int,
402
+ point_coordinates: Optional[List[ArrayOrTensor]],
403
+ point_labels: Optional[List[ArrayOrTensor]],
404
+ boxes: Optional[List[ArrayOrTensor]],
405
+ mask_input: Optional[List[ArrayOrTensor]],
406
+ ) -> Tuple[
407
+ Optional[List[ArrayOrTensor]],
408
+ Optional[List[ArrayOrTensor]],
409
+ Optional[List[ArrayOrTensor]],
410
+ Optional[List[ArrayOrTensor]],
411
+ ]:
412
+ if (
413
+ point_coordinates is not None
414
+ and len(point_coordinates) != embeddings_batch_size
415
+ ):
416
+ if len(point_coordinates) != 1:
417
+ raise ModelInputError(
418
+ message="When using SAM model, parameter `point_coordinates` was provided with invalid "
419
+ f"value indicating different input batch size ({len(point_coordinates)}) than provided "
420
+ f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
421
+ "integration making sure that the model interface is used correctly. "
422
+ "Running on Roboflow platform - contact us to get help.",
423
+ help_url="https://todo",
424
+ )
425
+ point_coordinates = point_coordinates * embeddings_batch_size
426
+ if point_labels is not None and len(point_labels) != embeddings_batch_size:
427
+ if len(point_labels) != 1:
428
+ raise ModelInputError(
429
+ message="When using SAM model, parameter `point_labels` was provided with invalid "
430
+ f"value indicating different input batch size ({len(point_labels)}) than provided "
431
+ f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
432
+ "integration making sure that the model interface is used correctly. "
433
+ "Running on Roboflow platform - contact us to get help.",
434
+ help_url="https://todo",
435
+ )
436
+ point_labels = point_labels * embeddings_batch_size
437
+ if boxes is not None and len(boxes) != embeddings_batch_size:
438
+ if len(boxes) != 1:
439
+ raise ModelInputError(
440
+ message="When using SAM model, parameter `boxes` was provided with invalid "
441
+ f"value indicating different input batch size ({len(boxes)}) than provided "
442
+ f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
443
+ "integration making sure that the model interface is used correctly. "
444
+ "Running on Roboflow platform - contact us to get help.",
445
+ help_url="https://todo",
446
+ )
447
+ boxes = boxes * embeddings_batch_size
448
+ if mask_input is not None and len(mask_input) != embeddings_batch_size:
449
+ if len(mask_input) != 1:
450
+ raise ModelInputError(
451
+ message="When using SAM model, parameter `mask_input` was provided with invalid "
452
+ f"value indicating different input batch size ({len(mask_input)}) than provided "
453
+ f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
454
+ "integration making sure that the model interface is used correctly. "
455
+ "Running on Roboflow platform - contact us to get help.",
456
+ help_url="https://todo",
457
+ )
458
+ mask_input = mask_input * embeddings_batch_size
459
+ prompts_first_dimension_characteristics = set()
460
+ if point_coordinates is not None:
461
+ point_coordinates_characteristic = "-".join(
462
+ [str(p.shape[0]) for p in point_coordinates]
463
+ )
464
+ prompts_first_dimension_characteristics.add(point_coordinates_characteristic)
465
+ if point_labels is not None:
466
+ point_labels_characteristic = "-".join([str(l.shape[0]) for l in point_labels])
467
+ prompts_first_dimension_characteristics.add(point_labels_characteristic)
468
+ if boxes is not None:
469
+ boxes_characteristic = "-".join(
470
+ [str(b.shape[0]) if len(b.shape) > 1 else "1" for b in boxes]
471
+ )
472
+ prompts_first_dimension_characteristics.add(boxes_characteristic)
473
+ if len(prompts_first_dimension_characteristics) > 1:
474
+ raise ModelInputError(
475
+ message="When using SAM model, in scenario when combination of `point_coordinates` and `point_labels` and "
476
+ "`boxes` provided, the model expect identical number of elements for each prompt component. "
477
+ "If you run inference locally, verify your integration making sure that the model interface is "
478
+ "used correctly. Running on Roboflow platform - contact us to get help.",
479
+ help_url="https://todo",
480
+ )
481
+ if mask_input is not None:
482
+ mask_input = [i[None, :, :] if len(i.shape) == 2 else i for i in mask_input]
483
+ if any(len(i.shape) != 3 or i.shape[0] != 1 for i in mask_input):
484
+ raise ModelInputError(
485
+ message="When using SAM model with `mask_input`, each mask must be 3D tensor of shape (1, H, W). "
486
+ "If you run inference locally, verify your integration making sure that the model interface is "
487
+ "used correctly. Running on Roboflow platform - contact us to get help.",
488
+ help_url="https://todo",
489
+ )
490
+ if boxes is not None:
491
+ batched_boxes_provided = False
492
+ for box in boxes:
493
+ if len(box.shape) > 1 and box.shape[0] > 1:
494
+ batched_boxes_provided = True
495
+ if batched_boxes_provided and any(
496
+ e is not None for e in [point_coordinates, point_labels, mask_input]
497
+ ):
498
+ raise ModelInputError(
499
+ message="When using SAM, providing batched boxes (multiple RoIs for single image) makes it impossible "
500
+ "to use other components of the prompt - like `point_coordinates`, `point_labels` "
501
+ "or `mask_input` - and such situation was detected. "
502
+ "If you run inference locally, verify your integration making sure that the model interface is "
503
+ "used correctly. Running on Roboflow platform - contact us to get help.",
504
+ help_url="https://todo",
505
+ )
506
+ return point_coordinates, point_labels, boxes, mask_input
507
+
508
+
509
+ def maybe_broadcast_list(value: Optional[List[T]], n: int) -> Optional[List[T]]:
510
+ if value is None:
511
+ return None
512
+ return value * n
513
+
514
+
515
+ def pre_process_prompts(
516
+ point_coordinates: Optional[List[ArrayOrTensor]],
517
+ point_labels: Optional[List[ArrayOrTensor]],
518
+ boxes: Optional[List[ArrayOrTensor]],
519
+ mask_input: Optional[List[ArrayOrTensor]],
520
+ device: torch.device,
521
+ transform: ResizeLongestSide,
522
+ original_image_sizes: List[Tuple[int, int]],
523
+ ) -> Tuple[
524
+ Optional[List[torch.Tensor]],
525
+ Optional[List[torch.Tensor]],
526
+ Optional[List[torch.Tensor]],
527
+ Optional[List[torch.Tensor]],
528
+ ]:
529
+ if point_labels is not None and point_coordinates is None:
530
+ raise ModelInputError(
531
+ message="When using SAM model, provided `point_coordinates` without `point_labels` which makes invalid "
532
+ "input. If you run inference locally, verify your integration making sure that the model "
533
+ "interface is used correctly. Running on Roboflow platform - contact us to get help.",
534
+ help_url="https://todo",
535
+ )
536
+ if point_coordinates is not None:
537
+ if point_labels is None:
538
+ raise ModelInputError(
539
+ message="When using SAM model, provided `point_coordinates` without `point_labels` which makes invalid "
540
+ "input. If you run inference locally, verify your integration making sure that the model "
541
+ "interface is used correctly. Running on Roboflow platform - contact us to get help.",
542
+ help_url="https://todo",
543
+ )
544
+ point_coordinates = [
545
+ (
546
+ c.to(device)[None, :, :]
547
+ if isinstance(c, torch.Tensor)
548
+ else torch.from_numpy(c).to(device)[None, :, :]
549
+ )
550
+ for c in point_coordinates
551
+ ]
552
+ point_labels = [
553
+ (
554
+ l.to(device)[None, :]
555
+ if isinstance(l, torch.Tensor)
556
+ else torch.from_numpy(l).to(device)[None, :]
557
+ )
558
+ for l in point_labels
559
+ ]
560
+ point_coordinates = [
561
+ transform.apply_coords_torch(point_coords, image_shape)
562
+ for point_coords, image_shape in zip(
563
+ point_coordinates, original_image_sizes
564
+ )
565
+ ]
566
+ if boxes is not None:
567
+ boxes = [
568
+ (
569
+ box.to(device)[None, :]
570
+ if isinstance(box, torch.Tensor)
571
+ else torch.from_numpy(box).to(device)[None, :]
572
+ )
573
+ for box in boxes
574
+ ]
575
+ boxes = [
576
+ transform.apply_boxes_torch(box, image_shape)
577
+ for box, image_shape in zip(boxes, original_image_sizes)
578
+ ]
579
+ if mask_input is not None:
580
+ mask_input = [
581
+ (
582
+ mask.to(device)[None, :, :]
583
+ if isinstance(mask, torch.Tensor)
584
+ else torch.from_numpy(mask).to(device)[None, :, :]
585
+ )
586
+ for mask in mask_input
587
+ ]
588
+ return point_coordinates, point_labels, boxes, mask_input
589
+
590
+
591
+ def generate_model_inputs(
592
+ embeddings: List[torch.Tensor],
593
+ image_hashes: List[str],
594
+ original_image_sizes: List[Tuple[int, int]],
595
+ point_coordinates: Optional[List[torch.Tensor]],
596
+ point_labels: Optional[List[torch.Tensor]],
597
+ boxes: Optional[List[torch.Tensor]],
598
+ mask_input: Optional[List[torch.Tensor]],
599
+ ) -> Generator[
600
+ Tuple[
601
+ torch.Tensor,
602
+ str,
603
+ Tuple[int, int],
604
+ Optional[torch.Tensor],
605
+ Optional[torch.Tensor],
606
+ Optional[torch.Tensor],
607
+ Optional[torch.Tensor],
608
+ ],
609
+ None,
610
+ None,
611
+ ]:
612
+ if point_coordinates is None:
613
+ point_coordinates = [None] * len(embeddings)
614
+ if point_labels is None:
615
+ point_labels = [None] * len(embeddings)
616
+ if boxes is None:
617
+ boxes = [None] * len(embeddings)
618
+ if mask_input is None:
619
+ mask_input = [None] * len(embeddings)
620
+ for embedding, hash_value, image_size, coords, labels, box, mask in zip(
621
+ embeddings,
622
+ image_hashes,
623
+ original_image_sizes,
624
+ point_coordinates,
625
+ point_labels,
626
+ boxes,
627
+ mask_input,
628
+ ):
629
+ yield embedding, hash_value, image_size, coords, labels, box, mask
630
+
631
+
632
+ @torch.inference_mode()
633
+ def predict_for_single_image(
634
+ model: Sam,
635
+ transform: ResizeLongestSide,
636
+ embeddings: torch.Tensor,
637
+ original_image_size: Tuple[int, int],
638
+ point_coordinates: Optional[torch.Tensor],
639
+ point_labels: Optional[torch.Tensor],
640
+ boxes: Optional[torch.Tensor] = None,
641
+ mask_input: Optional[torch.Tensor] = None,
642
+ multi_mask_output: bool = True,
643
+ return_logits: bool = False,
644
+ mask_threshold: Optional[float] = None,
645
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
646
+ embeddings = embeddings.unsqueeze(dim=0)
647
+ if point_coordinates is not None:
648
+ points = (point_coordinates, point_labels)
649
+ else:
650
+ points = None
651
+ sparse_embeddings, dense_embeddings = model.prompt_encoder(
652
+ points=points,
653
+ boxes=boxes,
654
+ masks=mask_input,
655
+ )
656
+ low_res_masks, iou_predictions = model.mask_decoder(
657
+ image_embeddings=embeddings,
658
+ image_pe=model.prompt_encoder.get_dense_pe(),
659
+ sparse_prompt_embeddings=sparse_embeddings,
660
+ dense_prompt_embeddings=dense_embeddings,
661
+ multimask_output=multi_mask_output,
662
+ )
663
+ model_input_size = transform.get_preprocess_shape(
664
+ original_image_size[0], original_image_size[1], transform.target_length
665
+ )
666
+ masks = model.postprocess_masks(
667
+ low_res_masks, model_input_size, original_image_size
668
+ )
669
+ if not return_logits:
670
+ threshold = mask_threshold or model.mask_threshold
671
+ masks = masks > threshold
672
+ if masks.shape[0] == 1:
673
+ return masks[0], iou_predictions[0], low_res_masks[0]
674
+ else:
675
+ return masks, iou_predictions, low_res_masks
File without changes