inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import os.path
|
|
3
|
+
import re
|
|
4
|
+
import urllib.parse
|
|
5
|
+
from typing import List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import backoff
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pybase64
|
|
11
|
+
import requests
|
|
12
|
+
import torch
|
|
13
|
+
from requests import Timeout
|
|
14
|
+
from tldextract import tldextract
|
|
15
|
+
from tldextract.tldextract import ExtractResult
|
|
16
|
+
|
|
17
|
+
from inference_models.configuration import (
|
|
18
|
+
API_CALLS_MAX_TRIES,
|
|
19
|
+
IDEMPOTENT_API_REQUEST_CODES_TO_RETRY,
|
|
20
|
+
)
|
|
21
|
+
from inference_models.errors import ModelInputError, ModelRuntimeError, RetryError
|
|
22
|
+
|
|
23
|
+
BASE64_DATA_TYPE_PATTERN = re.compile(r"^data:image\/[a-z]+;base64,")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LazyImageWrapper:
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def init(
|
|
30
|
+
cls,
|
|
31
|
+
image: Union[np.ndarray, torch.Tensor, str, bytes],
|
|
32
|
+
allow_url_input: bool,
|
|
33
|
+
allow_non_https_url: bool,
|
|
34
|
+
allow_url_without_fqdn: bool,
|
|
35
|
+
whitelisted_domains: Optional[List[str]],
|
|
36
|
+
blacklisted_domains: Optional[List[str]],
|
|
37
|
+
allow_local_storage_access: bool,
|
|
38
|
+
):
|
|
39
|
+
image_in_memory, image_reference = None, None
|
|
40
|
+
if isinstance(image, (torch.Tensor, np.ndarray)):
|
|
41
|
+
image_in_memory = image
|
|
42
|
+
else:
|
|
43
|
+
image_reference = image
|
|
44
|
+
return cls(
|
|
45
|
+
allow_url_input=allow_url_input,
|
|
46
|
+
allow_non_https_url=allow_non_https_url,
|
|
47
|
+
allow_url_without_fqdn=allow_url_without_fqdn,
|
|
48
|
+
whitelisted_domains=whitelisted_domains,
|
|
49
|
+
blacklisted_domains=blacklisted_domains,
|
|
50
|
+
allow_local_storage_access=allow_local_storage_access,
|
|
51
|
+
image_in_memory=image_in_memory,
|
|
52
|
+
image_reference=image_reference,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
allow_url_input: bool,
|
|
58
|
+
allow_non_https_url: bool,
|
|
59
|
+
allow_url_without_fqdn: bool,
|
|
60
|
+
whitelisted_domains: Optional[List[str]],
|
|
61
|
+
blacklisted_domains: Optional[List[str]],
|
|
62
|
+
allow_local_storage_access: bool,
|
|
63
|
+
image_in_memory: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
|
64
|
+
image_reference: Optional[Union[str, bytes]] = None,
|
|
65
|
+
image_hash: Optional[str] = None,
|
|
66
|
+
):
|
|
67
|
+
self._allow_url_input = allow_url_input
|
|
68
|
+
self._allow_non_https_url = allow_non_https_url
|
|
69
|
+
self._allow_url_without_fqdn = allow_url_without_fqdn
|
|
70
|
+
self._whitelisted_domains = whitelisted_domains
|
|
71
|
+
self._blacklisted_domains = blacklisted_domains
|
|
72
|
+
self._allow_local_storage_access = allow_local_storage_access
|
|
73
|
+
if image_in_memory is None and image_reference is None:
|
|
74
|
+
raise ModelRuntimeError(
|
|
75
|
+
message="Attempted to use OWLv2 image lazy loading not providing neither image "
|
|
76
|
+
"location nor image instance - this is invalid input. Contact Roboflow to get help.",
|
|
77
|
+
help_url="https://todo",
|
|
78
|
+
)
|
|
79
|
+
self._image_in_memory = image_in_memory
|
|
80
|
+
self._image_reference = image_reference
|
|
81
|
+
self._image_hash = image_hash
|
|
82
|
+
|
|
83
|
+
def as_numpy(self) -> np.ndarray:
|
|
84
|
+
if self._image_in_memory is not None:
|
|
85
|
+
if isinstance(self._image_in_memory, torch.Tensor):
|
|
86
|
+
self._image_in_memory = self._image_in_memory.cpu().numpy()
|
|
87
|
+
return self._image_in_memory
|
|
88
|
+
image = load_image_reference(
|
|
89
|
+
image_reference=self._image_reference,
|
|
90
|
+
allow_url_input=self._allow_url_input,
|
|
91
|
+
allow_non_https_url=self._allow_non_https_url,
|
|
92
|
+
allow_url_without_fqdn=self._allow_url_without_fqdn,
|
|
93
|
+
whitelisted_domains=self._whitelisted_domains,
|
|
94
|
+
blacklisted_domains=self._blacklisted_domains,
|
|
95
|
+
allow_local_storage_access=self._allow_local_storage_access,
|
|
96
|
+
)
|
|
97
|
+
self._image_in_memory = image
|
|
98
|
+
return image
|
|
99
|
+
|
|
100
|
+
def get_hash(self) -> str:
|
|
101
|
+
if self._image_hash is not None:
|
|
102
|
+
return self._image_hash
|
|
103
|
+
if self._image_reference is not None:
|
|
104
|
+
self._image_hash = hash_function(value=self._image_reference)
|
|
105
|
+
else:
|
|
106
|
+
self._image_hash = hash_function(value=self.as_numpy().tobytes())
|
|
107
|
+
return self._image_hash
|
|
108
|
+
|
|
109
|
+
def unload_image(self) -> None:
|
|
110
|
+
if self._image_in_memory is not None and self._image_reference is not None:
|
|
111
|
+
self._image_in_memory = None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def load_image_reference(
|
|
115
|
+
image_reference: Union[str, bytes],
|
|
116
|
+
allow_url_input: bool,
|
|
117
|
+
allow_non_https_url: bool,
|
|
118
|
+
allow_url_without_fqdn: bool,
|
|
119
|
+
whitelisted_domains: Optional[List[str]],
|
|
120
|
+
blacklisted_domains: Optional[List[str]],
|
|
121
|
+
allow_local_storage_access: bool,
|
|
122
|
+
) -> np.ndarray:
|
|
123
|
+
if isinstance(image_reference, bytes):
|
|
124
|
+
return decode_image_from_bytes(image_bytes=image_reference)
|
|
125
|
+
if is_url(reference=image_reference):
|
|
126
|
+
return decode_image_from_url(
|
|
127
|
+
url=image_reference,
|
|
128
|
+
allow_url_input=allow_url_input,
|
|
129
|
+
allow_non_https_url=allow_non_https_url,
|
|
130
|
+
allow_url_without_fqdn=allow_url_without_fqdn,
|
|
131
|
+
whitelisted_domains=whitelisted_domains,
|
|
132
|
+
blacklisted_domains=blacklisted_domains,
|
|
133
|
+
)
|
|
134
|
+
if not allow_local_storage_access:
|
|
135
|
+
return decode_image_from_base64(value=image_reference)
|
|
136
|
+
elif os.path.isfile(image_reference):
|
|
137
|
+
return cv2.imread(image_reference)
|
|
138
|
+
else:
|
|
139
|
+
return decode_image_from_base64(value=image_reference)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def decode_image_from_url(
|
|
143
|
+
url: str,
|
|
144
|
+
allow_url_input: bool,
|
|
145
|
+
allow_non_https_url: bool,
|
|
146
|
+
allow_url_without_fqdn: bool,
|
|
147
|
+
whitelisted_domains: Optional[List[str]],
|
|
148
|
+
blacklisted_domains: Optional[List[str]],
|
|
149
|
+
):
|
|
150
|
+
if not allow_url_input:
|
|
151
|
+
raise ModelInputError(
|
|
152
|
+
message="Providing images via URL is not supported in this configuration of `inference-models`.",
|
|
153
|
+
help_url="https://todo",
|
|
154
|
+
)
|
|
155
|
+
try:
|
|
156
|
+
parsed_url = urllib.parse.urlparse(url)
|
|
157
|
+
except ValueError as error:
|
|
158
|
+
raise ModelInputError(
|
|
159
|
+
message="Provided image URL is invalid.", help_url="https://todo"
|
|
160
|
+
) from error
|
|
161
|
+
if parsed_url.scheme != "https" and not allow_non_https_url:
|
|
162
|
+
raise ModelInputError(
|
|
163
|
+
message="Providing images via non https:// URL is not supported in this configuration of `inference-models`.",
|
|
164
|
+
help_url="https://todo",
|
|
165
|
+
)
|
|
166
|
+
domain_extraction_result = tldextract.TLDExtract(suffix_list_urls=())(
|
|
167
|
+
parsed_url.netloc
|
|
168
|
+
) # we get rid of potential ports and parse FQDNs
|
|
169
|
+
_ensure_resource_fqdn_allowed(
|
|
170
|
+
fqdn=domain_extraction_result.fqdn,
|
|
171
|
+
allow_url_without_fqdn=allow_url_without_fqdn,
|
|
172
|
+
)
|
|
173
|
+
address_parts_concatenated = _concatenate_chunks_of_network_location(
|
|
174
|
+
extraction_result=domain_extraction_result
|
|
175
|
+
) # concatenation of chunks - even if there is no FQDN, but address
|
|
176
|
+
# it allows white-/black-list verification
|
|
177
|
+
_ensure_location_matches_destination_whitelist(
|
|
178
|
+
destination=address_parts_concatenated,
|
|
179
|
+
whitelisted_domains=whitelisted_domains,
|
|
180
|
+
)
|
|
181
|
+
_ensure_location_matches_destination_blacklist(
|
|
182
|
+
destination=address_parts_concatenated,
|
|
183
|
+
blacklisted_domains=blacklisted_domains,
|
|
184
|
+
)
|
|
185
|
+
image_content = _get_from_url(url=url)
|
|
186
|
+
return decode_image_from_bytes(image_bytes=image_content)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def decode_image_from_base64(value: str) -> np.ndarray:
|
|
190
|
+
try:
|
|
191
|
+
value = BASE64_DATA_TYPE_PATTERN.sub("", value)
|
|
192
|
+
decoded = pybase64.b64decode(value, validate=True)
|
|
193
|
+
return decode_image_from_bytes(image_bytes=decoded)
|
|
194
|
+
except Exception as error:
|
|
195
|
+
value_prefix = value[:16]
|
|
196
|
+
raise ModelInputError(
|
|
197
|
+
message=f"Could not decode bas64 image fro reference {value_prefix}.",
|
|
198
|
+
help_url="https://todo",
|
|
199
|
+
) from error
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def decode_image_from_bytes(image_bytes: bytes) -> np.ndarray:
|
|
203
|
+
byte_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
|
204
|
+
return cv2.imdecode(byte_array, cv2.IMREAD_COLOR)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def is_url(reference: str) -> bool:
|
|
208
|
+
return reference.startswith("http://") or reference.startswith("https://")
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _ensure_resource_fqdn_allowed(fqdn: str, allow_url_without_fqdn: bool) -> None:
|
|
212
|
+
if not fqdn and not allow_url_without_fqdn:
|
|
213
|
+
raise ModelInputError(
|
|
214
|
+
message="Providing images via URL without FQDN is not supported in this configuration of `inference-models`.",
|
|
215
|
+
help_url="https://todo",
|
|
216
|
+
)
|
|
217
|
+
return None
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _concatenate_chunks_of_network_location(extraction_result: ExtractResult) -> str:
|
|
221
|
+
chunks = [
|
|
222
|
+
extraction_result.subdomain,
|
|
223
|
+
extraction_result.domain,
|
|
224
|
+
extraction_result.suffix,
|
|
225
|
+
]
|
|
226
|
+
non_empty_chunks = [chunk for chunk in chunks if chunk]
|
|
227
|
+
result = ".".join(non_empty_chunks)
|
|
228
|
+
if result.startswith("[") and result.endswith("]"):
|
|
229
|
+
# dropping brackets for IPv6
|
|
230
|
+
return result[1:-1]
|
|
231
|
+
return result
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _ensure_location_matches_destination_whitelist(
|
|
235
|
+
destination: str, whitelisted_domains: Optional[List[str]]
|
|
236
|
+
) -> None:
|
|
237
|
+
if whitelisted_domains is None:
|
|
238
|
+
return None
|
|
239
|
+
if destination not in whitelisted_domains:
|
|
240
|
+
raise ModelInputError(
|
|
241
|
+
message="It is not allowed to reach image URL - prohibited by whitelisted destinations",
|
|
242
|
+
help_url="https://todo",
|
|
243
|
+
)
|
|
244
|
+
return None
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _ensure_location_matches_destination_blacklist(
|
|
248
|
+
destination: str,
|
|
249
|
+
blacklisted_domains: Optional[List[str]],
|
|
250
|
+
) -> None:
|
|
251
|
+
if blacklisted_domains is None:
|
|
252
|
+
return None
|
|
253
|
+
if destination in blacklisted_domains:
|
|
254
|
+
raise ModelInputError(
|
|
255
|
+
message="It is not allowed to reach image URL - prohibited by blacklisted destinations.",
|
|
256
|
+
help_url="https://todo",
|
|
257
|
+
)
|
|
258
|
+
return None
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@backoff.on_exception(
|
|
262
|
+
backoff.constant,
|
|
263
|
+
exception=RetryError,
|
|
264
|
+
max_tries=API_CALLS_MAX_TRIES,
|
|
265
|
+
interval=1,
|
|
266
|
+
)
|
|
267
|
+
def _get_from_url(url: str, timeout: int = 5) -> bytes:
|
|
268
|
+
try:
|
|
269
|
+
with requests.get(url, stream=True, timeout=timeout) as response:
|
|
270
|
+
if response.status_code in IDEMPOTENT_API_REQUEST_CODES_TO_RETRY:
|
|
271
|
+
raise RetryError(
|
|
272
|
+
message=f"File hosting returned {response.status_code}",
|
|
273
|
+
help_url="https://todo",
|
|
274
|
+
)
|
|
275
|
+
response.raise_for_status()
|
|
276
|
+
return response.content
|
|
277
|
+
except (ConnectionError, Timeout, requests.exceptions.ConnectionError):
|
|
278
|
+
raise RetryError(
|
|
279
|
+
message=f"Connectivity error",
|
|
280
|
+
help_url="https://todo",
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def compute_image_hash(image: Union[torch.Tensor, np.ndarray]) -> str:
|
|
285
|
+
if isinstance(image, torch.Tensor):
|
|
286
|
+
image = image.cpu().numpy()
|
|
287
|
+
return hash_function(value=image.tobytes())
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def hash_function(value: Union[str, bytes]) -> str:
|
|
291
|
+
return hashlib.sha1(value).hexdigest()
|
|
File without changes
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from peft import PeftModel
|
|
7
|
+
from transformers import (
|
|
8
|
+
AutoProcessor,
|
|
9
|
+
BitsAndBytesConfig,
|
|
10
|
+
PaliGemmaForConditionalGeneration,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
14
|
+
from inference_models.entities import ColorFormat
|
|
15
|
+
from inference_models.models.common.roboflow.model_packages import (
|
|
16
|
+
InferenceConfig,
|
|
17
|
+
ResizeMode,
|
|
18
|
+
parse_inference_config,
|
|
19
|
+
)
|
|
20
|
+
from inference_models.models.common.roboflow.pre_processing import (
|
|
21
|
+
pre_process_network_input,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PaliGemmaHF:
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_pretrained(
|
|
29
|
+
cls,
|
|
30
|
+
model_name_or_path: str,
|
|
31
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
32
|
+
trust_remote_code: bool = False,
|
|
33
|
+
local_files_only: bool = True,
|
|
34
|
+
quantization_config: Optional[BitsAndBytesConfig] = None,
|
|
35
|
+
disable_quantization: bool = False,
|
|
36
|
+
**kwargs,
|
|
37
|
+
) -> "PaliGemmaHF":
|
|
38
|
+
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
|
|
39
|
+
inference_config_path = os.path.join(
|
|
40
|
+
model_name_or_path, "inference_config.json"
|
|
41
|
+
)
|
|
42
|
+
inference_config = None
|
|
43
|
+
if os.path.exists(inference_config_path):
|
|
44
|
+
inference_config = parse_inference_config(
|
|
45
|
+
config_path=inference_config_path,
|
|
46
|
+
allowed_resize_modes={
|
|
47
|
+
ResizeMode.STRETCH_TO,
|
|
48
|
+
ResizeMode.LETTERBOX,
|
|
49
|
+
ResizeMode.CENTER_CROP,
|
|
50
|
+
ResizeMode.LETTERBOX_REFLECT_EDGES,
|
|
51
|
+
},
|
|
52
|
+
)
|
|
53
|
+
if (
|
|
54
|
+
quantization_config is None
|
|
55
|
+
and device.type == "cuda"
|
|
56
|
+
and not disable_quantization
|
|
57
|
+
):
|
|
58
|
+
quantization_config = BitsAndBytesConfig(
|
|
59
|
+
load_in_4bit=True,
|
|
60
|
+
bnb_4bit_quant_type="nf4",
|
|
61
|
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
62
|
+
)
|
|
63
|
+
adapter_config_path = os.path.join(model_name_or_path, "adapter_config.json")
|
|
64
|
+
if os.path.exists(adapter_config_path):
|
|
65
|
+
base_model_path = os.path.join(model_name_or_path, "base")
|
|
66
|
+
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
|
67
|
+
base_model_path,
|
|
68
|
+
dtype=torch_dtype,
|
|
69
|
+
trust_remote_code=trust_remote_code,
|
|
70
|
+
local_files_only=local_files_only,
|
|
71
|
+
quantization_config=quantization_config,
|
|
72
|
+
)
|
|
73
|
+
model = PeftModel.from_pretrained(model, model_name_or_path)
|
|
74
|
+
if quantization_config is None:
|
|
75
|
+
model.merge_and_unload()
|
|
76
|
+
model.to(device)
|
|
77
|
+
|
|
78
|
+
processor = AutoProcessor.from_pretrained(
|
|
79
|
+
base_model_path,
|
|
80
|
+
trust_remote_code=trust_remote_code,
|
|
81
|
+
local_files_only=local_files_only,
|
|
82
|
+
use_fast=True,
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
|
86
|
+
model_name_or_path,
|
|
87
|
+
dtype=torch_dtype,
|
|
88
|
+
device_map=device,
|
|
89
|
+
trust_remote_code=trust_remote_code,
|
|
90
|
+
local_files_only=local_files_only,
|
|
91
|
+
quantization_config=quantization_config,
|
|
92
|
+
).eval()
|
|
93
|
+
processor = AutoProcessor.from_pretrained(
|
|
94
|
+
model_name_or_path,
|
|
95
|
+
trust_remote_code=trust_remote_code,
|
|
96
|
+
local_files_only=local_files_only,
|
|
97
|
+
use_fast=True,
|
|
98
|
+
)
|
|
99
|
+
return cls(
|
|
100
|
+
model=model,
|
|
101
|
+
processor=processor,
|
|
102
|
+
inference_config=inference_config,
|
|
103
|
+
device=device,
|
|
104
|
+
torch_dtype=torch_dtype,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
model: PaliGemmaForConditionalGeneration,
|
|
110
|
+
processor: AutoProcessor,
|
|
111
|
+
inference_config: Optional[InferenceConfig],
|
|
112
|
+
device: torch.device,
|
|
113
|
+
torch_dtype: torch.dtype,
|
|
114
|
+
):
|
|
115
|
+
self._model = model
|
|
116
|
+
self._processor = processor
|
|
117
|
+
self._inference_config = inference_config
|
|
118
|
+
self._device = device
|
|
119
|
+
self._torch_dtype = torch_dtype
|
|
120
|
+
|
|
121
|
+
def prompt(
|
|
122
|
+
self,
|
|
123
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
124
|
+
prompt: str,
|
|
125
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
126
|
+
max_new_tokens: int = 400,
|
|
127
|
+
do_sample: bool = False,
|
|
128
|
+
skip_special_tokens: bool = True,
|
|
129
|
+
**kwargs,
|
|
130
|
+
) -> List[str]:
|
|
131
|
+
inputs = self.pre_process_generation(
|
|
132
|
+
images=images, prompt=prompt, input_color_format=input_color_format
|
|
133
|
+
)
|
|
134
|
+
generated_ids = self.generate(
|
|
135
|
+
inputs=inputs,
|
|
136
|
+
max_new_tokens=max_new_tokens,
|
|
137
|
+
do_sample=do_sample,
|
|
138
|
+
)
|
|
139
|
+
return self.post_process_generation(
|
|
140
|
+
generated_ids=generated_ids,
|
|
141
|
+
skip_special_tokens=skip_special_tokens,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def pre_process_generation(
|
|
145
|
+
self,
|
|
146
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
147
|
+
prompt: str,
|
|
148
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
149
|
+
image_size: Optional[Tuple[int, int]] = None,
|
|
150
|
+
**kwargs,
|
|
151
|
+
) -> dict:
|
|
152
|
+
def _to_tensor(image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
|
153
|
+
is_numpy = isinstance(image, np.ndarray)
|
|
154
|
+
if is_numpy:
|
|
155
|
+
tensor_image = torch.from_numpy(image.copy()).permute(2, 0, 1)
|
|
156
|
+
else:
|
|
157
|
+
tensor_image = image
|
|
158
|
+
if input_color_format == "bgr" or (is_numpy and input_color_format is None):
|
|
159
|
+
tensor_image = tensor_image[[2, 1, 0], :, :]
|
|
160
|
+
return tensor_image
|
|
161
|
+
|
|
162
|
+
if self._inference_config is None:
|
|
163
|
+
if isinstance(images, torch.Tensor) and images.ndim > 3:
|
|
164
|
+
image_list = [_to_tensor(img) for img in images]
|
|
165
|
+
elif not isinstance(images, list):
|
|
166
|
+
image_list = [_to_tensor(images)]
|
|
167
|
+
else:
|
|
168
|
+
image_list = [_to_tensor(img) for img in images]
|
|
169
|
+
else:
|
|
170
|
+
images = pre_process_network_input(
|
|
171
|
+
images=images,
|
|
172
|
+
image_pre_processing=self._inference_config.image_pre_processing,
|
|
173
|
+
network_input=self._inference_config.network_input,
|
|
174
|
+
target_device=self._device,
|
|
175
|
+
input_color_format=input_color_format,
|
|
176
|
+
image_size_wh=image_size,
|
|
177
|
+
)[0]
|
|
178
|
+
image_list = [e[0] for e in torch.split(images, 1, dim=0)]
|
|
179
|
+
num_images = len(image_list)
|
|
180
|
+
|
|
181
|
+
if isinstance(prompt, str) and num_images > 1:
|
|
182
|
+
prompt = [prompt] * num_images
|
|
183
|
+
return self._processor(text=prompt, images=image_list, return_tensors="pt").to(
|
|
184
|
+
self._device
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def generate(
|
|
188
|
+
self,
|
|
189
|
+
inputs: dict,
|
|
190
|
+
max_new_tokens: int = 400,
|
|
191
|
+
do_sample: bool = False,
|
|
192
|
+
**kwargs,
|
|
193
|
+
) -> torch.Tensor:
|
|
194
|
+
with torch.inference_mode():
|
|
195
|
+
generation = self._model.generate(
|
|
196
|
+
**inputs, max_new_tokens=max_new_tokens, do_sample=do_sample
|
|
197
|
+
)
|
|
198
|
+
input_len = inputs["input_ids"].shape[-1]
|
|
199
|
+
return generation[:, input_len:]
|
|
200
|
+
|
|
201
|
+
def post_process_generation(
|
|
202
|
+
self,
|
|
203
|
+
generated_ids: torch.Tensor,
|
|
204
|
+
skip_special_tokens: bool = False,
|
|
205
|
+
**kwargs,
|
|
206
|
+
) -> List[str]:
|
|
207
|
+
return self._processor.batch_decode(
|
|
208
|
+
generated_ids, skip_special_tokens=skip_special_tokens
|
|
209
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Callable, List, Optional, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
import torchvision.transforms as T
|
|
7
|
+
from pydantic import BaseModel, ValidationError
|
|
8
|
+
|
|
9
|
+
import inference_models.models.perception_encoder.vision_encoder.pe as pe
|
|
10
|
+
import inference_models.models.perception_encoder.vision_encoder.transforms as transforms
|
|
11
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
12
|
+
from inference_models.entities import ColorFormat
|
|
13
|
+
from inference_models.errors import CorruptedModelPackageError
|
|
14
|
+
from inference_models.models.base.embeddings import TextImageEmbeddingModel
|
|
15
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PerceptionEncoderConfig(BaseModel):
|
|
19
|
+
vision_encoder_config: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def load_config(config_path: str) -> PerceptionEncoderConfig:
|
|
23
|
+
config_data = {}
|
|
24
|
+
try:
|
|
25
|
+
with open(config_path) as f:
|
|
26
|
+
config_data = json.load(f)
|
|
27
|
+
except (IOError, json.JSONDecodeError) as e:
|
|
28
|
+
raise CorruptedModelPackageError(
|
|
29
|
+
message=f"Could not load or parse perception encoder model package config file: {config_path}. Details: {e}",
|
|
30
|
+
help_url="https://todo",
|
|
31
|
+
) from e
|
|
32
|
+
try:
|
|
33
|
+
config = PerceptionEncoderConfig.model_validate(config_data)
|
|
34
|
+
return config
|
|
35
|
+
except ValidationError as e:
|
|
36
|
+
raise CorruptedModelPackageError(
|
|
37
|
+
f"Failed validate perception encoder model package config file: {config_path}. Details: {e}"
|
|
38
|
+
) from e
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# based on original implementation using PIL images found in vision_encoder/transforms.py
|
|
42
|
+
# but adjusted to work directly on tensors
|
|
43
|
+
def create_image_resize_transform(
|
|
44
|
+
image_size: int,
|
|
45
|
+
center_crop: bool = False,
|
|
46
|
+
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
|
|
47
|
+
):
|
|
48
|
+
if center_crop:
|
|
49
|
+
crop = [
|
|
50
|
+
T.Resize(image_size, interpolation=interpolation, antialias=True),
|
|
51
|
+
T.CenterCrop(image_size),
|
|
52
|
+
]
|
|
53
|
+
else:
|
|
54
|
+
# "Squash": most versatile
|
|
55
|
+
crop = [
|
|
56
|
+
T.Resize(
|
|
57
|
+
(image_size, image_size), interpolation=interpolation, antialias=True
|
|
58
|
+
)
|
|
59
|
+
]
|
|
60
|
+
return T.Compose(crop)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def create_image_normalize_transform():
|
|
64
|
+
return T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def create_preprocessor(image_size: int) -> Callable:
|
|
68
|
+
resize_transform = create_image_resize_transform(image_size)
|
|
69
|
+
normalize_transform = create_image_normalize_transform()
|
|
70
|
+
|
|
71
|
+
def _preprocess(
|
|
72
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
73
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
74
|
+
) -> torch.Tensor:
|
|
75
|
+
def _to_tensor(image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
|
76
|
+
is_numpy = isinstance(image, np.ndarray)
|
|
77
|
+
if is_numpy:
|
|
78
|
+
tensor_image = torch.from_numpy(image).permute(2, 0, 1)
|
|
79
|
+
else:
|
|
80
|
+
tensor_image = image
|
|
81
|
+
|
|
82
|
+
# For numpy array inputs, we default to BGR -> RGB conversion for compatibility.
|
|
83
|
+
# For tensor inputs, we only convert if BGR is explicitly specified, otherwise RGB is assumed.
|
|
84
|
+
if input_color_format == "bgr" or (is_numpy and input_color_format is None):
|
|
85
|
+
# BGR -> RGB
|
|
86
|
+
tensor_image = tensor_image[[2, 1, 0], :, :]
|
|
87
|
+
|
|
88
|
+
return tensor_image
|
|
89
|
+
|
|
90
|
+
if isinstance(images, list):
|
|
91
|
+
# Resize each image individually, then stack to a batch
|
|
92
|
+
resized_images = [resize_transform(_to_tensor(img)) for img in images]
|
|
93
|
+
tensor_batch = torch.stack(resized_images, dim=0)
|
|
94
|
+
else:
|
|
95
|
+
# Handle single image or pre-batched tensor
|
|
96
|
+
tensor_batch = resize_transform(_to_tensor(images))
|
|
97
|
+
|
|
98
|
+
# Ensure there is a batch dimension for single images
|
|
99
|
+
if tensor_batch.ndim == 3:
|
|
100
|
+
tensor_batch = tensor_batch.unsqueeze(0)
|
|
101
|
+
|
|
102
|
+
# Perform dtype conversion and normalization on the whole batch for efficiency
|
|
103
|
+
if tensor_batch.dtype == torch.uint8:
|
|
104
|
+
tensor_batch = tensor_batch.to(torch.float32) / 255.0
|
|
105
|
+
|
|
106
|
+
transformed_batch = normalize_transform(tensor_batch)
|
|
107
|
+
return transformed_batch
|
|
108
|
+
|
|
109
|
+
return _preprocess
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class PerceptionEncoderTorch(TextImageEmbeddingModel):
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
model: pe.CLIP,
|
|
116
|
+
device: torch.device,
|
|
117
|
+
):
|
|
118
|
+
self.model = model
|
|
119
|
+
self.device = device
|
|
120
|
+
self.preprocessor = create_preprocessor(model.image_size)
|
|
121
|
+
self.tokenizer = transforms.get_text_tokenizer(model.context_length)
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def from_pretrained(
|
|
125
|
+
cls, model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, **kwargs
|
|
126
|
+
) -> "PerceptionEncoderTorch":
|
|
127
|
+
# here model name came from path before, which maybe doesn't match directly with how our registry works
|
|
128
|
+
# instead should this be adopted to read config file that is served as part of model package?
|
|
129
|
+
# model_config = model_name_or_path.split("/")[-1]
|
|
130
|
+
# checkpoint_path = os.path.join(model_name_or_path, "model.pt")
|
|
131
|
+
|
|
132
|
+
model_package_content = get_model_package_contents(
|
|
133
|
+
model_package_dir=model_name_or_path,
|
|
134
|
+
elements=["config.json", "model.pt"],
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
model_config_file = model_package_content["config.json"]
|
|
138
|
+
model_weights_file = model_package_content["model.pt"]
|
|
139
|
+
config = load_config(model_config_file)
|
|
140
|
+
|
|
141
|
+
model = pe.CLIP.from_config(
|
|
142
|
+
config.vision_encoder_config,
|
|
143
|
+
pretrained=True,
|
|
144
|
+
checkpoint_path=model_weights_file,
|
|
145
|
+
)
|
|
146
|
+
model = model.to(device)
|
|
147
|
+
model.eval()
|
|
148
|
+
|
|
149
|
+
return cls(
|
|
150
|
+
model=model,
|
|
151
|
+
device=device,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def embed_images(
|
|
155
|
+
self,
|
|
156
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
157
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
158
|
+
**kwargs,
|
|
159
|
+
) -> torch.Tensor:
|
|
160
|
+
img_in = self.preprocessor(images, input_color_format=input_color_format).to(
|
|
161
|
+
self.device
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if self.device.type == "cpu" or self.device.type == "mps":
|
|
165
|
+
with torch.inference_mode():
|
|
166
|
+
image_features, _, _ = self.model(img_in, None)
|
|
167
|
+
embeddings = image_features.float()
|
|
168
|
+
else:
|
|
169
|
+
with torch.inference_mode(), torch.autocast(self.device.type):
|
|
170
|
+
image_features, _, _ = self.model(img_in, None)
|
|
171
|
+
embeddings = image_features.float()
|
|
172
|
+
|
|
173
|
+
return embeddings
|
|
174
|
+
|
|
175
|
+
def embed_text(
|
|
176
|
+
self,
|
|
177
|
+
texts: Union[str, List[str]],
|
|
178
|
+
**kwargs,
|
|
179
|
+
) -> torch.Tensor:
|
|
180
|
+
if isinstance(texts, list):
|
|
181
|
+
texts_to_embed = texts
|
|
182
|
+
else:
|
|
183
|
+
texts_to_embed = [texts]
|
|
184
|
+
|
|
185
|
+
# results = []
|
|
186
|
+
# The original implementation had batching here based on CLIP_MAX_BATCH_SIZE, but not entirely sure how to handle that with Tensor output
|
|
187
|
+
# I will leave it out for now, see https://github.com/roboflow/inference/blob/main/inference/models/perception_encoder/perception_encoder.py#L227
|
|
188
|
+
tokenized = self.tokenizer(texts_to_embed).to(self.device)
|
|
189
|
+
if self.device.type == "cpu" or self.device.type == "mps":
|
|
190
|
+
with torch.no_grad():
|
|
191
|
+
_, text_features, _ = self.model(None, tokenized)
|
|
192
|
+
else:
|
|
193
|
+
with torch.inference_mode(), torch.autocast(self.device.type):
|
|
194
|
+
_, text_features, _ = self.model(None, tokenized)
|
|
195
|
+
|
|
196
|
+
embeddings = text_features.float()
|
|
197
|
+
return embeddings
|
|
File without changes
|