inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,905 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import json
|
|
3
|
+
from copy import copy
|
|
4
|
+
from typing import Dict, Generator, List, Optional, Tuple, TypeVar, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from sam2.build_sam import build_sam2
|
|
9
|
+
from sam2.modeling.sam2_base import SAM2Base
|
|
10
|
+
from sam2.utils.transforms import SAM2Transforms
|
|
11
|
+
|
|
12
|
+
from inference_models import ColorFormat
|
|
13
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
14
|
+
from inference_models.errors import (
|
|
15
|
+
AssumptionError,
|
|
16
|
+
CorruptedModelPackageError,
|
|
17
|
+
ModelInputError,
|
|
18
|
+
)
|
|
19
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
20
|
+
from inference_models.models.sam2.cache import (
|
|
21
|
+
Sam2ImageEmbeddingsCache,
|
|
22
|
+
Sam2ImageEmbeddingsCacheNullObject,
|
|
23
|
+
Sam2LowResolutionMasksCache,
|
|
24
|
+
Sam2LowResolutionMasksCacheNullObject,
|
|
25
|
+
)
|
|
26
|
+
from inference_models.models.sam2.entities import (
|
|
27
|
+
SAM2ImageEmbeddings,
|
|
28
|
+
SAM2MaskCacheEntry,
|
|
29
|
+
SAM2Prediction,
|
|
30
|
+
)
|
|
31
|
+
from inference_models.utils.file_system import read_json
|
|
32
|
+
|
|
33
|
+
ArrayOrTensor = Union[np.ndarray, torch.Tensor]
|
|
34
|
+
T = TypeVar("T")
|
|
35
|
+
|
|
36
|
+
MAX_SAM2_BATCH_SIZE = 8
|
|
37
|
+
|
|
38
|
+
SUPPORTED_VERSIONS = {
|
|
39
|
+
"sam2_hiera_t",
|
|
40
|
+
"sam2_hiera_s",
|
|
41
|
+
"sam2_hiera_b+",
|
|
42
|
+
"sam2_hiera_l",
|
|
43
|
+
"sam2.1_hiera_t",
|
|
44
|
+
"sam2.1_hiera_s",
|
|
45
|
+
"sam2.1_hiera_b+",
|
|
46
|
+
"sam2.1_hiera_l",
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class SAM2Torch:
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def from_pretrained(
|
|
54
|
+
cls,
|
|
55
|
+
model_name_or_path: str,
|
|
56
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
57
|
+
max_batch_size: int = MAX_SAM2_BATCH_SIZE,
|
|
58
|
+
disable_sam2_torch_jit_transforms: bool = True,
|
|
59
|
+
sam2_image_embeddings_cache: Optional[Sam2ImageEmbeddingsCache] = None,
|
|
60
|
+
sam2_low_resolution_masks_cache: Optional[Sam2LowResolutionMasksCache] = None,
|
|
61
|
+
**kwargs,
|
|
62
|
+
) -> "SAM2Torch":
|
|
63
|
+
if sam2_image_embeddings_cache is None:
|
|
64
|
+
sam2_image_embeddings_cache = Sam2ImageEmbeddingsCacheNullObject()
|
|
65
|
+
if sam2_low_resolution_masks_cache is None:
|
|
66
|
+
sam2_low_resolution_masks_cache = Sam2LowResolutionMasksCacheNullObject()
|
|
67
|
+
model_package_content = get_model_package_contents(
|
|
68
|
+
model_package_dir=model_name_or_path,
|
|
69
|
+
elements=[
|
|
70
|
+
"model.pt",
|
|
71
|
+
"sam_configuration.json",
|
|
72
|
+
],
|
|
73
|
+
)
|
|
74
|
+
try:
|
|
75
|
+
version = decode_sam_version(
|
|
76
|
+
config_path=model_package_content["sam_configuration.json"]
|
|
77
|
+
)
|
|
78
|
+
except Exception as error:
|
|
79
|
+
raise CorruptedModelPackageError(
|
|
80
|
+
message="Cold not decode SAM2 model version. If you see this error running inference locally, "
|
|
81
|
+
"verify the contents of model package. If you see the error running on Roboflow platform - "
|
|
82
|
+
"contact us to get help.",
|
|
83
|
+
help_url="https://todo",
|
|
84
|
+
) from error
|
|
85
|
+
if version not in SUPPORTED_VERSIONS:
|
|
86
|
+
raise CorruptedModelPackageError(
|
|
87
|
+
message=f"Detected unsupported version of SAM2 model: {version}. Supported versions: "
|
|
88
|
+
f"are {SUPPORTED_VERSIONS}. If you run inference locally, verify the correctness of "
|
|
89
|
+
f"SAM2 model package. If you see the error running on Roboflow platform - "
|
|
90
|
+
"contact us to get help.",
|
|
91
|
+
help_url="https://todo",
|
|
92
|
+
)
|
|
93
|
+
model_config = f"{version}.yaml"
|
|
94
|
+
sam2_model = build_sam2(
|
|
95
|
+
model_config, model_package_content["model.pt"], device=device
|
|
96
|
+
)
|
|
97
|
+
transforms = SAM2Transforms(
|
|
98
|
+
resolution=sam2_model.image_size,
|
|
99
|
+
mask_threshold=0.0,
|
|
100
|
+
max_hole_area=0.0,
|
|
101
|
+
max_sprinkle_area=0.0,
|
|
102
|
+
disable_torch_jit=disable_sam2_torch_jit_transforms,
|
|
103
|
+
)
|
|
104
|
+
return cls(
|
|
105
|
+
model=sam2_model,
|
|
106
|
+
transform=transforms,
|
|
107
|
+
device=device,
|
|
108
|
+
max_batch_size=max_batch_size,
|
|
109
|
+
sam2_image_embeddings_cache=sam2_image_embeddings_cache,
|
|
110
|
+
sam2_low_resolution_masks_cache=sam2_low_resolution_masks_cache,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
model: SAM2Base,
|
|
116
|
+
transform: SAM2Transforms,
|
|
117
|
+
device: torch.device,
|
|
118
|
+
max_batch_size: int,
|
|
119
|
+
sam2_image_embeddings_cache: Sam2ImageEmbeddingsCache,
|
|
120
|
+
sam2_low_resolution_masks_cache: Sam2LowResolutionMasksCache,
|
|
121
|
+
):
|
|
122
|
+
self._model = model
|
|
123
|
+
self._transform = transform
|
|
124
|
+
self._device = device
|
|
125
|
+
self._max_batch_size = max_batch_size
|
|
126
|
+
self._bb_feat_sizes = [
|
|
127
|
+
(256, 256),
|
|
128
|
+
(128, 128),
|
|
129
|
+
(64, 64),
|
|
130
|
+
]
|
|
131
|
+
self._sam2_image_embeddings_cache = sam2_image_embeddings_cache
|
|
132
|
+
self._sam2_low_resolution_masks_cache = sam2_low_resolution_masks_cache
|
|
133
|
+
|
|
134
|
+
def embed_images(
|
|
135
|
+
self,
|
|
136
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
137
|
+
use_embeddings_cache: bool = True,
|
|
138
|
+
**kwargs,
|
|
139
|
+
) -> List[SAM2ImageEmbeddings]:
|
|
140
|
+
model_input_images, image_hashes, original_image_sizes = (
|
|
141
|
+
self.pre_process_images(
|
|
142
|
+
images=images,
|
|
143
|
+
**kwargs,
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
embeddings_from_cache: Dict[int, SAM2ImageEmbeddings] = {}
|
|
147
|
+
images_to_compute, hashes_of_images_to_compute, sizes_of_images_to_compute = (
|
|
148
|
+
[],
|
|
149
|
+
[],
|
|
150
|
+
[],
|
|
151
|
+
)
|
|
152
|
+
for idx, (image, image_hash, image_size) in enumerate(
|
|
153
|
+
zip(model_input_images, image_hashes, original_image_sizes)
|
|
154
|
+
):
|
|
155
|
+
cache_content = None
|
|
156
|
+
if use_embeddings_cache:
|
|
157
|
+
cache_content = self._sam2_image_embeddings_cache.retrieve_embeddings(
|
|
158
|
+
key=image_hash
|
|
159
|
+
)
|
|
160
|
+
if cache_content is not None:
|
|
161
|
+
cache_content = cache_content.to(device=self._device)
|
|
162
|
+
embeddings_from_cache[idx] = cache_content
|
|
163
|
+
else:
|
|
164
|
+
images_to_compute.append(image)
|
|
165
|
+
hashes_of_images_to_compute.append(image_hash)
|
|
166
|
+
sizes_of_images_to_compute.append(image_size)
|
|
167
|
+
if len(images_to_compute) > 0:
|
|
168
|
+
images_to_compute = torch.stack(images_to_compute, dim=0)
|
|
169
|
+
computed_embeddings = self.forward_image_embeddings(
|
|
170
|
+
model_input_images=images_to_compute,
|
|
171
|
+
image_hashes=hashes_of_images_to_compute,
|
|
172
|
+
original_image_sizes=sizes_of_images_to_compute,
|
|
173
|
+
)
|
|
174
|
+
computed_embeddings_idx = 0
|
|
175
|
+
result_embeddings = []
|
|
176
|
+
for i in range(len(model_input_images)):
|
|
177
|
+
if i in embeddings_from_cache:
|
|
178
|
+
result_embeddings.append(embeddings_from_cache[i])
|
|
179
|
+
else:
|
|
180
|
+
result_embeddings.append(
|
|
181
|
+
computed_embeddings[computed_embeddings_idx]
|
|
182
|
+
)
|
|
183
|
+
computed_embeddings_idx += 1
|
|
184
|
+
else:
|
|
185
|
+
result_embeddings = [
|
|
186
|
+
embeddings_from_cache[i] for i in range(len(model_input_images))
|
|
187
|
+
]
|
|
188
|
+
if use_embeddings_cache:
|
|
189
|
+
for embeddings in result_embeddings:
|
|
190
|
+
self._sam2_image_embeddings_cache.save_embeddings(
|
|
191
|
+
key=embeddings.image_hash, embeddings=embeddings
|
|
192
|
+
)
|
|
193
|
+
return result_embeddings
|
|
194
|
+
|
|
195
|
+
def pre_process_images(
|
|
196
|
+
self,
|
|
197
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
198
|
+
**kwargs,
|
|
199
|
+
) -> Tuple[torch.Tensor, List[str], List[Tuple[int, int]]]:
|
|
200
|
+
if isinstance(images, torch.Tensor):
|
|
201
|
+
images = images.to(device=self._device)
|
|
202
|
+
if len(images.shape) == 4:
|
|
203
|
+
image_hashes = [compute_image_hash(image=image) for image in images]
|
|
204
|
+
original_image_sizes = [tuple(images.shape[2:4])] * images.shape[0]
|
|
205
|
+
model_input_images = self._transform.transforms(images / 255.0)
|
|
206
|
+
else:
|
|
207
|
+
image_hashes = [compute_image_hash(image=images)]
|
|
208
|
+
original_image_sizes = [tuple(images.shape[1:3])]
|
|
209
|
+
model_input_images = self._transform.transforms(
|
|
210
|
+
(images / 255).unsqueeze(dim=0)
|
|
211
|
+
)
|
|
212
|
+
else:
|
|
213
|
+
if isinstance(images, list):
|
|
214
|
+
image_hashes = [compute_image_hash(image=image) for image in images]
|
|
215
|
+
original_image_sizes = []
|
|
216
|
+
model_input_images = []
|
|
217
|
+
for image in images:
|
|
218
|
+
if isinstance(image, np.ndarray):
|
|
219
|
+
original_image_sizes.append(image.shape[:2])
|
|
220
|
+
input_image = self._transform(image).to(self._device)
|
|
221
|
+
model_input_images.append(input_image)
|
|
222
|
+
else:
|
|
223
|
+
original_image_sizes.append(tuple(image.shape[1:3]))
|
|
224
|
+
image = image.to(self._device)
|
|
225
|
+
input_image = self._transform.transforms(image / 255)
|
|
226
|
+
model_input_images.append(input_image)
|
|
227
|
+
model_input_images = torch.stack(model_input_images, dim=0)
|
|
228
|
+
else:
|
|
229
|
+
image_hashes = [compute_image_hash(image=images)]
|
|
230
|
+
original_image_sizes = [images.shape[:2]]
|
|
231
|
+
model_input_images = (
|
|
232
|
+
self._transform(images).to(self._device).unsqueeze(dim=0)
|
|
233
|
+
)
|
|
234
|
+
return model_input_images, image_hashes, original_image_sizes
|
|
235
|
+
|
|
236
|
+
@torch.inference_mode()
|
|
237
|
+
def forward_image_embeddings(
|
|
238
|
+
self,
|
|
239
|
+
model_input_images: torch.Tensor,
|
|
240
|
+
image_hashes: List[str],
|
|
241
|
+
original_image_sizes: List[Tuple[int, int]],
|
|
242
|
+
**kwargs,
|
|
243
|
+
) -> List[SAM2ImageEmbeddings]:
|
|
244
|
+
result_embeddings = []
|
|
245
|
+
for i in range(0, model_input_images.shape[0], self._max_batch_size):
|
|
246
|
+
input_images_batch = model_input_images[
|
|
247
|
+
i : i + self._max_batch_size
|
|
248
|
+
].contiguous()
|
|
249
|
+
batch_size = input_images_batch.shape[0]
|
|
250
|
+
backbone_out = self._model.forward_image(input_images_batch)
|
|
251
|
+
_, vision_feats, _, _ = self._model._prepare_backbone_features(backbone_out)
|
|
252
|
+
if self._model.directly_add_no_mem_embed:
|
|
253
|
+
vision_feats[-1] = vision_feats[-1] + self._model.no_mem_embed
|
|
254
|
+
feats = [
|
|
255
|
+
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
|
256
|
+
for feat, feat_size in zip(
|
|
257
|
+
vision_feats[::-1], self._bb_feat_sizes[::-1]
|
|
258
|
+
)
|
|
259
|
+
][::-1]
|
|
260
|
+
for image_idx in range(batch_size):
|
|
261
|
+
image_embeddings = feats[-1][image_idx].unsqueeze(dim=0)
|
|
262
|
+
high_resolution_features = [
|
|
263
|
+
feature[image_idx].unsqueeze(dim=0) for feature in feats[:-1]
|
|
264
|
+
]
|
|
265
|
+
result_embeddings.append(
|
|
266
|
+
SAM2ImageEmbeddings(
|
|
267
|
+
image_hash=image_hashes[i + image_idx],
|
|
268
|
+
image_size_hw=original_image_sizes[i + image_idx],
|
|
269
|
+
embeddings=image_embeddings,
|
|
270
|
+
high_resolution_features=high_resolution_features,
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
return result_embeddings
|
|
274
|
+
|
|
275
|
+
def segment_images(
|
|
276
|
+
self,
|
|
277
|
+
images: Optional[
|
|
278
|
+
Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]]
|
|
279
|
+
] = None,
|
|
280
|
+
embeddings: Optional[
|
|
281
|
+
Union[List[SAM2ImageEmbeddings], SAM2ImageEmbeddings]
|
|
282
|
+
] = None,
|
|
283
|
+
point_coordinates: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
|
|
284
|
+
point_labels: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
|
|
285
|
+
boxes: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
|
|
286
|
+
mask_input: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
|
|
287
|
+
multi_mask_output: bool = True,
|
|
288
|
+
return_logits: bool = False,
|
|
289
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
290
|
+
mask_threshold: Optional[float] = None,
|
|
291
|
+
load_from_mask_input_cache: bool = False,
|
|
292
|
+
save_to_mask_input_cache: bool = False,
|
|
293
|
+
use_embeddings_cache: bool = True,
|
|
294
|
+
**kwargs,
|
|
295
|
+
) -> List[SAM2Prediction]:
|
|
296
|
+
if images is None and embeddings is None:
|
|
297
|
+
raise ModelInputError(
|
|
298
|
+
message="Attempted to use SAM model segment_images(...) method not providing valid input - "
|
|
299
|
+
"neither `images` nor `embeddings` parameter is given. If you run inference locally, "
|
|
300
|
+
"verify your integration making sure that the model interface is used correctly. Running "
|
|
301
|
+
"on Roboflow platform - contact us to get help.",
|
|
302
|
+
help_url="https://todo",
|
|
303
|
+
)
|
|
304
|
+
if images is not None:
|
|
305
|
+
embeddings = self.embed_images(
|
|
306
|
+
images=images,
|
|
307
|
+
input_color_format=input_color_format,
|
|
308
|
+
use_embeddings_cache=use_embeddings_cache,
|
|
309
|
+
**kwargs,
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
embeddings = maybe_wrap_in_list(value=embeddings)
|
|
313
|
+
image_hashes = [e.image_hash for e in embeddings]
|
|
314
|
+
original_image_sizes = [e.image_size_hw for e in embeddings]
|
|
315
|
+
point_coordinates = maybe_wrap_in_list(value=point_coordinates)
|
|
316
|
+
point_labels = maybe_wrap_in_list(value=point_labels)
|
|
317
|
+
boxes = maybe_wrap_in_list(value=boxes)
|
|
318
|
+
mask_input = maybe_wrap_in_list(value=mask_input)
|
|
319
|
+
point_coordinates, point_labels, boxes, mask_input = equalize_batch_size(
|
|
320
|
+
embeddings_batch_size=len(embeddings),
|
|
321
|
+
point_coordinates=point_coordinates,
|
|
322
|
+
point_labels=point_labels,
|
|
323
|
+
boxes=boxes,
|
|
324
|
+
mask_input=mask_input,
|
|
325
|
+
)
|
|
326
|
+
point_coordinates, point_labels, boxes, mask_input = pre_process_prompts(
|
|
327
|
+
point_coordinates=point_coordinates,
|
|
328
|
+
point_labels=point_labels,
|
|
329
|
+
boxes=boxes,
|
|
330
|
+
mask_input=mask_input,
|
|
331
|
+
device=self._device,
|
|
332
|
+
transform=self._transform,
|
|
333
|
+
original_image_sizes=original_image_sizes,
|
|
334
|
+
)
|
|
335
|
+
predictions = []
|
|
336
|
+
for (
|
|
337
|
+
image_embedding,
|
|
338
|
+
image_hash,
|
|
339
|
+
image_size,
|
|
340
|
+
image_point_coordinates,
|
|
341
|
+
image_point_labels,
|
|
342
|
+
image_boxes,
|
|
343
|
+
image_mask_input,
|
|
344
|
+
) in generate_model_inputs(
|
|
345
|
+
embeddings=embeddings,
|
|
346
|
+
image_hashes=image_hashes,
|
|
347
|
+
original_image_sizes=original_image_sizes,
|
|
348
|
+
point_coordinates=point_coordinates,
|
|
349
|
+
point_labels=point_labels,
|
|
350
|
+
boxes=boxes,
|
|
351
|
+
mask_input=mask_input,
|
|
352
|
+
):
|
|
353
|
+
serialized_prompt, prompt_hash = None, None
|
|
354
|
+
if save_to_mask_input_cache or load_from_mask_input_cache:
|
|
355
|
+
serialized_prompt = serialize_prompt(
|
|
356
|
+
point_coordinates=image_point_coordinates,
|
|
357
|
+
point_labels=image_point_labels,
|
|
358
|
+
boxes=image_boxes,
|
|
359
|
+
)
|
|
360
|
+
prompt_hash = hash_serialized_prompt(
|
|
361
|
+
serialized_prompt=serialized_prompt
|
|
362
|
+
)
|
|
363
|
+
if image_mask_input is None and load_from_mask_input_cache:
|
|
364
|
+
image_mask_input = attempt_load_image_mask_from_cache(
|
|
365
|
+
image_hash=image_hash,
|
|
366
|
+
serialized_prompt_hash=prompt_hash,
|
|
367
|
+
serialized_prompt=serialized_prompt,
|
|
368
|
+
sam2_low_resolution_masks_cache=self._sam2_low_resolution_masks_cache,
|
|
369
|
+
device=self._device,
|
|
370
|
+
)
|
|
371
|
+
prediction = predict_for_single_image(
|
|
372
|
+
model=self._model,
|
|
373
|
+
transform=self._transform,
|
|
374
|
+
embeddings=image_embedding,
|
|
375
|
+
original_image_size=image_size,
|
|
376
|
+
point_coordinates=image_point_coordinates,
|
|
377
|
+
point_labels=image_point_labels,
|
|
378
|
+
boxes=image_boxes,
|
|
379
|
+
mask_input=image_mask_input,
|
|
380
|
+
multi_mask_output=multi_mask_output,
|
|
381
|
+
return_logits=return_logits,
|
|
382
|
+
mask_threshold=mask_threshold,
|
|
383
|
+
)
|
|
384
|
+
if save_to_mask_input_cache and len(prediction[0].shape) == 3:
|
|
385
|
+
max_score_id = torch.argmax(prediction[1]).item()
|
|
386
|
+
mask = SAM2MaskCacheEntry(
|
|
387
|
+
prompt_hash=prompt_hash,
|
|
388
|
+
serialized_prompt=serialized_prompt,
|
|
389
|
+
mask=prediction[2][max_score_id].unsqueeze(dim=0),
|
|
390
|
+
)
|
|
391
|
+
self._sam2_low_resolution_masks_cache.save_mask(
|
|
392
|
+
key=image_hash,
|
|
393
|
+
mask=mask,
|
|
394
|
+
)
|
|
395
|
+
parsed_prediction = SAM2Prediction(
|
|
396
|
+
masks=prediction[0],
|
|
397
|
+
scores=prediction[1],
|
|
398
|
+
logits=prediction[2],
|
|
399
|
+
)
|
|
400
|
+
predictions.append(parsed_prediction)
|
|
401
|
+
return predictions
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def decode_sam_version(config_path: str) -> str:
|
|
405
|
+
config = read_json(path=config_path)
|
|
406
|
+
version = config["version"]
|
|
407
|
+
if not isinstance(version, str):
|
|
408
|
+
raise ValueError("Could not decode SAM model version")
|
|
409
|
+
return version
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def compute_image_hash(image: Union[torch.Tensor, np.ndarray]) -> str:
|
|
413
|
+
if isinstance(image, torch.Tensor):
|
|
414
|
+
image = image.cpu().numpy()
|
|
415
|
+
return hash_function(value=image.tobytes())
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def hash_function(value: Union[str, bytes]) -> str:
|
|
419
|
+
return hashlib.sha1(value).hexdigest()
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def maybe_wrap_in_list(value: Optional[Union[T, List[T]]]) -> Optional[List[T]]:
|
|
423
|
+
if value is None:
|
|
424
|
+
return None
|
|
425
|
+
if isinstance(value, list):
|
|
426
|
+
return value
|
|
427
|
+
return [value]
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def equalize_batch_size(
|
|
431
|
+
embeddings_batch_size: int,
|
|
432
|
+
point_coordinates: Optional[List[ArrayOrTensor]],
|
|
433
|
+
point_labels: Optional[List[ArrayOrTensor]],
|
|
434
|
+
boxes: Optional[List[ArrayOrTensor]],
|
|
435
|
+
mask_input: Optional[List[ArrayOrTensor]],
|
|
436
|
+
) -> Tuple[
|
|
437
|
+
Optional[List[ArrayOrTensor]],
|
|
438
|
+
Optional[List[ArrayOrTensor]],
|
|
439
|
+
Optional[List[ArrayOrTensor]],
|
|
440
|
+
Optional[List[ArrayOrTensor]],
|
|
441
|
+
]:
|
|
442
|
+
if (
|
|
443
|
+
point_coordinates is not None
|
|
444
|
+
and len(point_coordinates) != embeddings_batch_size
|
|
445
|
+
):
|
|
446
|
+
if len(point_coordinates) != 1:
|
|
447
|
+
raise ModelInputError(
|
|
448
|
+
message="When using SAM2 model, parameter `point_coordinates` was provided with invalid "
|
|
449
|
+
f"value indicating different input batch size ({len(point_coordinates)}) than provided "
|
|
450
|
+
f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
|
|
451
|
+
"integration making sure that the model interface is used correctly. "
|
|
452
|
+
"Running on Roboflow platform - contact us to get help.",
|
|
453
|
+
help_url="https://todo",
|
|
454
|
+
)
|
|
455
|
+
point_coordinates = point_coordinates * embeddings_batch_size
|
|
456
|
+
if point_labels is not None and len(point_labels) != embeddings_batch_size:
|
|
457
|
+
if len(point_labels) != 1:
|
|
458
|
+
raise ModelInputError(
|
|
459
|
+
message="When using SAM2 model, parameter `point_labels` was provided with invalid "
|
|
460
|
+
f"value indicating different input batch size ({len(point_labels)}) than provided "
|
|
461
|
+
f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
|
|
462
|
+
"integration making sure that the model interface is used correctly. "
|
|
463
|
+
"Running on Roboflow platform - contact us to get help.",
|
|
464
|
+
help_url="https://todo",
|
|
465
|
+
)
|
|
466
|
+
point_labels = point_labels * embeddings_batch_size
|
|
467
|
+
if boxes is not None and len(boxes) != embeddings_batch_size:
|
|
468
|
+
if len(boxes) != 1:
|
|
469
|
+
raise ModelInputError(
|
|
470
|
+
message="When using SAM2 model, parameter `boxes` was provided with invalid "
|
|
471
|
+
f"value indicating different input batch size ({len(boxes)}) than provided "
|
|
472
|
+
f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
|
|
473
|
+
"integration making sure that the model interface is used correctly. "
|
|
474
|
+
"Running on Roboflow platform - contact us to get help.",
|
|
475
|
+
help_url="https://todo",
|
|
476
|
+
)
|
|
477
|
+
boxes = boxes * embeddings_batch_size
|
|
478
|
+
if mask_input is not None and len(mask_input) != embeddings_batch_size:
|
|
479
|
+
if len(mask_input) != 1:
|
|
480
|
+
raise ModelInputError(
|
|
481
|
+
message="When using SAM2 model, parameter `mask_input` was provided with invalid "
|
|
482
|
+
f"value indicating different input batch size ({len(mask_input)}) than provided "
|
|
483
|
+
f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
|
|
484
|
+
"integration making sure that the model interface is used correctly. "
|
|
485
|
+
"Running on Roboflow platform - contact us to get help.",
|
|
486
|
+
help_url="https://todo",
|
|
487
|
+
)
|
|
488
|
+
mask_input = mask_input * embeddings_batch_size
|
|
489
|
+
prompts_first_dimension_characteristics = set()
|
|
490
|
+
at_max_one_box_expected = False
|
|
491
|
+
if point_coordinates is not None:
|
|
492
|
+
point_coordinates_characteristic = "-".join(
|
|
493
|
+
[str(p.shape[0]) for p in point_coordinates]
|
|
494
|
+
)
|
|
495
|
+
prompts_first_dimension_characteristics.add(point_coordinates_characteristic)
|
|
496
|
+
points_dimensions = set(len(p.shape) for p in point_coordinates)
|
|
497
|
+
if len(points_dimensions) != 1:
|
|
498
|
+
raise ModelInputError(
|
|
499
|
+
message="When using SAM2 model, in scenario when combination of `point_coordinates` provided with "
|
|
500
|
+
"different shapes for different input images, which makes the input invalid. "
|
|
501
|
+
"If you run inference locally, verify your integration making sure that the model interface is "
|
|
502
|
+
"used correctly. Running on Roboflow platform - contact us to get help.",
|
|
503
|
+
help_url="https://todo",
|
|
504
|
+
)
|
|
505
|
+
if points_dimensions.pop() == 2:
|
|
506
|
+
at_max_one_box_expected = True
|
|
507
|
+
if point_labels is not None:
|
|
508
|
+
point_labels_characteristic = "-".join([str(l.shape[0]) for l in point_labels])
|
|
509
|
+
prompts_first_dimension_characteristics.add(point_labels_characteristic)
|
|
510
|
+
if len(prompts_first_dimension_characteristics) > 1:
|
|
511
|
+
raise ModelInputError(
|
|
512
|
+
message="When using SAM2 model, in scenario when combination of `point_coordinates` and `point_labels` "
|
|
513
|
+
"provided, the model expect identical number of elements for each prompt component. "
|
|
514
|
+
"If you run inference locally, verify your integration making sure that the model interface is "
|
|
515
|
+
"used correctly. Running on Roboflow platform - contact us to get help.",
|
|
516
|
+
help_url="https://todo",
|
|
517
|
+
)
|
|
518
|
+
if boxes is not None:
|
|
519
|
+
boxes_characteristic = "-".join(
|
|
520
|
+
[str(b.shape[0]) if len(b.shape) > 1 else "1" for b in boxes]
|
|
521
|
+
)
|
|
522
|
+
prompts_first_dimension_characteristics.add(boxes_characteristic)
|
|
523
|
+
if at_max_one_box_expected:
|
|
524
|
+
if not all(b.shape[0] == 1 if len(b.shape) > 1 else True for b in boxes):
|
|
525
|
+
raise ModelInputError(
|
|
526
|
+
message="When using SAM2 model, with `point_coordinates` provided for single box, each box in "
|
|
527
|
+
"`boxes` parameter must only define single bounding box."
|
|
528
|
+
"If you run inference locally, verify your integration making sure that the model "
|
|
529
|
+
"interface is used correctly. Running on Roboflow platform - contact us to get help.",
|
|
530
|
+
help_url="https://todo",
|
|
531
|
+
)
|
|
532
|
+
elif len(prompts_first_dimension_characteristics) > 1:
|
|
533
|
+
raise ModelInputError(
|
|
534
|
+
message="When using SAM2 model, in scenario when combination of `point_coordinates`, `point_labels`, "
|
|
535
|
+
"`boxes` provided, the model expect identical number of elements for each prompt component. "
|
|
536
|
+
"If you run inference locally, verify your integration making sure that the model interface is "
|
|
537
|
+
"used correctly. Running on Roboflow platform - contact us to get help.",
|
|
538
|
+
help_url="https://todo",
|
|
539
|
+
)
|
|
540
|
+
if mask_input is not None:
|
|
541
|
+
mask_input = [i[None, :, :] if len(i.shape) == 2 else i for i in mask_input]
|
|
542
|
+
if any(len(i.shape) != 3 or i.shape[0] != 1 for i in mask_input):
|
|
543
|
+
raise ModelInputError(
|
|
544
|
+
message="When using SAM2 model with `mask_input`, each mask must be 3D tensor of shape (1, H, W). "
|
|
545
|
+
"If you run inference locally, verify your integration making sure that the model interface is "
|
|
546
|
+
"used correctly. Running on Roboflow platform - contact us to get help.",
|
|
547
|
+
help_url="https://todo",
|
|
548
|
+
)
|
|
549
|
+
return point_coordinates, point_labels, boxes, mask_input
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def generate_model_inputs(
|
|
553
|
+
embeddings: List[SAM2ImageEmbeddings],
|
|
554
|
+
image_hashes: List[str],
|
|
555
|
+
original_image_sizes: List[Tuple[int, int]],
|
|
556
|
+
point_coordinates: Optional[List[torch.Tensor]],
|
|
557
|
+
point_labels: Optional[List[torch.Tensor]],
|
|
558
|
+
boxes: Optional[List[torch.Tensor]],
|
|
559
|
+
mask_input: Optional[List[torch.Tensor]],
|
|
560
|
+
) -> Generator[
|
|
561
|
+
Tuple[
|
|
562
|
+
SAM2ImageEmbeddings,
|
|
563
|
+
str,
|
|
564
|
+
Tuple[int, int],
|
|
565
|
+
Optional[torch.Tensor],
|
|
566
|
+
Optional[torch.Tensor],
|
|
567
|
+
Optional[torch.Tensor],
|
|
568
|
+
Optional[torch.Tensor],
|
|
569
|
+
],
|
|
570
|
+
None,
|
|
571
|
+
None,
|
|
572
|
+
]:
|
|
573
|
+
if point_coordinates is None:
|
|
574
|
+
point_coordinates = [None] * len(embeddings)
|
|
575
|
+
if point_labels is None:
|
|
576
|
+
point_labels = [None] * len(embeddings)
|
|
577
|
+
if boxes is None:
|
|
578
|
+
boxes = [None] * len(embeddings)
|
|
579
|
+
if mask_input is None:
|
|
580
|
+
mask_input = [None] * len(embeddings)
|
|
581
|
+
for embedding, hash_value, image_size, coords, labels, box, mask in zip(
|
|
582
|
+
embeddings,
|
|
583
|
+
image_hashes,
|
|
584
|
+
original_image_sizes,
|
|
585
|
+
point_coordinates,
|
|
586
|
+
point_labels,
|
|
587
|
+
boxes,
|
|
588
|
+
mask_input,
|
|
589
|
+
):
|
|
590
|
+
yield embedding, hash_value, image_size, coords, labels, box, mask
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def pre_process_prompts(
|
|
594
|
+
point_coordinates: Optional[List[ArrayOrTensor]],
|
|
595
|
+
point_labels: Optional[List[ArrayOrTensor]],
|
|
596
|
+
boxes: Optional[List[ArrayOrTensor]],
|
|
597
|
+
mask_input: Optional[List[ArrayOrTensor]],
|
|
598
|
+
device: torch.device,
|
|
599
|
+
transform: SAM2Transforms,
|
|
600
|
+
original_image_sizes: List[Tuple[int, int]],
|
|
601
|
+
normalize_coordinates: bool = True,
|
|
602
|
+
) -> Tuple[
|
|
603
|
+
Optional[List[torch.Tensor]],
|
|
604
|
+
Optional[List[torch.Tensor]],
|
|
605
|
+
Optional[List[torch.Tensor]],
|
|
606
|
+
Optional[List[torch.Tensor]],
|
|
607
|
+
]:
|
|
608
|
+
(
|
|
609
|
+
processed_point_coordinates,
|
|
610
|
+
processed_point_labels,
|
|
611
|
+
processed_boxes,
|
|
612
|
+
processed_mask_input,
|
|
613
|
+
) = (None, None, None, None)
|
|
614
|
+
if point_labels is not None and point_coordinates is None:
|
|
615
|
+
raise ModelInputError(
|
|
616
|
+
message="When using SAM2 model, provided `point_coordinates` without `point_labels` which makes "
|
|
617
|
+
"invalid input. If you run inference locally, verify your integration making sure that the "
|
|
618
|
+
"model interface is used correctly. Running on Roboflow platform - contact us to get help.",
|
|
619
|
+
help_url="https://todo",
|
|
620
|
+
)
|
|
621
|
+
if point_coordinates is not None:
|
|
622
|
+
if point_labels is None:
|
|
623
|
+
raise ModelInputError(
|
|
624
|
+
message="When using SAM2 model, provided `point_coordinates` without `point_labels` which makes "
|
|
625
|
+
"invalid input. If you run inference locally, verify your integration making sure that the "
|
|
626
|
+
"model interface is used correctly. Running on Roboflow platform - contact us to get help.",
|
|
627
|
+
help_url="https://todo",
|
|
628
|
+
)
|
|
629
|
+
processed_point_coordinates = []
|
|
630
|
+
processed_point_labels = []
|
|
631
|
+
for single_label, single_point_coordinates, image_size in zip(
|
|
632
|
+
point_labels, point_coordinates, original_image_sizes
|
|
633
|
+
):
|
|
634
|
+
if isinstance(single_point_coordinates, torch.Tensor):
|
|
635
|
+
single_point_coordinates = single_point_coordinates.to(
|
|
636
|
+
dtype=torch.float, device=device
|
|
637
|
+
)
|
|
638
|
+
else:
|
|
639
|
+
single_point_coordinates = torch.as_tensor(
|
|
640
|
+
single_point_coordinates, dtype=torch.float, device=device
|
|
641
|
+
)
|
|
642
|
+
single_point_coordinates = transform.transform_coords(
|
|
643
|
+
single_point_coordinates,
|
|
644
|
+
normalize=normalize_coordinates,
|
|
645
|
+
orig_hw=image_size,
|
|
646
|
+
)
|
|
647
|
+
dimension_to_unsqueeze = len(single_point_coordinates.shape) == 2
|
|
648
|
+
if dimension_to_unsqueeze:
|
|
649
|
+
single_point_coordinates = single_point_coordinates[None, ...]
|
|
650
|
+
processed_point_coordinates.append(single_point_coordinates)
|
|
651
|
+
if isinstance(single_label, torch.Tensor):
|
|
652
|
+
single_label = single_label.to(dtype=torch.int, device=device)
|
|
653
|
+
else:
|
|
654
|
+
single_label = torch.as_tensor(
|
|
655
|
+
single_label, dtype=torch.int, device=device
|
|
656
|
+
)
|
|
657
|
+
if dimension_to_unsqueeze:
|
|
658
|
+
single_label = single_label[None, ...]
|
|
659
|
+
processed_point_labels.append(single_label)
|
|
660
|
+
if boxes is not None:
|
|
661
|
+
processed_boxes = []
|
|
662
|
+
for box, image_size in zip(boxes, original_image_sizes):
|
|
663
|
+
if isinstance(box, torch.Tensor):
|
|
664
|
+
box = box.to(dtype=torch.float, device=device)
|
|
665
|
+
else:
|
|
666
|
+
box = torch.as_tensor(box, dtype=torch.float, device=device)
|
|
667
|
+
box = transform.transform_boxes(
|
|
668
|
+
box,
|
|
669
|
+
normalize=normalize_coordinates,
|
|
670
|
+
orig_hw=image_size,
|
|
671
|
+
) # Bx2x2
|
|
672
|
+
processed_boxes.append(box)
|
|
673
|
+
if mask_input is not None:
|
|
674
|
+
processed_mask_input = []
|
|
675
|
+
for single_mask in mask_input:
|
|
676
|
+
if isinstance(single_mask, torch.Tensor):
|
|
677
|
+
single_mask = single_mask.to(dtype=torch.float, device=device)
|
|
678
|
+
else:
|
|
679
|
+
single_mask = torch.as_tensor(
|
|
680
|
+
single_mask, dtype=torch.float, device=device
|
|
681
|
+
)
|
|
682
|
+
if len(single_mask.shape) == 3:
|
|
683
|
+
single_mask = single_mask[None, :, :, :]
|
|
684
|
+
processed_mask_input.append(single_mask)
|
|
685
|
+
return (
|
|
686
|
+
processed_point_coordinates,
|
|
687
|
+
processed_point_labels,
|
|
688
|
+
processed_boxes,
|
|
689
|
+
processed_mask_input,
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
@torch.inference_mode()
|
|
694
|
+
def predict_for_single_image(
|
|
695
|
+
model: SAM2Base,
|
|
696
|
+
transform: SAM2Transforms,
|
|
697
|
+
embeddings: SAM2ImageEmbeddings,
|
|
698
|
+
original_image_size: Tuple[int, int],
|
|
699
|
+
point_coordinates: Optional[torch.Tensor],
|
|
700
|
+
point_labels: Optional[torch.Tensor],
|
|
701
|
+
boxes: Optional[torch.Tensor] = None,
|
|
702
|
+
mask_input: Optional[torch.Tensor] = None,
|
|
703
|
+
multi_mask_output: bool = True,
|
|
704
|
+
return_logits: bool = False,
|
|
705
|
+
mask_threshold: Optional[float] = None,
|
|
706
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
707
|
+
if point_coordinates is not None:
|
|
708
|
+
concat_points = (point_coordinates, point_labels)
|
|
709
|
+
else:
|
|
710
|
+
concat_points = None
|
|
711
|
+
if boxes is not None:
|
|
712
|
+
box_coords = boxes.reshape(-1, 2, 2)
|
|
713
|
+
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
|
|
714
|
+
box_labels = box_labels.repeat(boxes.size(0), 1)
|
|
715
|
+
# we merge "boxes" and "points" into a single "concat_points" input (where
|
|
716
|
+
# boxes are added at the beginning) to sam_prompt_encoder
|
|
717
|
+
if concat_points is not None:
|
|
718
|
+
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
|
|
719
|
+
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
|
|
720
|
+
concat_points = (concat_coords, concat_labels)
|
|
721
|
+
else:
|
|
722
|
+
concat_points = (box_coords, box_labels)
|
|
723
|
+
sparse_embeddings, dense_embeddings = model.sam_prompt_encoder(
|
|
724
|
+
points=concat_points,
|
|
725
|
+
boxes=None,
|
|
726
|
+
masks=mask_input,
|
|
727
|
+
)
|
|
728
|
+
batched_mode = concat_points is not None and concat_points[0].shape[0] > 1
|
|
729
|
+
low_res_masks, iou_predictions, _, _ = model.sam_mask_decoder(
|
|
730
|
+
image_embeddings=embeddings.embeddings,
|
|
731
|
+
image_pe=model.sam_prompt_encoder.get_dense_pe(),
|
|
732
|
+
sparse_prompt_embeddings=sparse_embeddings,
|
|
733
|
+
dense_prompt_embeddings=dense_embeddings,
|
|
734
|
+
multimask_output=multi_mask_output,
|
|
735
|
+
repeat_image=batched_mode,
|
|
736
|
+
high_res_features=embeddings.high_resolution_features,
|
|
737
|
+
)
|
|
738
|
+
masks = transform.postprocess_masks(low_res_masks, original_image_size)
|
|
739
|
+
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
|
|
740
|
+
if not return_logits:
|
|
741
|
+
masks = masks > (mask_threshold or 0.0)
|
|
742
|
+
if masks.shape[0] == 1:
|
|
743
|
+
return masks[0], iou_predictions[0], low_res_masks[0]
|
|
744
|
+
else:
|
|
745
|
+
return masks, iou_predictions, low_res_masks
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
def serialize_prompt(
|
|
749
|
+
point_coordinates: Optional[torch.Tensor],
|
|
750
|
+
point_labels: Optional[torch.Tensor],
|
|
751
|
+
boxes: Optional[torch.Tensor],
|
|
752
|
+
) -> List[dict]:
|
|
753
|
+
if point_coordinates is None and point_labels is None and boxes is None:
|
|
754
|
+
return []
|
|
755
|
+
sizes = set()
|
|
756
|
+
if point_coordinates is not None:
|
|
757
|
+
sizes.add(point_coordinates.shape[0])
|
|
758
|
+
if point_labels is not None:
|
|
759
|
+
sizes.add(point_labels.shape[0])
|
|
760
|
+
if boxes is not None:
|
|
761
|
+
sizes.add(boxes.shape[0])
|
|
762
|
+
if len(sizes) != 1:
|
|
763
|
+
raise AssumptionError(
|
|
764
|
+
message="In SAM2 implementation, after pre-processing, all prompt elements must have the same "
|
|
765
|
+
"leading dimension. This assumption just got violated. This is most likely a bug. "
|
|
766
|
+
"You can help us sorting out this problem by submitting an issue: "
|
|
767
|
+
"https://github.com/roboflow/inference/issues",
|
|
768
|
+
help_url="https://todo",
|
|
769
|
+
)
|
|
770
|
+
broadcast_size = sizes.pop()
|
|
771
|
+
point_coordinates_list = (
|
|
772
|
+
point_coordinates.tolist()
|
|
773
|
+
if point_coordinates is not None
|
|
774
|
+
else [None] * broadcast_size
|
|
775
|
+
)
|
|
776
|
+
point_labels_list = (
|
|
777
|
+
point_labels.tolist() if point_labels is not None else [None] * broadcast_size
|
|
778
|
+
)
|
|
779
|
+
boxes_list = (
|
|
780
|
+
boxes.reshape(-1).tolist() if boxes is not None else [None] * broadcast_size
|
|
781
|
+
)
|
|
782
|
+
results = []
|
|
783
|
+
for points, labels, box in zip(
|
|
784
|
+
point_coordinates_list, point_labels_list, boxes_list
|
|
785
|
+
):
|
|
786
|
+
points_serialized = []
|
|
787
|
+
if points is not None and labels is not None:
|
|
788
|
+
for point, label in zip(points, labels):
|
|
789
|
+
points_serialized.append(
|
|
790
|
+
{
|
|
791
|
+
"x": (
|
|
792
|
+
point[0].item()
|
|
793
|
+
if isinstance(point[0], torch.Tensor)
|
|
794
|
+
else point[0]
|
|
795
|
+
),
|
|
796
|
+
"y": (
|
|
797
|
+
point[1].item()
|
|
798
|
+
if isinstance(point[1], torch.Tensor)
|
|
799
|
+
else point[1]
|
|
800
|
+
),
|
|
801
|
+
"positive": (
|
|
802
|
+
label.item() if isinstance(labels, torch.Tensor) else label
|
|
803
|
+
),
|
|
804
|
+
}
|
|
805
|
+
)
|
|
806
|
+
if box is not None:
|
|
807
|
+
box_serialized = box
|
|
808
|
+
else:
|
|
809
|
+
box_serialized = None
|
|
810
|
+
results.append({"points": points_serialized, "box": box_serialized})
|
|
811
|
+
return results
|
|
812
|
+
|
|
813
|
+
|
|
814
|
+
def hash_serialized_prompt(serialized_prompt: List[dict]) -> str:
|
|
815
|
+
serialized = json.dumps(serialized_prompt, sort_keys=True, separators=(",", ":"))
|
|
816
|
+
return hashlib.sha1(serialized.encode("utf-8")).hexdigest()
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
def attempt_load_image_mask_from_cache(
|
|
820
|
+
image_hash: str,
|
|
821
|
+
serialized_prompt_hash: str,
|
|
822
|
+
serialized_prompt: List[dict],
|
|
823
|
+
sam2_low_resolution_masks_cache: Sam2LowResolutionMasksCache,
|
|
824
|
+
device: torch.device,
|
|
825
|
+
) -> Optional[torch.Tensor]:
|
|
826
|
+
all_masks_for_image = sam2_low_resolution_masks_cache.retrieve_all_masks_for_image(
|
|
827
|
+
key=image_hash
|
|
828
|
+
)
|
|
829
|
+
if not all_masks_for_image:
|
|
830
|
+
return None
|
|
831
|
+
if len(serialized_prompt) == 0:
|
|
832
|
+
return None
|
|
833
|
+
return find_prior_prompt_in_cache(
|
|
834
|
+
serialized_prompt_hash=serialized_prompt_hash,
|
|
835
|
+
serialized_prompt=serialized_prompt,
|
|
836
|
+
matching_cache_entries=all_masks_for_image,
|
|
837
|
+
device=device,
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
def find_prior_prompt_in_cache(
|
|
842
|
+
serialized_prompt_hash: str,
|
|
843
|
+
serialized_prompt: List[dict],
|
|
844
|
+
matching_cache_entries: List[SAM2MaskCacheEntry],
|
|
845
|
+
device: torch.device,
|
|
846
|
+
) -> Optional[torch.Tensor]:
|
|
847
|
+
maxed_size = 0
|
|
848
|
+
best_match: Optional[SAM2MaskCacheEntry] = None
|
|
849
|
+
desired_size = len(serialized_prompt) - 1
|
|
850
|
+
for cache_entry in matching_cache_entries[::-1]:
|
|
851
|
+
is_viable = is_prompt_strict_subset(
|
|
852
|
+
assumed_sub_set_prompt=(
|
|
853
|
+
cache_entry.prompt_hash,
|
|
854
|
+
cache_entry.serialized_prompt,
|
|
855
|
+
),
|
|
856
|
+
assumed_super_set_prompt=(serialized_prompt_hash, serialized_prompt),
|
|
857
|
+
)
|
|
858
|
+
if not is_viable:
|
|
859
|
+
continue
|
|
860
|
+
|
|
861
|
+
# short circuit search if we find prompt with one less point (most recent possible mask)
|
|
862
|
+
current_cache_entry_prompt_size = len(cache_entry.serialized_prompt)
|
|
863
|
+
if current_cache_entry_prompt_size == desired_size:
|
|
864
|
+
return cache_entry.mask.to(device=device)
|
|
865
|
+
if current_cache_entry_prompt_size >= maxed_size:
|
|
866
|
+
maxed_size = current_cache_entry_prompt_size
|
|
867
|
+
best_match = cache_entry
|
|
868
|
+
return best_match.mask.to(device=device)
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
def is_prompt_strict_subset(
|
|
872
|
+
assumed_sub_set_prompt: Tuple[str, List[dict]],
|
|
873
|
+
assumed_super_set_prompt: Tuple[str, List[dict]],
|
|
874
|
+
) -> bool:
|
|
875
|
+
if assumed_sub_set_prompt[0] == assumed_super_set_prompt[0]:
|
|
876
|
+
return False
|
|
877
|
+
super_set_prompt_copy = copy(assumed_super_set_prompt[1])
|
|
878
|
+
for sub_set_prompt_element in assumed_sub_set_prompt[1]:
|
|
879
|
+
found_match = False
|
|
880
|
+
for super_set_prompt_element in super_set_prompt_copy:
|
|
881
|
+
boxes_matching = (
|
|
882
|
+
sub_set_prompt_element["box"] == super_set_prompt_element["box"]
|
|
883
|
+
)
|
|
884
|
+
if not boxes_matching:
|
|
885
|
+
continue
|
|
886
|
+
sub_set_prompt_element_points = {
|
|
887
|
+
get_hashable_point(point=point)
|
|
888
|
+
for point in sub_set_prompt_element.get("points", [])
|
|
889
|
+
}
|
|
890
|
+
super_set_prompt_element_points = {
|
|
891
|
+
get_hashable_point(point=point)
|
|
892
|
+
for point in super_set_prompt_element.get("points", [])
|
|
893
|
+
}
|
|
894
|
+
if sub_set_prompt_element_points <= super_set_prompt_element_points:
|
|
895
|
+
super_set_prompt_copy.remove(super_set_prompt_element)
|
|
896
|
+
found_match = True
|
|
897
|
+
break
|
|
898
|
+
if not found_match:
|
|
899
|
+
return False
|
|
900
|
+
# every prompt in subset has a matching super prompt
|
|
901
|
+
return True
|
|
902
|
+
|
|
903
|
+
|
|
904
|
+
def get_hashable_point(point: dict) -> str:
|
|
905
|
+
return json.dumps(point, sort_keys=True, separators=(",", ":"))
|