inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,905 @@
1
+ import hashlib
2
+ import json
3
+ from copy import copy
4
+ from typing import Dict, Generator, List, Optional, Tuple, TypeVar, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from sam2.build_sam import build_sam2
9
+ from sam2.modeling.sam2_base import SAM2Base
10
+ from sam2.utils.transforms import SAM2Transforms
11
+
12
+ from inference_models import ColorFormat
13
+ from inference_models.configuration import DEFAULT_DEVICE
14
+ from inference_models.errors import (
15
+ AssumptionError,
16
+ CorruptedModelPackageError,
17
+ ModelInputError,
18
+ )
19
+ from inference_models.models.common.model_packages import get_model_package_contents
20
+ from inference_models.models.sam2.cache import (
21
+ Sam2ImageEmbeddingsCache,
22
+ Sam2ImageEmbeddingsCacheNullObject,
23
+ Sam2LowResolutionMasksCache,
24
+ Sam2LowResolutionMasksCacheNullObject,
25
+ )
26
+ from inference_models.models.sam2.entities import (
27
+ SAM2ImageEmbeddings,
28
+ SAM2MaskCacheEntry,
29
+ SAM2Prediction,
30
+ )
31
+ from inference_models.utils.file_system import read_json
32
+
33
+ ArrayOrTensor = Union[np.ndarray, torch.Tensor]
34
+ T = TypeVar("T")
35
+
36
+ MAX_SAM2_BATCH_SIZE = 8
37
+
38
+ SUPPORTED_VERSIONS = {
39
+ "sam2_hiera_t",
40
+ "sam2_hiera_s",
41
+ "sam2_hiera_b+",
42
+ "sam2_hiera_l",
43
+ "sam2.1_hiera_t",
44
+ "sam2.1_hiera_s",
45
+ "sam2.1_hiera_b+",
46
+ "sam2.1_hiera_l",
47
+ }
48
+
49
+
50
+ class SAM2Torch:
51
+
52
+ @classmethod
53
+ def from_pretrained(
54
+ cls,
55
+ model_name_or_path: str,
56
+ device: torch.device = DEFAULT_DEVICE,
57
+ max_batch_size: int = MAX_SAM2_BATCH_SIZE,
58
+ disable_sam2_torch_jit_transforms: bool = True,
59
+ sam2_image_embeddings_cache: Optional[Sam2ImageEmbeddingsCache] = None,
60
+ sam2_low_resolution_masks_cache: Optional[Sam2LowResolutionMasksCache] = None,
61
+ **kwargs,
62
+ ) -> "SAM2Torch":
63
+ if sam2_image_embeddings_cache is None:
64
+ sam2_image_embeddings_cache = Sam2ImageEmbeddingsCacheNullObject()
65
+ if sam2_low_resolution_masks_cache is None:
66
+ sam2_low_resolution_masks_cache = Sam2LowResolutionMasksCacheNullObject()
67
+ model_package_content = get_model_package_contents(
68
+ model_package_dir=model_name_or_path,
69
+ elements=[
70
+ "model.pt",
71
+ "sam_configuration.json",
72
+ ],
73
+ )
74
+ try:
75
+ version = decode_sam_version(
76
+ config_path=model_package_content["sam_configuration.json"]
77
+ )
78
+ except Exception as error:
79
+ raise CorruptedModelPackageError(
80
+ message="Cold not decode SAM2 model version. If you see this error running inference locally, "
81
+ "verify the contents of model package. If you see the error running on Roboflow platform - "
82
+ "contact us to get help.",
83
+ help_url="https://todo",
84
+ ) from error
85
+ if version not in SUPPORTED_VERSIONS:
86
+ raise CorruptedModelPackageError(
87
+ message=f"Detected unsupported version of SAM2 model: {version}. Supported versions: "
88
+ f"are {SUPPORTED_VERSIONS}. If you run inference locally, verify the correctness of "
89
+ f"SAM2 model package. If you see the error running on Roboflow platform - "
90
+ "contact us to get help.",
91
+ help_url="https://todo",
92
+ )
93
+ model_config = f"{version}.yaml"
94
+ sam2_model = build_sam2(
95
+ model_config, model_package_content["model.pt"], device=device
96
+ )
97
+ transforms = SAM2Transforms(
98
+ resolution=sam2_model.image_size,
99
+ mask_threshold=0.0,
100
+ max_hole_area=0.0,
101
+ max_sprinkle_area=0.0,
102
+ disable_torch_jit=disable_sam2_torch_jit_transforms,
103
+ )
104
+ return cls(
105
+ model=sam2_model,
106
+ transform=transforms,
107
+ device=device,
108
+ max_batch_size=max_batch_size,
109
+ sam2_image_embeddings_cache=sam2_image_embeddings_cache,
110
+ sam2_low_resolution_masks_cache=sam2_low_resolution_masks_cache,
111
+ )
112
+
113
+ def __init__(
114
+ self,
115
+ model: SAM2Base,
116
+ transform: SAM2Transforms,
117
+ device: torch.device,
118
+ max_batch_size: int,
119
+ sam2_image_embeddings_cache: Sam2ImageEmbeddingsCache,
120
+ sam2_low_resolution_masks_cache: Sam2LowResolutionMasksCache,
121
+ ):
122
+ self._model = model
123
+ self._transform = transform
124
+ self._device = device
125
+ self._max_batch_size = max_batch_size
126
+ self._bb_feat_sizes = [
127
+ (256, 256),
128
+ (128, 128),
129
+ (64, 64),
130
+ ]
131
+ self._sam2_image_embeddings_cache = sam2_image_embeddings_cache
132
+ self._sam2_low_resolution_masks_cache = sam2_low_resolution_masks_cache
133
+
134
+ def embed_images(
135
+ self,
136
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
137
+ use_embeddings_cache: bool = True,
138
+ **kwargs,
139
+ ) -> List[SAM2ImageEmbeddings]:
140
+ model_input_images, image_hashes, original_image_sizes = (
141
+ self.pre_process_images(
142
+ images=images,
143
+ **kwargs,
144
+ )
145
+ )
146
+ embeddings_from_cache: Dict[int, SAM2ImageEmbeddings] = {}
147
+ images_to_compute, hashes_of_images_to_compute, sizes_of_images_to_compute = (
148
+ [],
149
+ [],
150
+ [],
151
+ )
152
+ for idx, (image, image_hash, image_size) in enumerate(
153
+ zip(model_input_images, image_hashes, original_image_sizes)
154
+ ):
155
+ cache_content = None
156
+ if use_embeddings_cache:
157
+ cache_content = self._sam2_image_embeddings_cache.retrieve_embeddings(
158
+ key=image_hash
159
+ )
160
+ if cache_content is not None:
161
+ cache_content = cache_content.to(device=self._device)
162
+ embeddings_from_cache[idx] = cache_content
163
+ else:
164
+ images_to_compute.append(image)
165
+ hashes_of_images_to_compute.append(image_hash)
166
+ sizes_of_images_to_compute.append(image_size)
167
+ if len(images_to_compute) > 0:
168
+ images_to_compute = torch.stack(images_to_compute, dim=0)
169
+ computed_embeddings = self.forward_image_embeddings(
170
+ model_input_images=images_to_compute,
171
+ image_hashes=hashes_of_images_to_compute,
172
+ original_image_sizes=sizes_of_images_to_compute,
173
+ )
174
+ computed_embeddings_idx = 0
175
+ result_embeddings = []
176
+ for i in range(len(model_input_images)):
177
+ if i in embeddings_from_cache:
178
+ result_embeddings.append(embeddings_from_cache[i])
179
+ else:
180
+ result_embeddings.append(
181
+ computed_embeddings[computed_embeddings_idx]
182
+ )
183
+ computed_embeddings_idx += 1
184
+ else:
185
+ result_embeddings = [
186
+ embeddings_from_cache[i] for i in range(len(model_input_images))
187
+ ]
188
+ if use_embeddings_cache:
189
+ for embeddings in result_embeddings:
190
+ self._sam2_image_embeddings_cache.save_embeddings(
191
+ key=embeddings.image_hash, embeddings=embeddings
192
+ )
193
+ return result_embeddings
194
+
195
+ def pre_process_images(
196
+ self,
197
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
198
+ **kwargs,
199
+ ) -> Tuple[torch.Tensor, List[str], List[Tuple[int, int]]]:
200
+ if isinstance(images, torch.Tensor):
201
+ images = images.to(device=self._device)
202
+ if len(images.shape) == 4:
203
+ image_hashes = [compute_image_hash(image=image) for image in images]
204
+ original_image_sizes = [tuple(images.shape[2:4])] * images.shape[0]
205
+ model_input_images = self._transform.transforms(images / 255.0)
206
+ else:
207
+ image_hashes = [compute_image_hash(image=images)]
208
+ original_image_sizes = [tuple(images.shape[1:3])]
209
+ model_input_images = self._transform.transforms(
210
+ (images / 255).unsqueeze(dim=0)
211
+ )
212
+ else:
213
+ if isinstance(images, list):
214
+ image_hashes = [compute_image_hash(image=image) for image in images]
215
+ original_image_sizes = []
216
+ model_input_images = []
217
+ for image in images:
218
+ if isinstance(image, np.ndarray):
219
+ original_image_sizes.append(image.shape[:2])
220
+ input_image = self._transform(image).to(self._device)
221
+ model_input_images.append(input_image)
222
+ else:
223
+ original_image_sizes.append(tuple(image.shape[1:3]))
224
+ image = image.to(self._device)
225
+ input_image = self._transform.transforms(image / 255)
226
+ model_input_images.append(input_image)
227
+ model_input_images = torch.stack(model_input_images, dim=0)
228
+ else:
229
+ image_hashes = [compute_image_hash(image=images)]
230
+ original_image_sizes = [images.shape[:2]]
231
+ model_input_images = (
232
+ self._transform(images).to(self._device).unsqueeze(dim=0)
233
+ )
234
+ return model_input_images, image_hashes, original_image_sizes
235
+
236
+ @torch.inference_mode()
237
+ def forward_image_embeddings(
238
+ self,
239
+ model_input_images: torch.Tensor,
240
+ image_hashes: List[str],
241
+ original_image_sizes: List[Tuple[int, int]],
242
+ **kwargs,
243
+ ) -> List[SAM2ImageEmbeddings]:
244
+ result_embeddings = []
245
+ for i in range(0, model_input_images.shape[0], self._max_batch_size):
246
+ input_images_batch = model_input_images[
247
+ i : i + self._max_batch_size
248
+ ].contiguous()
249
+ batch_size = input_images_batch.shape[0]
250
+ backbone_out = self._model.forward_image(input_images_batch)
251
+ _, vision_feats, _, _ = self._model._prepare_backbone_features(backbone_out)
252
+ if self._model.directly_add_no_mem_embed:
253
+ vision_feats[-1] = vision_feats[-1] + self._model.no_mem_embed
254
+ feats = [
255
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
256
+ for feat, feat_size in zip(
257
+ vision_feats[::-1], self._bb_feat_sizes[::-1]
258
+ )
259
+ ][::-1]
260
+ for image_idx in range(batch_size):
261
+ image_embeddings = feats[-1][image_idx].unsqueeze(dim=0)
262
+ high_resolution_features = [
263
+ feature[image_idx].unsqueeze(dim=0) for feature in feats[:-1]
264
+ ]
265
+ result_embeddings.append(
266
+ SAM2ImageEmbeddings(
267
+ image_hash=image_hashes[i + image_idx],
268
+ image_size_hw=original_image_sizes[i + image_idx],
269
+ embeddings=image_embeddings,
270
+ high_resolution_features=high_resolution_features,
271
+ )
272
+ )
273
+ return result_embeddings
274
+
275
+ def segment_images(
276
+ self,
277
+ images: Optional[
278
+ Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]]
279
+ ] = None,
280
+ embeddings: Optional[
281
+ Union[List[SAM2ImageEmbeddings], SAM2ImageEmbeddings]
282
+ ] = None,
283
+ point_coordinates: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
284
+ point_labels: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
285
+ boxes: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
286
+ mask_input: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
287
+ multi_mask_output: bool = True,
288
+ return_logits: bool = False,
289
+ input_color_format: Optional[ColorFormat] = None,
290
+ mask_threshold: Optional[float] = None,
291
+ load_from_mask_input_cache: bool = False,
292
+ save_to_mask_input_cache: bool = False,
293
+ use_embeddings_cache: bool = True,
294
+ **kwargs,
295
+ ) -> List[SAM2Prediction]:
296
+ if images is None and embeddings is None:
297
+ raise ModelInputError(
298
+ message="Attempted to use SAM model segment_images(...) method not providing valid input - "
299
+ "neither `images` nor `embeddings` parameter is given. If you run inference locally, "
300
+ "verify your integration making sure that the model interface is used correctly. Running "
301
+ "on Roboflow platform - contact us to get help.",
302
+ help_url="https://todo",
303
+ )
304
+ if images is not None:
305
+ embeddings = self.embed_images(
306
+ images=images,
307
+ input_color_format=input_color_format,
308
+ use_embeddings_cache=use_embeddings_cache,
309
+ **kwargs,
310
+ )
311
+ else:
312
+ embeddings = maybe_wrap_in_list(value=embeddings)
313
+ image_hashes = [e.image_hash for e in embeddings]
314
+ original_image_sizes = [e.image_size_hw for e in embeddings]
315
+ point_coordinates = maybe_wrap_in_list(value=point_coordinates)
316
+ point_labels = maybe_wrap_in_list(value=point_labels)
317
+ boxes = maybe_wrap_in_list(value=boxes)
318
+ mask_input = maybe_wrap_in_list(value=mask_input)
319
+ point_coordinates, point_labels, boxes, mask_input = equalize_batch_size(
320
+ embeddings_batch_size=len(embeddings),
321
+ point_coordinates=point_coordinates,
322
+ point_labels=point_labels,
323
+ boxes=boxes,
324
+ mask_input=mask_input,
325
+ )
326
+ point_coordinates, point_labels, boxes, mask_input = pre_process_prompts(
327
+ point_coordinates=point_coordinates,
328
+ point_labels=point_labels,
329
+ boxes=boxes,
330
+ mask_input=mask_input,
331
+ device=self._device,
332
+ transform=self._transform,
333
+ original_image_sizes=original_image_sizes,
334
+ )
335
+ predictions = []
336
+ for (
337
+ image_embedding,
338
+ image_hash,
339
+ image_size,
340
+ image_point_coordinates,
341
+ image_point_labels,
342
+ image_boxes,
343
+ image_mask_input,
344
+ ) in generate_model_inputs(
345
+ embeddings=embeddings,
346
+ image_hashes=image_hashes,
347
+ original_image_sizes=original_image_sizes,
348
+ point_coordinates=point_coordinates,
349
+ point_labels=point_labels,
350
+ boxes=boxes,
351
+ mask_input=mask_input,
352
+ ):
353
+ serialized_prompt, prompt_hash = None, None
354
+ if save_to_mask_input_cache or load_from_mask_input_cache:
355
+ serialized_prompt = serialize_prompt(
356
+ point_coordinates=image_point_coordinates,
357
+ point_labels=image_point_labels,
358
+ boxes=image_boxes,
359
+ )
360
+ prompt_hash = hash_serialized_prompt(
361
+ serialized_prompt=serialized_prompt
362
+ )
363
+ if image_mask_input is None and load_from_mask_input_cache:
364
+ image_mask_input = attempt_load_image_mask_from_cache(
365
+ image_hash=image_hash,
366
+ serialized_prompt_hash=prompt_hash,
367
+ serialized_prompt=serialized_prompt,
368
+ sam2_low_resolution_masks_cache=self._sam2_low_resolution_masks_cache,
369
+ device=self._device,
370
+ )
371
+ prediction = predict_for_single_image(
372
+ model=self._model,
373
+ transform=self._transform,
374
+ embeddings=image_embedding,
375
+ original_image_size=image_size,
376
+ point_coordinates=image_point_coordinates,
377
+ point_labels=image_point_labels,
378
+ boxes=image_boxes,
379
+ mask_input=image_mask_input,
380
+ multi_mask_output=multi_mask_output,
381
+ return_logits=return_logits,
382
+ mask_threshold=mask_threshold,
383
+ )
384
+ if save_to_mask_input_cache and len(prediction[0].shape) == 3:
385
+ max_score_id = torch.argmax(prediction[1]).item()
386
+ mask = SAM2MaskCacheEntry(
387
+ prompt_hash=prompt_hash,
388
+ serialized_prompt=serialized_prompt,
389
+ mask=prediction[2][max_score_id].unsqueeze(dim=0),
390
+ )
391
+ self._sam2_low_resolution_masks_cache.save_mask(
392
+ key=image_hash,
393
+ mask=mask,
394
+ )
395
+ parsed_prediction = SAM2Prediction(
396
+ masks=prediction[0],
397
+ scores=prediction[1],
398
+ logits=prediction[2],
399
+ )
400
+ predictions.append(parsed_prediction)
401
+ return predictions
402
+
403
+
404
+ def decode_sam_version(config_path: str) -> str:
405
+ config = read_json(path=config_path)
406
+ version = config["version"]
407
+ if not isinstance(version, str):
408
+ raise ValueError("Could not decode SAM model version")
409
+ return version
410
+
411
+
412
+ def compute_image_hash(image: Union[torch.Tensor, np.ndarray]) -> str:
413
+ if isinstance(image, torch.Tensor):
414
+ image = image.cpu().numpy()
415
+ return hash_function(value=image.tobytes())
416
+
417
+
418
+ def hash_function(value: Union[str, bytes]) -> str:
419
+ return hashlib.sha1(value).hexdigest()
420
+
421
+
422
+ def maybe_wrap_in_list(value: Optional[Union[T, List[T]]]) -> Optional[List[T]]:
423
+ if value is None:
424
+ return None
425
+ if isinstance(value, list):
426
+ return value
427
+ return [value]
428
+
429
+
430
+ def equalize_batch_size(
431
+ embeddings_batch_size: int,
432
+ point_coordinates: Optional[List[ArrayOrTensor]],
433
+ point_labels: Optional[List[ArrayOrTensor]],
434
+ boxes: Optional[List[ArrayOrTensor]],
435
+ mask_input: Optional[List[ArrayOrTensor]],
436
+ ) -> Tuple[
437
+ Optional[List[ArrayOrTensor]],
438
+ Optional[List[ArrayOrTensor]],
439
+ Optional[List[ArrayOrTensor]],
440
+ Optional[List[ArrayOrTensor]],
441
+ ]:
442
+ if (
443
+ point_coordinates is not None
444
+ and len(point_coordinates) != embeddings_batch_size
445
+ ):
446
+ if len(point_coordinates) != 1:
447
+ raise ModelInputError(
448
+ message="When using SAM2 model, parameter `point_coordinates` was provided with invalid "
449
+ f"value indicating different input batch size ({len(point_coordinates)}) than provided "
450
+ f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
451
+ "integration making sure that the model interface is used correctly. "
452
+ "Running on Roboflow platform - contact us to get help.",
453
+ help_url="https://todo",
454
+ )
455
+ point_coordinates = point_coordinates * embeddings_batch_size
456
+ if point_labels is not None and len(point_labels) != embeddings_batch_size:
457
+ if len(point_labels) != 1:
458
+ raise ModelInputError(
459
+ message="When using SAM2 model, parameter `point_labels` was provided with invalid "
460
+ f"value indicating different input batch size ({len(point_labels)}) than provided "
461
+ f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
462
+ "integration making sure that the model interface is used correctly. "
463
+ "Running on Roboflow platform - contact us to get help.",
464
+ help_url="https://todo",
465
+ )
466
+ point_labels = point_labels * embeddings_batch_size
467
+ if boxes is not None and len(boxes) != embeddings_batch_size:
468
+ if len(boxes) != 1:
469
+ raise ModelInputError(
470
+ message="When using SAM2 model, parameter `boxes` was provided with invalid "
471
+ f"value indicating different input batch size ({len(boxes)}) than provided "
472
+ f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
473
+ "integration making sure that the model interface is used correctly. "
474
+ "Running on Roboflow platform - contact us to get help.",
475
+ help_url="https://todo",
476
+ )
477
+ boxes = boxes * embeddings_batch_size
478
+ if mask_input is not None and len(mask_input) != embeddings_batch_size:
479
+ if len(mask_input) != 1:
480
+ raise ModelInputError(
481
+ message="When using SAM2 model, parameter `mask_input` was provided with invalid "
482
+ f"value indicating different input batch size ({len(mask_input)}) than provided "
483
+ f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
484
+ "integration making sure that the model interface is used correctly. "
485
+ "Running on Roboflow platform - contact us to get help.",
486
+ help_url="https://todo",
487
+ )
488
+ mask_input = mask_input * embeddings_batch_size
489
+ prompts_first_dimension_characteristics = set()
490
+ at_max_one_box_expected = False
491
+ if point_coordinates is not None:
492
+ point_coordinates_characteristic = "-".join(
493
+ [str(p.shape[0]) for p in point_coordinates]
494
+ )
495
+ prompts_first_dimension_characteristics.add(point_coordinates_characteristic)
496
+ points_dimensions = set(len(p.shape) for p in point_coordinates)
497
+ if len(points_dimensions) != 1:
498
+ raise ModelInputError(
499
+ message="When using SAM2 model, in scenario when combination of `point_coordinates` provided with "
500
+ "different shapes for different input images, which makes the input invalid. "
501
+ "If you run inference locally, verify your integration making sure that the model interface is "
502
+ "used correctly. Running on Roboflow platform - contact us to get help.",
503
+ help_url="https://todo",
504
+ )
505
+ if points_dimensions.pop() == 2:
506
+ at_max_one_box_expected = True
507
+ if point_labels is not None:
508
+ point_labels_characteristic = "-".join([str(l.shape[0]) for l in point_labels])
509
+ prompts_first_dimension_characteristics.add(point_labels_characteristic)
510
+ if len(prompts_first_dimension_characteristics) > 1:
511
+ raise ModelInputError(
512
+ message="When using SAM2 model, in scenario when combination of `point_coordinates` and `point_labels` "
513
+ "provided, the model expect identical number of elements for each prompt component. "
514
+ "If you run inference locally, verify your integration making sure that the model interface is "
515
+ "used correctly. Running on Roboflow platform - contact us to get help.",
516
+ help_url="https://todo",
517
+ )
518
+ if boxes is not None:
519
+ boxes_characteristic = "-".join(
520
+ [str(b.shape[0]) if len(b.shape) > 1 else "1" for b in boxes]
521
+ )
522
+ prompts_first_dimension_characteristics.add(boxes_characteristic)
523
+ if at_max_one_box_expected:
524
+ if not all(b.shape[0] == 1 if len(b.shape) > 1 else True for b in boxes):
525
+ raise ModelInputError(
526
+ message="When using SAM2 model, with `point_coordinates` provided for single box, each box in "
527
+ "`boxes` parameter must only define single bounding box."
528
+ "If you run inference locally, verify your integration making sure that the model "
529
+ "interface is used correctly. Running on Roboflow platform - contact us to get help.",
530
+ help_url="https://todo",
531
+ )
532
+ elif len(prompts_first_dimension_characteristics) > 1:
533
+ raise ModelInputError(
534
+ message="When using SAM2 model, in scenario when combination of `point_coordinates`, `point_labels`, "
535
+ "`boxes` provided, the model expect identical number of elements for each prompt component. "
536
+ "If you run inference locally, verify your integration making sure that the model interface is "
537
+ "used correctly. Running on Roboflow platform - contact us to get help.",
538
+ help_url="https://todo",
539
+ )
540
+ if mask_input is not None:
541
+ mask_input = [i[None, :, :] if len(i.shape) == 2 else i for i in mask_input]
542
+ if any(len(i.shape) != 3 or i.shape[0] != 1 for i in mask_input):
543
+ raise ModelInputError(
544
+ message="When using SAM2 model with `mask_input`, each mask must be 3D tensor of shape (1, H, W). "
545
+ "If you run inference locally, verify your integration making sure that the model interface is "
546
+ "used correctly. Running on Roboflow platform - contact us to get help.",
547
+ help_url="https://todo",
548
+ )
549
+ return point_coordinates, point_labels, boxes, mask_input
550
+
551
+
552
+ def generate_model_inputs(
553
+ embeddings: List[SAM2ImageEmbeddings],
554
+ image_hashes: List[str],
555
+ original_image_sizes: List[Tuple[int, int]],
556
+ point_coordinates: Optional[List[torch.Tensor]],
557
+ point_labels: Optional[List[torch.Tensor]],
558
+ boxes: Optional[List[torch.Tensor]],
559
+ mask_input: Optional[List[torch.Tensor]],
560
+ ) -> Generator[
561
+ Tuple[
562
+ SAM2ImageEmbeddings,
563
+ str,
564
+ Tuple[int, int],
565
+ Optional[torch.Tensor],
566
+ Optional[torch.Tensor],
567
+ Optional[torch.Tensor],
568
+ Optional[torch.Tensor],
569
+ ],
570
+ None,
571
+ None,
572
+ ]:
573
+ if point_coordinates is None:
574
+ point_coordinates = [None] * len(embeddings)
575
+ if point_labels is None:
576
+ point_labels = [None] * len(embeddings)
577
+ if boxes is None:
578
+ boxes = [None] * len(embeddings)
579
+ if mask_input is None:
580
+ mask_input = [None] * len(embeddings)
581
+ for embedding, hash_value, image_size, coords, labels, box, mask in zip(
582
+ embeddings,
583
+ image_hashes,
584
+ original_image_sizes,
585
+ point_coordinates,
586
+ point_labels,
587
+ boxes,
588
+ mask_input,
589
+ ):
590
+ yield embedding, hash_value, image_size, coords, labels, box, mask
591
+
592
+
593
+ def pre_process_prompts(
594
+ point_coordinates: Optional[List[ArrayOrTensor]],
595
+ point_labels: Optional[List[ArrayOrTensor]],
596
+ boxes: Optional[List[ArrayOrTensor]],
597
+ mask_input: Optional[List[ArrayOrTensor]],
598
+ device: torch.device,
599
+ transform: SAM2Transforms,
600
+ original_image_sizes: List[Tuple[int, int]],
601
+ normalize_coordinates: bool = True,
602
+ ) -> Tuple[
603
+ Optional[List[torch.Tensor]],
604
+ Optional[List[torch.Tensor]],
605
+ Optional[List[torch.Tensor]],
606
+ Optional[List[torch.Tensor]],
607
+ ]:
608
+ (
609
+ processed_point_coordinates,
610
+ processed_point_labels,
611
+ processed_boxes,
612
+ processed_mask_input,
613
+ ) = (None, None, None, None)
614
+ if point_labels is not None and point_coordinates is None:
615
+ raise ModelInputError(
616
+ message="When using SAM2 model, provided `point_coordinates` without `point_labels` which makes "
617
+ "invalid input. If you run inference locally, verify your integration making sure that the "
618
+ "model interface is used correctly. Running on Roboflow platform - contact us to get help.",
619
+ help_url="https://todo",
620
+ )
621
+ if point_coordinates is not None:
622
+ if point_labels is None:
623
+ raise ModelInputError(
624
+ message="When using SAM2 model, provided `point_coordinates` without `point_labels` which makes "
625
+ "invalid input. If you run inference locally, verify your integration making sure that the "
626
+ "model interface is used correctly. Running on Roboflow platform - contact us to get help.",
627
+ help_url="https://todo",
628
+ )
629
+ processed_point_coordinates = []
630
+ processed_point_labels = []
631
+ for single_label, single_point_coordinates, image_size in zip(
632
+ point_labels, point_coordinates, original_image_sizes
633
+ ):
634
+ if isinstance(single_point_coordinates, torch.Tensor):
635
+ single_point_coordinates = single_point_coordinates.to(
636
+ dtype=torch.float, device=device
637
+ )
638
+ else:
639
+ single_point_coordinates = torch.as_tensor(
640
+ single_point_coordinates, dtype=torch.float, device=device
641
+ )
642
+ single_point_coordinates = transform.transform_coords(
643
+ single_point_coordinates,
644
+ normalize=normalize_coordinates,
645
+ orig_hw=image_size,
646
+ )
647
+ dimension_to_unsqueeze = len(single_point_coordinates.shape) == 2
648
+ if dimension_to_unsqueeze:
649
+ single_point_coordinates = single_point_coordinates[None, ...]
650
+ processed_point_coordinates.append(single_point_coordinates)
651
+ if isinstance(single_label, torch.Tensor):
652
+ single_label = single_label.to(dtype=torch.int, device=device)
653
+ else:
654
+ single_label = torch.as_tensor(
655
+ single_label, dtype=torch.int, device=device
656
+ )
657
+ if dimension_to_unsqueeze:
658
+ single_label = single_label[None, ...]
659
+ processed_point_labels.append(single_label)
660
+ if boxes is not None:
661
+ processed_boxes = []
662
+ for box, image_size in zip(boxes, original_image_sizes):
663
+ if isinstance(box, torch.Tensor):
664
+ box = box.to(dtype=torch.float, device=device)
665
+ else:
666
+ box = torch.as_tensor(box, dtype=torch.float, device=device)
667
+ box = transform.transform_boxes(
668
+ box,
669
+ normalize=normalize_coordinates,
670
+ orig_hw=image_size,
671
+ ) # Bx2x2
672
+ processed_boxes.append(box)
673
+ if mask_input is not None:
674
+ processed_mask_input = []
675
+ for single_mask in mask_input:
676
+ if isinstance(single_mask, torch.Tensor):
677
+ single_mask = single_mask.to(dtype=torch.float, device=device)
678
+ else:
679
+ single_mask = torch.as_tensor(
680
+ single_mask, dtype=torch.float, device=device
681
+ )
682
+ if len(single_mask.shape) == 3:
683
+ single_mask = single_mask[None, :, :, :]
684
+ processed_mask_input.append(single_mask)
685
+ return (
686
+ processed_point_coordinates,
687
+ processed_point_labels,
688
+ processed_boxes,
689
+ processed_mask_input,
690
+ )
691
+
692
+
693
+ @torch.inference_mode()
694
+ def predict_for_single_image(
695
+ model: SAM2Base,
696
+ transform: SAM2Transforms,
697
+ embeddings: SAM2ImageEmbeddings,
698
+ original_image_size: Tuple[int, int],
699
+ point_coordinates: Optional[torch.Tensor],
700
+ point_labels: Optional[torch.Tensor],
701
+ boxes: Optional[torch.Tensor] = None,
702
+ mask_input: Optional[torch.Tensor] = None,
703
+ multi_mask_output: bool = True,
704
+ return_logits: bool = False,
705
+ mask_threshold: Optional[float] = None,
706
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
707
+ if point_coordinates is not None:
708
+ concat_points = (point_coordinates, point_labels)
709
+ else:
710
+ concat_points = None
711
+ if boxes is not None:
712
+ box_coords = boxes.reshape(-1, 2, 2)
713
+ box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
714
+ box_labels = box_labels.repeat(boxes.size(0), 1)
715
+ # we merge "boxes" and "points" into a single "concat_points" input (where
716
+ # boxes are added at the beginning) to sam_prompt_encoder
717
+ if concat_points is not None:
718
+ concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
719
+ concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
720
+ concat_points = (concat_coords, concat_labels)
721
+ else:
722
+ concat_points = (box_coords, box_labels)
723
+ sparse_embeddings, dense_embeddings = model.sam_prompt_encoder(
724
+ points=concat_points,
725
+ boxes=None,
726
+ masks=mask_input,
727
+ )
728
+ batched_mode = concat_points is not None and concat_points[0].shape[0] > 1
729
+ low_res_masks, iou_predictions, _, _ = model.sam_mask_decoder(
730
+ image_embeddings=embeddings.embeddings,
731
+ image_pe=model.sam_prompt_encoder.get_dense_pe(),
732
+ sparse_prompt_embeddings=sparse_embeddings,
733
+ dense_prompt_embeddings=dense_embeddings,
734
+ multimask_output=multi_mask_output,
735
+ repeat_image=batched_mode,
736
+ high_res_features=embeddings.high_resolution_features,
737
+ )
738
+ masks = transform.postprocess_masks(low_res_masks, original_image_size)
739
+ low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
740
+ if not return_logits:
741
+ masks = masks > (mask_threshold or 0.0)
742
+ if masks.shape[0] == 1:
743
+ return masks[0], iou_predictions[0], low_res_masks[0]
744
+ else:
745
+ return masks, iou_predictions, low_res_masks
746
+
747
+
748
+ def serialize_prompt(
749
+ point_coordinates: Optional[torch.Tensor],
750
+ point_labels: Optional[torch.Tensor],
751
+ boxes: Optional[torch.Tensor],
752
+ ) -> List[dict]:
753
+ if point_coordinates is None and point_labels is None and boxes is None:
754
+ return []
755
+ sizes = set()
756
+ if point_coordinates is not None:
757
+ sizes.add(point_coordinates.shape[0])
758
+ if point_labels is not None:
759
+ sizes.add(point_labels.shape[0])
760
+ if boxes is not None:
761
+ sizes.add(boxes.shape[0])
762
+ if len(sizes) != 1:
763
+ raise AssumptionError(
764
+ message="In SAM2 implementation, after pre-processing, all prompt elements must have the same "
765
+ "leading dimension. This assumption just got violated. This is most likely a bug. "
766
+ "You can help us sorting out this problem by submitting an issue: "
767
+ "https://github.com/roboflow/inference/issues",
768
+ help_url="https://todo",
769
+ )
770
+ broadcast_size = sizes.pop()
771
+ point_coordinates_list = (
772
+ point_coordinates.tolist()
773
+ if point_coordinates is not None
774
+ else [None] * broadcast_size
775
+ )
776
+ point_labels_list = (
777
+ point_labels.tolist() if point_labels is not None else [None] * broadcast_size
778
+ )
779
+ boxes_list = (
780
+ boxes.reshape(-1).tolist() if boxes is not None else [None] * broadcast_size
781
+ )
782
+ results = []
783
+ for points, labels, box in zip(
784
+ point_coordinates_list, point_labels_list, boxes_list
785
+ ):
786
+ points_serialized = []
787
+ if points is not None and labels is not None:
788
+ for point, label in zip(points, labels):
789
+ points_serialized.append(
790
+ {
791
+ "x": (
792
+ point[0].item()
793
+ if isinstance(point[0], torch.Tensor)
794
+ else point[0]
795
+ ),
796
+ "y": (
797
+ point[1].item()
798
+ if isinstance(point[1], torch.Tensor)
799
+ else point[1]
800
+ ),
801
+ "positive": (
802
+ label.item() if isinstance(labels, torch.Tensor) else label
803
+ ),
804
+ }
805
+ )
806
+ if box is not None:
807
+ box_serialized = box
808
+ else:
809
+ box_serialized = None
810
+ results.append({"points": points_serialized, "box": box_serialized})
811
+ return results
812
+
813
+
814
+ def hash_serialized_prompt(serialized_prompt: List[dict]) -> str:
815
+ serialized = json.dumps(serialized_prompt, sort_keys=True, separators=(",", ":"))
816
+ return hashlib.sha1(serialized.encode("utf-8")).hexdigest()
817
+
818
+
819
+ def attempt_load_image_mask_from_cache(
820
+ image_hash: str,
821
+ serialized_prompt_hash: str,
822
+ serialized_prompt: List[dict],
823
+ sam2_low_resolution_masks_cache: Sam2LowResolutionMasksCache,
824
+ device: torch.device,
825
+ ) -> Optional[torch.Tensor]:
826
+ all_masks_for_image = sam2_low_resolution_masks_cache.retrieve_all_masks_for_image(
827
+ key=image_hash
828
+ )
829
+ if not all_masks_for_image:
830
+ return None
831
+ if len(serialized_prompt) == 0:
832
+ return None
833
+ return find_prior_prompt_in_cache(
834
+ serialized_prompt_hash=serialized_prompt_hash,
835
+ serialized_prompt=serialized_prompt,
836
+ matching_cache_entries=all_masks_for_image,
837
+ device=device,
838
+ )
839
+
840
+
841
+ def find_prior_prompt_in_cache(
842
+ serialized_prompt_hash: str,
843
+ serialized_prompt: List[dict],
844
+ matching_cache_entries: List[SAM2MaskCacheEntry],
845
+ device: torch.device,
846
+ ) -> Optional[torch.Tensor]:
847
+ maxed_size = 0
848
+ best_match: Optional[SAM2MaskCacheEntry] = None
849
+ desired_size = len(serialized_prompt) - 1
850
+ for cache_entry in matching_cache_entries[::-1]:
851
+ is_viable = is_prompt_strict_subset(
852
+ assumed_sub_set_prompt=(
853
+ cache_entry.prompt_hash,
854
+ cache_entry.serialized_prompt,
855
+ ),
856
+ assumed_super_set_prompt=(serialized_prompt_hash, serialized_prompt),
857
+ )
858
+ if not is_viable:
859
+ continue
860
+
861
+ # short circuit search if we find prompt with one less point (most recent possible mask)
862
+ current_cache_entry_prompt_size = len(cache_entry.serialized_prompt)
863
+ if current_cache_entry_prompt_size == desired_size:
864
+ return cache_entry.mask.to(device=device)
865
+ if current_cache_entry_prompt_size >= maxed_size:
866
+ maxed_size = current_cache_entry_prompt_size
867
+ best_match = cache_entry
868
+ return best_match.mask.to(device=device)
869
+
870
+
871
+ def is_prompt_strict_subset(
872
+ assumed_sub_set_prompt: Tuple[str, List[dict]],
873
+ assumed_super_set_prompt: Tuple[str, List[dict]],
874
+ ) -> bool:
875
+ if assumed_sub_set_prompt[0] == assumed_super_set_prompt[0]:
876
+ return False
877
+ super_set_prompt_copy = copy(assumed_super_set_prompt[1])
878
+ for sub_set_prompt_element in assumed_sub_set_prompt[1]:
879
+ found_match = False
880
+ for super_set_prompt_element in super_set_prompt_copy:
881
+ boxes_matching = (
882
+ sub_set_prompt_element["box"] == super_set_prompt_element["box"]
883
+ )
884
+ if not boxes_matching:
885
+ continue
886
+ sub_set_prompt_element_points = {
887
+ get_hashable_point(point=point)
888
+ for point in sub_set_prompt_element.get("points", [])
889
+ }
890
+ super_set_prompt_element_points = {
891
+ get_hashable_point(point=point)
892
+ for point in super_set_prompt_element.get("points", [])
893
+ }
894
+ if sub_set_prompt_element_points <= super_set_prompt_element_points:
895
+ super_set_prompt_copy.remove(super_set_prompt_element)
896
+ found_match = True
897
+ break
898
+ if not found_match:
899
+ return False
900
+ # every prompt in subset has a matching super prompt
901
+ return True
902
+
903
+
904
+ def get_hashable_point(point: dict) -> str:
905
+ return json.dumps(point, sort_keys=True, separators=(",", ":"))