inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,141 @@
1
+ import os.path
2
+ from pickle import UnpicklingError
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from inference_models import Detections, ObjectDetectionModel
9
+ from inference_models.configuration import DEFAULT_DEVICE
10
+ from inference_models.entities import ImageDimensions
11
+ from inference_models.errors import CorruptedModelPackageError
12
+ from inference_models.models.auto_loaders.entities import AnyModel
13
+ from inference_models.models.base.types import (
14
+ PreprocessedInputs,
15
+ PreprocessingMetadata,
16
+ RawPrediction,
17
+ )
18
+ from inference_models.models.common.model_packages import get_model_package_contents
19
+ from inference_models.models.owlv2.entities import (
20
+ ImageEmbeddings,
21
+ ReferenceExamplesEmbeddings,
22
+ )
23
+ from inference_models.models.owlv2.owlv2_hf import OWLv2HF
24
+
25
+
26
+ class RoboflowInstantHF(ObjectDetectionModel):
27
+
28
+ @classmethod
29
+ def from_pretrained(
30
+ cls,
31
+ model_name_or_path: str,
32
+ device: torch.device = DEFAULT_DEVICE,
33
+ model_dependencies: Optional[Dict[str, AnyModel]] = None,
34
+ **kwargs,
35
+ ) -> "ObjectDetectionModel":
36
+ model_package_content = get_model_package_contents(
37
+ model_package_dir=model_name_or_path,
38
+ elements=["weights.pt"],
39
+ )
40
+ model_dependencies = model_dependencies or {}
41
+ if "feature_extractor" in model_dependencies:
42
+ feature_extractor: OWLv2HF = model_dependencies["feature_extractor"]
43
+ else:
44
+ feature_extractor = OWLv2HF.from_pretrained(
45
+ os.path.join(
46
+ model_name_or_path, "model_dependencies", "feature_extractor"
47
+ ),
48
+ **kwargs,
49
+ )
50
+ try:
51
+ weights_dict = torch.load(
52
+ model_package_content["weights.pt"],
53
+ map_location=device,
54
+ weights_only=True,
55
+ )
56
+ except UnpicklingError as error:
57
+ raise CorruptedModelPackageError(
58
+ message="Could not deserialize RF Instant model weights. Contact Roboflow to get help.",
59
+ help_url="https://todo",
60
+ ) from error
61
+ if "class_names" not in weights_dict or "train_data_dict" not in weights_dict:
62
+ raise CorruptedModelPackageError(
63
+ message="Corrupted weights of Roboflow Instant model detected. Contact Roboflow to get help.",
64
+ help_url="https://todo",
65
+ )
66
+ class_names = weights_dict["class_names"]
67
+ train_data_dict = weights_dict["train_data_dict"]
68
+ try:
69
+ reference_examples_embeddings = (
70
+ ReferenceExamplesEmbeddings.from_class_embeddings_dict(
71
+ class_embeddings=train_data_dict,
72
+ device=device,
73
+ )
74
+ )
75
+ except Exception as error:
76
+ raise CorruptedModelPackageError(
77
+ message="Could not decode RF Instant model weights. Contact Roboflow to get help.",
78
+ help_url="https://todo",
79
+ ) from error
80
+ return cls(
81
+ feature_extractor=feature_extractor,
82
+ class_names=class_names,
83
+ reference_examples_embeddings=reference_examples_embeddings,
84
+ )
85
+
86
+ def __init__(
87
+ self,
88
+ feature_extractor: OWLv2HF,
89
+ class_names: List[str],
90
+ reference_examples_embeddings: ReferenceExamplesEmbeddings,
91
+ ):
92
+ self._feature_extractor = feature_extractor
93
+ self._class_names = class_names
94
+ self._reference_examples_embeddings = reference_examples_embeddings
95
+
96
+ @property
97
+ def class_names(self) -> List[str]:
98
+ return self._class_names
99
+
100
+ def pre_process(
101
+ self,
102
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
103
+ max_detections: int = 300,
104
+ **kwargs,
105
+ ) -> Tuple[List[ImageEmbeddings], List[ImageDimensions]]:
106
+ images_embeddings, images_dimensions = self._feature_extractor.embed_images(
107
+ images=images,
108
+ max_detections=max_detections,
109
+ )
110
+ return images_embeddings, images_dimensions
111
+
112
+ def forward(
113
+ self,
114
+ pre_processed_images: List[ImageEmbeddings],
115
+ confidence_threshold: float = 0.99,
116
+ iou_threshold: float = 0.3,
117
+ **kwargs,
118
+ ) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
119
+ return self._feature_extractor.forward_pass_with_precomputed_embeddings(
120
+ images_embeddings=pre_processed_images,
121
+ class_embeddings=self._reference_examples_embeddings.class_embeddings,
122
+ confidence_threshold=confidence_threshold,
123
+ iou_threshold=iou_threshold,
124
+ )
125
+
126
+ def post_process(
127
+ self,
128
+ model_results: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
129
+ pre_processing_meta: List[ImageDimensions],
130
+ max_detections: int = 300,
131
+ iou_threshold: float = 0.3,
132
+ **kwargs,
133
+ ) -> List[Detections]:
134
+ return (
135
+ self._feature_extractor.post_process_predictions_for_precomputed_embeddings(
136
+ predictions=model_results,
137
+ images_dimensions=pre_processing_meta,
138
+ max_detections=max_detections,
139
+ iou_threshold=iou_threshold,
140
+ )
141
+ )
File without changes
@@ -0,0 +1,147 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections import OrderedDict
3
+ from threading import Lock
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+ from inference_models.errors import EnvironmentConfigurationError
9
+ from inference_models.models.sam.entities import SAMImageEmbeddings
10
+
11
+
12
+ class SamImageEmbeddingsCache(ABC):
13
+
14
+ @abstractmethod
15
+ def retrieve_embeddings(self, key: str) -> Optional[SAMImageEmbeddings]:
16
+ pass
17
+
18
+ @abstractmethod
19
+ def save_embeddings(self, key: str, embeddings: SAMImageEmbeddings) -> None:
20
+ pass
21
+
22
+
23
+ class SamImageEmbeddingsCacheNullObject(SamImageEmbeddingsCache):
24
+
25
+ def retrieve_embeddings(self, key: str) -> Optional[SAMImageEmbeddings]:
26
+ pass
27
+
28
+ def save_embeddings(self, key: str, embeddings: SAMImageEmbeddings) -> None:
29
+ pass
30
+
31
+
32
+ class SamImageEmbeddingsInMemoryCache(SamImageEmbeddingsCache):
33
+
34
+ @classmethod
35
+ def init(
36
+ cls, size_limit: Optional[int], send_to_cpu: bool = True
37
+ ) -> "SamImageEmbeddingsInMemoryCache":
38
+ return cls(
39
+ state=OrderedDict(),
40
+ size_limit=size_limit,
41
+ send_to_cpu=send_to_cpu,
42
+ )
43
+
44
+ def __init__(
45
+ self,
46
+ state: OrderedDict,
47
+ size_limit: Optional[int],
48
+ send_to_cpu: bool = True,
49
+ ):
50
+ self._state = state
51
+ self._size_limit = size_limit
52
+ self._send_to_cpu = send_to_cpu
53
+ self._state_lock = Lock()
54
+
55
+ def retrieve_embeddings(self, key: str) -> Optional[SAMImageEmbeddings]:
56
+ return self._state.get(key)
57
+
58
+ def save_embeddings(self, key: str, embeddings: SAMImageEmbeddings) -> None:
59
+ with self._state_lock:
60
+ if key in self._state:
61
+ return None
62
+ self._ensure_cache_has_capacity()
63
+ if self._send_to_cpu:
64
+ embeddings = embeddings.to(device=torch.device("cpu"))
65
+ self._state[key] = embeddings
66
+
67
+ def _ensure_cache_has_capacity(self) -> None:
68
+ if self._size_limit < 1:
69
+ raise EnvironmentConfigurationError(
70
+ message=f"In memory cache size for SAM embeddings was set to invalid value. "
71
+ f"If you are running inference locally - adjust settings of your deployment. If you see this "
72
+ f"error running on Roboflow platform - contact us to get help.",
73
+ help_url="https://todo",
74
+ )
75
+ if self._size_limit is None or self._size_limit < 1:
76
+ return None
77
+ while len(self._state) > self._size_limit:
78
+ _ = self._state.popitem(last=False)
79
+
80
+
81
+ class SamLowResolutionMasksCache(ABC):
82
+
83
+ @abstractmethod
84
+ def retrieve_mask(self, key: str) -> Optional[torch.Tensor]:
85
+ pass
86
+
87
+ @abstractmethod
88
+ def save_mask(self, key: str, mask: torch.Tensor) -> None:
89
+ pass
90
+
91
+
92
+ class SamLowResolutionMasksCacheNullObject(SamLowResolutionMasksCache):
93
+
94
+ def retrieve_mask(self, key: str) -> Optional[torch.Tensor]:
95
+ pass
96
+
97
+ def save_mask(self, key: str, mask: torch.Tensor) -> None:
98
+ pass
99
+
100
+
101
+ class SamLowResolutionMasksInMemoryCache(SamLowResolutionMasksCache):
102
+
103
+ @classmethod
104
+ def init(
105
+ cls, size_limit: Optional[int], send_to_cpu: bool = True
106
+ ) -> "SamLowResolutionMasksInMemoryCache":
107
+ return cls(
108
+ state=OrderedDict(),
109
+ size_limit=size_limit,
110
+ send_to_cpu=send_to_cpu,
111
+ )
112
+
113
+ def __init__(
114
+ self,
115
+ state: OrderedDict,
116
+ size_limit: Optional[int],
117
+ send_to_cpu: bool = True,
118
+ ):
119
+ self._state = state
120
+ self._size_limit = size_limit
121
+ self._send_to_cpu = send_to_cpu
122
+ self._state_lock = Lock()
123
+
124
+ def retrieve_mask(self, key: str) -> Optional[torch.Tensor]:
125
+ return self._state.get(key)
126
+
127
+ def save_mask(self, key: str, mask: torch.Tensor) -> None:
128
+ with self._state_lock:
129
+ if key in self._state:
130
+ return None
131
+ self._ensure_cache_has_capacity()
132
+ if self._send_to_cpu:
133
+ mask = mask.to(device=torch.device("cpu"))
134
+ self._state[key] = mask
135
+
136
+ def _ensure_cache_has_capacity(self) -> None:
137
+ if self._size_limit < 1:
138
+ raise EnvironmentConfigurationError(
139
+ message=f"In memory cache size for SAM embeddings was set to invalid value. "
140
+ f"If you are running inference locally - adjust settings of your deployment. If you see this "
141
+ f"error running on Roboflow platform - contact us to get help.",
142
+ help_url="https://todo",
143
+ )
144
+ if self._size_limit is None or self._size_limit < 1:
145
+ return None
146
+ while len(self._state) > self._size_limit:
147
+ _ = self._state.popitem(last=False)
@@ -0,0 +1,25 @@
1
+ from dataclasses import dataclass
2
+ from typing import Tuple
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class SAMImageEmbeddings:
9
+ image_hash: str
10
+ image_size_hw: Tuple[int, int]
11
+ embeddings: torch.Tensor
12
+
13
+ def to(self, device: torch.device) -> "SAMImageEmbeddings":
14
+ return SAMImageEmbeddings(
15
+ image_hash=self.image_hash,
16
+ image_size_hw=self.image_size_hw,
17
+ embeddings=self.embeddings.to(device=device),
18
+ )
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class SAMPrediction:
23
+ masks: torch.Tensor
24
+ scores: torch.Tensor
25
+ logits: torch.Tensor