inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,675 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
from typing import Dict, Generator, List, Optional, Tuple, TypeVar, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from segment_anything import sam_model_registry
|
|
7
|
+
from segment_anything.modeling import Sam
|
|
8
|
+
from segment_anything.utils.transforms import ResizeLongestSide
|
|
9
|
+
|
|
10
|
+
from inference_models import ColorFormat
|
|
11
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
12
|
+
from inference_models.errors import CorruptedModelPackageError, ModelInputError
|
|
13
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
14
|
+
from inference_models.models.sam.cache import (
|
|
15
|
+
SamImageEmbeddingsCache,
|
|
16
|
+
SamImageEmbeddingsCacheNullObject,
|
|
17
|
+
SamLowResolutionMasksCache,
|
|
18
|
+
SamLowResolutionMasksCacheNullObject,
|
|
19
|
+
)
|
|
20
|
+
from inference_models.models.sam.entities import SAMImageEmbeddings, SAMPrediction
|
|
21
|
+
from inference_models.utils.file_system import read_json
|
|
22
|
+
|
|
23
|
+
T = TypeVar("T")
|
|
24
|
+
|
|
25
|
+
MAX_SAM_BATCH_SIZE = 8
|
|
26
|
+
|
|
27
|
+
ArrayOrTensor = Union[np.ndarray, torch.Tensor]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SAMTorch:
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def from_pretrained(
|
|
34
|
+
cls,
|
|
35
|
+
model_name_or_path: str,
|
|
36
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
37
|
+
max_batch_size: int = MAX_SAM_BATCH_SIZE,
|
|
38
|
+
sam_image_embeddings_cache: Optional[SamImageEmbeddingsCache] = None,
|
|
39
|
+
sam_low_resolution_masks_cache: Optional[SamLowResolutionMasksCache] = None,
|
|
40
|
+
**kwargs,
|
|
41
|
+
) -> "SAMTorch":
|
|
42
|
+
if sam_image_embeddings_cache is None:
|
|
43
|
+
sam_image_embeddings_cache = SamImageEmbeddingsCacheNullObject()
|
|
44
|
+
if sam_low_resolution_masks_cache is None:
|
|
45
|
+
sam_low_resolution_masks_cache = SamLowResolutionMasksCacheNullObject()
|
|
46
|
+
model_package_content = get_model_package_contents(
|
|
47
|
+
model_package_dir=model_name_or_path,
|
|
48
|
+
elements=[
|
|
49
|
+
"model.pth",
|
|
50
|
+
"sam_configuration.json",
|
|
51
|
+
],
|
|
52
|
+
)
|
|
53
|
+
try:
|
|
54
|
+
version = decode_sam_version(
|
|
55
|
+
config_path=model_package_content["sam_configuration.json"]
|
|
56
|
+
)
|
|
57
|
+
except Exception as error:
|
|
58
|
+
raise CorruptedModelPackageError(
|
|
59
|
+
message="Cold not decode SAM model version. If you see this error running inference locally, "
|
|
60
|
+
"verify the contents of model package. If you see the error running on Roboflow platform - "
|
|
61
|
+
"contact us to get help.",
|
|
62
|
+
help_url="https://todo",
|
|
63
|
+
) from error
|
|
64
|
+
try:
|
|
65
|
+
sam_model = sam_model_registry[version](
|
|
66
|
+
checkpoint=model_package_content["model.pth"]
|
|
67
|
+
).to(device)
|
|
68
|
+
except Exception as error:
|
|
69
|
+
raise CorruptedModelPackageError(
|
|
70
|
+
message=f"Cold not decode initialize SAM model - cause: {error} If you see this error running "
|
|
71
|
+
f"locally - verify installation of inference and contents of model package. If you use "
|
|
72
|
+
f"Roboflow platform, contact us to get help.",
|
|
73
|
+
help_url="https://todo",
|
|
74
|
+
) from error
|
|
75
|
+
transform = ResizeLongestSide(sam_model.image_encoder.img_size)
|
|
76
|
+
return cls(
|
|
77
|
+
model=sam_model,
|
|
78
|
+
transform=transform,
|
|
79
|
+
device=device,
|
|
80
|
+
max_batch_size=max_batch_size,
|
|
81
|
+
sam_image_embeddings_cache=sam_image_embeddings_cache,
|
|
82
|
+
sam_low_resolution_masks_cache=sam_low_resolution_masks_cache,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
model: Sam,
|
|
88
|
+
transform: ResizeLongestSide,
|
|
89
|
+
device: torch.device,
|
|
90
|
+
max_batch_size: int,
|
|
91
|
+
sam_image_embeddings_cache: SamImageEmbeddingsCache,
|
|
92
|
+
sam_low_resolution_masks_cache: SamLowResolutionMasksCache,
|
|
93
|
+
):
|
|
94
|
+
self._model = model
|
|
95
|
+
self._transform = transform
|
|
96
|
+
self._device = device
|
|
97
|
+
self._max_batch_size = max_batch_size
|
|
98
|
+
self._sam_image_embeddings_cache = sam_image_embeddings_cache
|
|
99
|
+
self._sam_low_resolution_masks_cache = sam_low_resolution_masks_cache
|
|
100
|
+
|
|
101
|
+
def embed_images(
|
|
102
|
+
self,
|
|
103
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
104
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
105
|
+
use_embeddings_cache: bool = True,
|
|
106
|
+
**kwargs,
|
|
107
|
+
) -> List[SAMImageEmbeddings]:
|
|
108
|
+
model_input_images, image_hashes, original_image_sizes = (
|
|
109
|
+
self.pre_process_images(
|
|
110
|
+
images=images,
|
|
111
|
+
input_color_format=input_color_format,
|
|
112
|
+
**kwargs,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
embeddings_from_cache: Dict[int, SAMImageEmbeddings] = {}
|
|
116
|
+
images_to_compute = []
|
|
117
|
+
for idx, (image, image_hash) in enumerate(
|
|
118
|
+
zip(model_input_images, image_hashes)
|
|
119
|
+
):
|
|
120
|
+
cache_content = None
|
|
121
|
+
if use_embeddings_cache:
|
|
122
|
+
cache_content = self._sam_image_embeddings_cache.retrieve_embeddings(
|
|
123
|
+
key=image_hash
|
|
124
|
+
)
|
|
125
|
+
if cache_content is not None:
|
|
126
|
+
cache_content = cache_content.to(device=self._device)
|
|
127
|
+
embeddings_from_cache[idx] = cache_content
|
|
128
|
+
else:
|
|
129
|
+
images_to_compute.append(image)
|
|
130
|
+
if len(images_to_compute) > 0:
|
|
131
|
+
images_to_compute = torch.stack(images_to_compute, dim=0)
|
|
132
|
+
computed_embeddings = self.forward_image_embeddings(
|
|
133
|
+
model_input_images=images_to_compute,
|
|
134
|
+
)
|
|
135
|
+
computed_embeddings_idx = 0
|
|
136
|
+
result_embeddings = []
|
|
137
|
+
for i in range(len(model_input_images)):
|
|
138
|
+
if i in embeddings_from_cache:
|
|
139
|
+
result_embeddings.append(embeddings_from_cache[i].embeddings)
|
|
140
|
+
else:
|
|
141
|
+
result_embeddings.append(
|
|
142
|
+
computed_embeddings[computed_embeddings_idx]
|
|
143
|
+
)
|
|
144
|
+
computed_embeddings_idx += 1
|
|
145
|
+
else:
|
|
146
|
+
result_embeddings = [
|
|
147
|
+
embeddings_from_cache[i].embeddings
|
|
148
|
+
for i in range(len(model_input_images))
|
|
149
|
+
]
|
|
150
|
+
results = []
|
|
151
|
+
for image_hash, image_size, image_embeddings in zip(
|
|
152
|
+
image_hashes, original_image_sizes, result_embeddings
|
|
153
|
+
):
|
|
154
|
+
result = SAMImageEmbeddings(
|
|
155
|
+
image_hash=image_hash,
|
|
156
|
+
image_size_hw=image_size,
|
|
157
|
+
embeddings=image_embeddings,
|
|
158
|
+
)
|
|
159
|
+
results.append(result)
|
|
160
|
+
if use_embeddings_cache:
|
|
161
|
+
self._sam_image_embeddings_cache.save_embeddings(
|
|
162
|
+
key=image_hash, embeddings=result
|
|
163
|
+
)
|
|
164
|
+
return results
|
|
165
|
+
|
|
166
|
+
def pre_process_images(
|
|
167
|
+
self,
|
|
168
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
169
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
170
|
+
**kwargs,
|
|
171
|
+
) -> Tuple[torch.Tensor, List[str], List[Tuple[int, int]]]:
|
|
172
|
+
if isinstance(images, torch.Tensor):
|
|
173
|
+
images = images.to(device=self._device)
|
|
174
|
+
if images.device.type == "cuda":
|
|
175
|
+
images = images.float()
|
|
176
|
+
if len(images.shape) == 4:
|
|
177
|
+
image_hashes = [compute_image_hash(image=image) for image in images]
|
|
178
|
+
if input_color_format == "bgr":
|
|
179
|
+
images = images[:, :-1, :, :].contiguous()
|
|
180
|
+
original_image_sizes = [tuple(images.shape[2:4])] * images.shape[0]
|
|
181
|
+
model_input_images = self._transform.apply_image_torch(image=images)
|
|
182
|
+
else:
|
|
183
|
+
image_hashes = [compute_image_hash(image=images)]
|
|
184
|
+
if input_color_format == "bgr":
|
|
185
|
+
images = images[::-1, :, :].contiguous()
|
|
186
|
+
original_image_sizes = [tuple(images.shape[1:3])]
|
|
187
|
+
model_input_images = self._transform.apply_image_torch(
|
|
188
|
+
image=images.unsqueeze(dim=0)
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
if isinstance(images, list):
|
|
192
|
+
image_hashes = [compute_image_hash(image=image) for image in images]
|
|
193
|
+
original_image_sizes = []
|
|
194
|
+
model_input_images = []
|
|
195
|
+
for image in images:
|
|
196
|
+
if isinstance(image, np.ndarray):
|
|
197
|
+
original_image_sizes.append(image.shape[:2])
|
|
198
|
+
if input_color_format in {None, "bgr"}:
|
|
199
|
+
image = np.ascontiguousarray(image[:, :, ::-1])
|
|
200
|
+
input_image = self._transform.apply_image(image=image)
|
|
201
|
+
input_image = (
|
|
202
|
+
torch.as_tensor(input_image, device=self._device)
|
|
203
|
+
.permute(2, 0, 1)
|
|
204
|
+
.contiguous()
|
|
205
|
+
)
|
|
206
|
+
model_input_images.append(input_image)
|
|
207
|
+
else:
|
|
208
|
+
original_image_sizes.append(tuple(image.shape[1:3]))
|
|
209
|
+
image = image.to(self._device)
|
|
210
|
+
if image.device.type == "cuda":
|
|
211
|
+
image = image.float()
|
|
212
|
+
if input_color_format == "bgr":
|
|
213
|
+
image = image[::-1, :, :].contiguous()
|
|
214
|
+
input_image = self._transform.apply_image_torch(
|
|
215
|
+
image=image.unsqueeze(dim=0)
|
|
216
|
+
)[0]
|
|
217
|
+
model_input_images.append(input_image)
|
|
218
|
+
model_input_images = torch.stack(model_input_images, dim=0)
|
|
219
|
+
else:
|
|
220
|
+
image_hashes = [compute_image_hash(image=images)]
|
|
221
|
+
original_image_sizes = [images.shape[:2]]
|
|
222
|
+
if input_color_format in {None, "bgr"}:
|
|
223
|
+
images = np.ascontiguousarray(images[:, :, ::-1])
|
|
224
|
+
model_input_images = self._transform.apply_image(image=images)
|
|
225
|
+
model_input_images = (
|
|
226
|
+
torch.as_tensor(model_input_images, device=self._device)
|
|
227
|
+
.permute(2, 0, 1)
|
|
228
|
+
.contiguous()[None, :, :, :]
|
|
229
|
+
)
|
|
230
|
+
return model_input_images, image_hashes, original_image_sizes
|
|
231
|
+
|
|
232
|
+
@torch.inference_mode()
|
|
233
|
+
def forward_image_embeddings(
|
|
234
|
+
self, model_input_images: torch.Tensor, **kwargs
|
|
235
|
+
) -> torch.Tensor:
|
|
236
|
+
result_embeddings = []
|
|
237
|
+
for i in range(0, model_input_images.shape[0], self._max_batch_size):
|
|
238
|
+
input_images_batch = model_input_images[
|
|
239
|
+
i : i + self._max_batch_size
|
|
240
|
+
].contiguous()
|
|
241
|
+
pre_processed_images_batch = self._model.preprocess(input_images_batch)
|
|
242
|
+
batch_embeddings = self._model.image_encoder(pre_processed_images_batch).to(
|
|
243
|
+
device=self._device
|
|
244
|
+
)
|
|
245
|
+
result_embeddings.append(batch_embeddings)
|
|
246
|
+
return torch.cat(result_embeddings, dim=0)
|
|
247
|
+
|
|
248
|
+
def segment_images(
|
|
249
|
+
self,
|
|
250
|
+
images: Optional[
|
|
251
|
+
Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]]
|
|
252
|
+
] = None,
|
|
253
|
+
embeddings: Optional[
|
|
254
|
+
Union[List[SAMImageEmbeddings], SAMImageEmbeddings]
|
|
255
|
+
] = None,
|
|
256
|
+
point_coordinates: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
|
|
257
|
+
point_labels: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
|
|
258
|
+
boxes: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
|
|
259
|
+
mask_input: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None,
|
|
260
|
+
multi_mask_output: bool = True,
|
|
261
|
+
return_logits: bool = False,
|
|
262
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
263
|
+
mask_threshold: Optional[float] = None,
|
|
264
|
+
enforce_mask_input: bool = False,
|
|
265
|
+
use_mask_input_cache: bool = True,
|
|
266
|
+
use_embeddings_cache: bool = True,
|
|
267
|
+
**kwargs,
|
|
268
|
+
) -> List[SAMPrediction]:
|
|
269
|
+
if images is None and embeddings is None:
|
|
270
|
+
raise ModelInputError(
|
|
271
|
+
message="Attempted to use SAM model segment_images(...) method not providing valid input - "
|
|
272
|
+
"neither `images` nor `embeddings` parameter is given. If you run inference locally, "
|
|
273
|
+
"verify your integration making sure that the model interface is used correctly. Running "
|
|
274
|
+
"on Roboflow platform - contact us to get help.",
|
|
275
|
+
help_url="https://todo",
|
|
276
|
+
)
|
|
277
|
+
if images is not None:
|
|
278
|
+
embeddings = self.embed_images(
|
|
279
|
+
images=images,
|
|
280
|
+
input_color_format=input_color_format,
|
|
281
|
+
use_embeddings_cache=use_embeddings_cache,
|
|
282
|
+
**kwargs,
|
|
283
|
+
)
|
|
284
|
+
else:
|
|
285
|
+
embeddings = maybe_wrap_in_list(value=embeddings)
|
|
286
|
+
embeddings_tensors = [e.embeddings.to(self._device) for e in embeddings]
|
|
287
|
+
image_hashes = [e.image_hash for e in embeddings]
|
|
288
|
+
original_image_sizes = [e.image_size_hw for e in embeddings]
|
|
289
|
+
point_coordinates = maybe_wrap_in_list(value=point_coordinates)
|
|
290
|
+
point_labels = maybe_wrap_in_list(value=point_labels)
|
|
291
|
+
boxes = maybe_wrap_in_list(value=boxes)
|
|
292
|
+
mask_input = maybe_wrap_in_list(value=mask_input)
|
|
293
|
+
masks_from_the_cache = [
|
|
294
|
+
(
|
|
295
|
+
self._sam_low_resolution_masks_cache.retrieve_mask(key=image_hash)
|
|
296
|
+
if use_mask_input_cache
|
|
297
|
+
else None
|
|
298
|
+
)
|
|
299
|
+
for image_hash in image_hashes
|
|
300
|
+
]
|
|
301
|
+
if enforce_mask_input and mask_input is None:
|
|
302
|
+
if not all(e is not None for e in masks_from_the_cache):
|
|
303
|
+
raise ModelInputError(
|
|
304
|
+
message="Attempted to use SAM model segment_images(...) method enforcing the presence of "
|
|
305
|
+
"low-resolution mask input and not providing the mask explicitly (causing fallback to "
|
|
306
|
+
"SAM cache lookup which failed for at least one image) - this problem may be temporary, "
|
|
307
|
+
"but may also be a result of bug or invalid integration. If you run inference locally, "
|
|
308
|
+
"verify your integration making sure that the model interface is used correctly. Running "
|
|
309
|
+
"on Roboflow platform - contact us to get help.",
|
|
310
|
+
help_url="https://todo",
|
|
311
|
+
)
|
|
312
|
+
mask_input = [mask.to(self._device) for mask in masks_from_the_cache]
|
|
313
|
+
point_coordinates, point_labels, boxes, mask_input = equalize_batch_size(
|
|
314
|
+
embeddings_batch_size=len(embeddings),
|
|
315
|
+
point_coordinates=point_coordinates,
|
|
316
|
+
point_labels=point_labels,
|
|
317
|
+
boxes=boxes,
|
|
318
|
+
mask_input=mask_input,
|
|
319
|
+
)
|
|
320
|
+
point_coordinates, point_labels, boxes, mask_input = pre_process_prompts(
|
|
321
|
+
point_coordinates=point_coordinates,
|
|
322
|
+
point_labels=point_labels,
|
|
323
|
+
boxes=boxes,
|
|
324
|
+
mask_input=mask_input,
|
|
325
|
+
device=self._device,
|
|
326
|
+
transform=self._transform,
|
|
327
|
+
original_image_sizes=original_image_sizes,
|
|
328
|
+
)
|
|
329
|
+
predictions = []
|
|
330
|
+
for (
|
|
331
|
+
image_embedding,
|
|
332
|
+
image_hash,
|
|
333
|
+
image_size,
|
|
334
|
+
image_point_coordinates,
|
|
335
|
+
image_point_labels,
|
|
336
|
+
image_boxes,
|
|
337
|
+
image_mask_input,
|
|
338
|
+
) in generate_model_inputs(
|
|
339
|
+
embeddings=embeddings_tensors,
|
|
340
|
+
image_hashes=image_hashes,
|
|
341
|
+
original_image_sizes=original_image_sizes,
|
|
342
|
+
point_coordinates=point_coordinates,
|
|
343
|
+
point_labels=point_labels,
|
|
344
|
+
boxes=boxes,
|
|
345
|
+
mask_input=mask_input,
|
|
346
|
+
):
|
|
347
|
+
prediction = predict_for_single_image(
|
|
348
|
+
model=self._model,
|
|
349
|
+
transform=self._transform,
|
|
350
|
+
embeddings=image_embedding,
|
|
351
|
+
original_image_size=image_size,
|
|
352
|
+
point_coordinates=image_point_coordinates,
|
|
353
|
+
point_labels=image_point_labels,
|
|
354
|
+
boxes=image_boxes,
|
|
355
|
+
mask_input=image_mask_input,
|
|
356
|
+
multi_mask_output=multi_mask_output,
|
|
357
|
+
return_logits=return_logits,
|
|
358
|
+
mask_threshold=mask_threshold,
|
|
359
|
+
)
|
|
360
|
+
if use_mask_input_cache and len(prediction[0].shape) == 3:
|
|
361
|
+
max_score_id = torch.argmax(prediction[1]).item()
|
|
362
|
+
self._sam_low_resolution_masks_cache.save_mask(
|
|
363
|
+
key=image_hash, mask=prediction[2][max_score_id].unsqueeze(dim=0)
|
|
364
|
+
)
|
|
365
|
+
parsed_prediction = SAMPrediction(
|
|
366
|
+
masks=prediction[0],
|
|
367
|
+
scores=prediction[1],
|
|
368
|
+
logits=prediction[2],
|
|
369
|
+
)
|
|
370
|
+
predictions.append(parsed_prediction)
|
|
371
|
+
return predictions
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def decode_sam_version(config_path: str) -> str:
|
|
375
|
+
config = read_json(path=config_path)
|
|
376
|
+
version = config["version"]
|
|
377
|
+
if not isinstance(version, str):
|
|
378
|
+
raise ValueError("Could not decode SAM model version")
|
|
379
|
+
return version
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def compute_image_hash(image: Union[torch.Tensor, np.ndarray]) -> str:
|
|
383
|
+
if isinstance(image, torch.Tensor):
|
|
384
|
+
image = image.cpu().numpy()
|
|
385
|
+
return hash_function(value=image.tobytes())
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def hash_function(value: Union[str, bytes]) -> str:
|
|
389
|
+
return hashlib.sha1(value).hexdigest()
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def maybe_wrap_in_list(value: Optional[Union[T, List[T]]]) -> Optional[List[T]]:
|
|
393
|
+
if value is None:
|
|
394
|
+
return None
|
|
395
|
+
if isinstance(value, list):
|
|
396
|
+
return value
|
|
397
|
+
return [value]
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def equalize_batch_size(
|
|
401
|
+
embeddings_batch_size: int,
|
|
402
|
+
point_coordinates: Optional[List[ArrayOrTensor]],
|
|
403
|
+
point_labels: Optional[List[ArrayOrTensor]],
|
|
404
|
+
boxes: Optional[List[ArrayOrTensor]],
|
|
405
|
+
mask_input: Optional[List[ArrayOrTensor]],
|
|
406
|
+
) -> Tuple[
|
|
407
|
+
Optional[List[ArrayOrTensor]],
|
|
408
|
+
Optional[List[ArrayOrTensor]],
|
|
409
|
+
Optional[List[ArrayOrTensor]],
|
|
410
|
+
Optional[List[ArrayOrTensor]],
|
|
411
|
+
]:
|
|
412
|
+
if (
|
|
413
|
+
point_coordinates is not None
|
|
414
|
+
and len(point_coordinates) != embeddings_batch_size
|
|
415
|
+
):
|
|
416
|
+
if len(point_coordinates) != 1:
|
|
417
|
+
raise ModelInputError(
|
|
418
|
+
message="When using SAM model, parameter `point_coordinates` was provided with invalid "
|
|
419
|
+
f"value indicating different input batch size ({len(point_coordinates)}) than provided "
|
|
420
|
+
f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
|
|
421
|
+
"integration making sure that the model interface is used correctly. "
|
|
422
|
+
"Running on Roboflow platform - contact us to get help.",
|
|
423
|
+
help_url="https://todo",
|
|
424
|
+
)
|
|
425
|
+
point_coordinates = point_coordinates * embeddings_batch_size
|
|
426
|
+
if point_labels is not None and len(point_labels) != embeddings_batch_size:
|
|
427
|
+
if len(point_labels) != 1:
|
|
428
|
+
raise ModelInputError(
|
|
429
|
+
message="When using SAM model, parameter `point_labels` was provided with invalid "
|
|
430
|
+
f"value indicating different input batch size ({len(point_labels)}) than provided "
|
|
431
|
+
f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
|
|
432
|
+
"integration making sure that the model interface is used correctly. "
|
|
433
|
+
"Running on Roboflow platform - contact us to get help.",
|
|
434
|
+
help_url="https://todo",
|
|
435
|
+
)
|
|
436
|
+
point_labels = point_labels * embeddings_batch_size
|
|
437
|
+
if boxes is not None and len(boxes) != embeddings_batch_size:
|
|
438
|
+
if len(boxes) != 1:
|
|
439
|
+
raise ModelInputError(
|
|
440
|
+
message="When using SAM model, parameter `boxes` was provided with invalid "
|
|
441
|
+
f"value indicating different input batch size ({len(boxes)}) than provided "
|
|
442
|
+
f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
|
|
443
|
+
"integration making sure that the model interface is used correctly. "
|
|
444
|
+
"Running on Roboflow platform - contact us to get help.",
|
|
445
|
+
help_url="https://todo",
|
|
446
|
+
)
|
|
447
|
+
boxes = boxes * embeddings_batch_size
|
|
448
|
+
if mask_input is not None and len(mask_input) != embeddings_batch_size:
|
|
449
|
+
if len(mask_input) != 1:
|
|
450
|
+
raise ModelInputError(
|
|
451
|
+
message="When using SAM model, parameter `mask_input` was provided with invalid "
|
|
452
|
+
f"value indicating different input batch size ({len(mask_input)}) than provided "
|
|
453
|
+
f"images / embeddings ({embeddings_batch_size}). If you run inference locally, verify your "
|
|
454
|
+
"integration making sure that the model interface is used correctly. "
|
|
455
|
+
"Running on Roboflow platform - contact us to get help.",
|
|
456
|
+
help_url="https://todo",
|
|
457
|
+
)
|
|
458
|
+
mask_input = mask_input * embeddings_batch_size
|
|
459
|
+
prompts_first_dimension_characteristics = set()
|
|
460
|
+
if point_coordinates is not None:
|
|
461
|
+
point_coordinates_characteristic = "-".join(
|
|
462
|
+
[str(p.shape[0]) for p in point_coordinates]
|
|
463
|
+
)
|
|
464
|
+
prompts_first_dimension_characteristics.add(point_coordinates_characteristic)
|
|
465
|
+
if point_labels is not None:
|
|
466
|
+
point_labels_characteristic = "-".join([str(l.shape[0]) for l in point_labels])
|
|
467
|
+
prompts_first_dimension_characteristics.add(point_labels_characteristic)
|
|
468
|
+
if boxes is not None:
|
|
469
|
+
boxes_characteristic = "-".join(
|
|
470
|
+
[str(b.shape[0]) if len(b.shape) > 1 else "1" for b in boxes]
|
|
471
|
+
)
|
|
472
|
+
prompts_first_dimension_characteristics.add(boxes_characteristic)
|
|
473
|
+
if len(prompts_first_dimension_characteristics) > 1:
|
|
474
|
+
raise ModelInputError(
|
|
475
|
+
message="When using SAM model, in scenario when combination of `point_coordinates` and `point_labels` and "
|
|
476
|
+
"`boxes` provided, the model expect identical number of elements for each prompt component. "
|
|
477
|
+
"If you run inference locally, verify your integration making sure that the model interface is "
|
|
478
|
+
"used correctly. Running on Roboflow platform - contact us to get help.",
|
|
479
|
+
help_url="https://todo",
|
|
480
|
+
)
|
|
481
|
+
if mask_input is not None:
|
|
482
|
+
mask_input = [i[None, :, :] if len(i.shape) == 2 else i for i in mask_input]
|
|
483
|
+
if any(len(i.shape) != 3 or i.shape[0] != 1 for i in mask_input):
|
|
484
|
+
raise ModelInputError(
|
|
485
|
+
message="When using SAM model with `mask_input`, each mask must be 3D tensor of shape (1, H, W). "
|
|
486
|
+
"If you run inference locally, verify your integration making sure that the model interface is "
|
|
487
|
+
"used correctly. Running on Roboflow platform - contact us to get help.",
|
|
488
|
+
help_url="https://todo",
|
|
489
|
+
)
|
|
490
|
+
if boxes is not None:
|
|
491
|
+
batched_boxes_provided = False
|
|
492
|
+
for box in boxes:
|
|
493
|
+
if len(box.shape) > 1 and box.shape[0] > 1:
|
|
494
|
+
batched_boxes_provided = True
|
|
495
|
+
if batched_boxes_provided and any(
|
|
496
|
+
e is not None for e in [point_coordinates, point_labels, mask_input]
|
|
497
|
+
):
|
|
498
|
+
raise ModelInputError(
|
|
499
|
+
message="When using SAM, providing batched boxes (multiple RoIs for single image) makes it impossible "
|
|
500
|
+
"to use other components of the prompt - like `point_coordinates`, `point_labels` "
|
|
501
|
+
"or `mask_input` - and such situation was detected. "
|
|
502
|
+
"If you run inference locally, verify your integration making sure that the model interface is "
|
|
503
|
+
"used correctly. Running on Roboflow platform - contact us to get help.",
|
|
504
|
+
help_url="https://todo",
|
|
505
|
+
)
|
|
506
|
+
return point_coordinates, point_labels, boxes, mask_input
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def maybe_broadcast_list(value: Optional[List[T]], n: int) -> Optional[List[T]]:
|
|
510
|
+
if value is None:
|
|
511
|
+
return None
|
|
512
|
+
return value * n
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def pre_process_prompts(
|
|
516
|
+
point_coordinates: Optional[List[ArrayOrTensor]],
|
|
517
|
+
point_labels: Optional[List[ArrayOrTensor]],
|
|
518
|
+
boxes: Optional[List[ArrayOrTensor]],
|
|
519
|
+
mask_input: Optional[List[ArrayOrTensor]],
|
|
520
|
+
device: torch.device,
|
|
521
|
+
transform: ResizeLongestSide,
|
|
522
|
+
original_image_sizes: List[Tuple[int, int]],
|
|
523
|
+
) -> Tuple[
|
|
524
|
+
Optional[List[torch.Tensor]],
|
|
525
|
+
Optional[List[torch.Tensor]],
|
|
526
|
+
Optional[List[torch.Tensor]],
|
|
527
|
+
Optional[List[torch.Tensor]],
|
|
528
|
+
]:
|
|
529
|
+
if point_labels is not None and point_coordinates is None:
|
|
530
|
+
raise ModelInputError(
|
|
531
|
+
message="When using SAM model, provided `point_coordinates` without `point_labels` which makes invalid "
|
|
532
|
+
"input. If you run inference locally, verify your integration making sure that the model "
|
|
533
|
+
"interface is used correctly. Running on Roboflow platform - contact us to get help.",
|
|
534
|
+
help_url="https://todo",
|
|
535
|
+
)
|
|
536
|
+
if point_coordinates is not None:
|
|
537
|
+
if point_labels is None:
|
|
538
|
+
raise ModelInputError(
|
|
539
|
+
message="When using SAM model, provided `point_coordinates` without `point_labels` which makes invalid "
|
|
540
|
+
"input. If you run inference locally, verify your integration making sure that the model "
|
|
541
|
+
"interface is used correctly. Running on Roboflow platform - contact us to get help.",
|
|
542
|
+
help_url="https://todo",
|
|
543
|
+
)
|
|
544
|
+
point_coordinates = [
|
|
545
|
+
(
|
|
546
|
+
c.to(device)[None, :, :]
|
|
547
|
+
if isinstance(c, torch.Tensor)
|
|
548
|
+
else torch.from_numpy(c).to(device)[None, :, :]
|
|
549
|
+
)
|
|
550
|
+
for c in point_coordinates
|
|
551
|
+
]
|
|
552
|
+
point_labels = [
|
|
553
|
+
(
|
|
554
|
+
l.to(device)[None, :]
|
|
555
|
+
if isinstance(l, torch.Tensor)
|
|
556
|
+
else torch.from_numpy(l).to(device)[None, :]
|
|
557
|
+
)
|
|
558
|
+
for l in point_labels
|
|
559
|
+
]
|
|
560
|
+
point_coordinates = [
|
|
561
|
+
transform.apply_coords_torch(point_coords, image_shape)
|
|
562
|
+
for point_coords, image_shape in zip(
|
|
563
|
+
point_coordinates, original_image_sizes
|
|
564
|
+
)
|
|
565
|
+
]
|
|
566
|
+
if boxes is not None:
|
|
567
|
+
boxes = [
|
|
568
|
+
(
|
|
569
|
+
box.to(device)[None, :]
|
|
570
|
+
if isinstance(box, torch.Tensor)
|
|
571
|
+
else torch.from_numpy(box).to(device)[None, :]
|
|
572
|
+
)
|
|
573
|
+
for box in boxes
|
|
574
|
+
]
|
|
575
|
+
boxes = [
|
|
576
|
+
transform.apply_boxes_torch(box, image_shape)
|
|
577
|
+
for box, image_shape in zip(boxes, original_image_sizes)
|
|
578
|
+
]
|
|
579
|
+
if mask_input is not None:
|
|
580
|
+
mask_input = [
|
|
581
|
+
(
|
|
582
|
+
mask.to(device)[None, :, :]
|
|
583
|
+
if isinstance(mask, torch.Tensor)
|
|
584
|
+
else torch.from_numpy(mask).to(device)[None, :, :]
|
|
585
|
+
)
|
|
586
|
+
for mask in mask_input
|
|
587
|
+
]
|
|
588
|
+
return point_coordinates, point_labels, boxes, mask_input
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
def generate_model_inputs(
|
|
592
|
+
embeddings: List[torch.Tensor],
|
|
593
|
+
image_hashes: List[str],
|
|
594
|
+
original_image_sizes: List[Tuple[int, int]],
|
|
595
|
+
point_coordinates: Optional[List[torch.Tensor]],
|
|
596
|
+
point_labels: Optional[List[torch.Tensor]],
|
|
597
|
+
boxes: Optional[List[torch.Tensor]],
|
|
598
|
+
mask_input: Optional[List[torch.Tensor]],
|
|
599
|
+
) -> Generator[
|
|
600
|
+
Tuple[
|
|
601
|
+
torch.Tensor,
|
|
602
|
+
str,
|
|
603
|
+
Tuple[int, int],
|
|
604
|
+
Optional[torch.Tensor],
|
|
605
|
+
Optional[torch.Tensor],
|
|
606
|
+
Optional[torch.Tensor],
|
|
607
|
+
Optional[torch.Tensor],
|
|
608
|
+
],
|
|
609
|
+
None,
|
|
610
|
+
None,
|
|
611
|
+
]:
|
|
612
|
+
if point_coordinates is None:
|
|
613
|
+
point_coordinates = [None] * len(embeddings)
|
|
614
|
+
if point_labels is None:
|
|
615
|
+
point_labels = [None] * len(embeddings)
|
|
616
|
+
if boxes is None:
|
|
617
|
+
boxes = [None] * len(embeddings)
|
|
618
|
+
if mask_input is None:
|
|
619
|
+
mask_input = [None] * len(embeddings)
|
|
620
|
+
for embedding, hash_value, image_size, coords, labels, box, mask in zip(
|
|
621
|
+
embeddings,
|
|
622
|
+
image_hashes,
|
|
623
|
+
original_image_sizes,
|
|
624
|
+
point_coordinates,
|
|
625
|
+
point_labels,
|
|
626
|
+
boxes,
|
|
627
|
+
mask_input,
|
|
628
|
+
):
|
|
629
|
+
yield embedding, hash_value, image_size, coords, labels, box, mask
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
@torch.inference_mode()
|
|
633
|
+
def predict_for_single_image(
|
|
634
|
+
model: Sam,
|
|
635
|
+
transform: ResizeLongestSide,
|
|
636
|
+
embeddings: torch.Tensor,
|
|
637
|
+
original_image_size: Tuple[int, int],
|
|
638
|
+
point_coordinates: Optional[torch.Tensor],
|
|
639
|
+
point_labels: Optional[torch.Tensor],
|
|
640
|
+
boxes: Optional[torch.Tensor] = None,
|
|
641
|
+
mask_input: Optional[torch.Tensor] = None,
|
|
642
|
+
multi_mask_output: bool = True,
|
|
643
|
+
return_logits: bool = False,
|
|
644
|
+
mask_threshold: Optional[float] = None,
|
|
645
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
646
|
+
embeddings = embeddings.unsqueeze(dim=0)
|
|
647
|
+
if point_coordinates is not None:
|
|
648
|
+
points = (point_coordinates, point_labels)
|
|
649
|
+
else:
|
|
650
|
+
points = None
|
|
651
|
+
sparse_embeddings, dense_embeddings = model.prompt_encoder(
|
|
652
|
+
points=points,
|
|
653
|
+
boxes=boxes,
|
|
654
|
+
masks=mask_input,
|
|
655
|
+
)
|
|
656
|
+
low_res_masks, iou_predictions = model.mask_decoder(
|
|
657
|
+
image_embeddings=embeddings,
|
|
658
|
+
image_pe=model.prompt_encoder.get_dense_pe(),
|
|
659
|
+
sparse_prompt_embeddings=sparse_embeddings,
|
|
660
|
+
dense_prompt_embeddings=dense_embeddings,
|
|
661
|
+
multimask_output=multi_mask_output,
|
|
662
|
+
)
|
|
663
|
+
model_input_size = transform.get_preprocess_shape(
|
|
664
|
+
original_image_size[0], original_image_size[1], transform.target_length
|
|
665
|
+
)
|
|
666
|
+
masks = model.postprocess_masks(
|
|
667
|
+
low_res_masks, model_input_size, original_image_size
|
|
668
|
+
)
|
|
669
|
+
if not return_logits:
|
|
670
|
+
threshold = mask_threshold or model.mask_threshold
|
|
671
|
+
masks = masks > threshold
|
|
672
|
+
if masks.shape[0] == 1:
|
|
673
|
+
return masks[0], iou_predictions[0], low_res_masks[0]
|
|
674
|
+
else:
|
|
675
|
+
return masks, iou_predictions, low_res_masks
|
|
File without changes
|