inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1329 @@
1
+ from collections import Counter
2
+ from dataclasses import dataclass
3
+ from functools import cache
4
+ from typing import List, Optional, Set, Tuple, Union
5
+
6
+ import torch
7
+ from packaging.version import Version
8
+
9
+ from inference_models.errors import (
10
+ AmbiguousModelPackageResolutionError,
11
+ AssumptionError,
12
+ InvalidRequestedBatchSizeError,
13
+ ModelPackageNegotiationError,
14
+ NoModelPackagesAvailableError,
15
+ UnknownBackendTypeError,
16
+ UnknownQuantizationError,
17
+ )
18
+ from inference_models.logger import verbose_info
19
+ from inference_models.models.auto_loaders.constants import (
20
+ NMS_CLASS_AGNOSTIC_KEY,
21
+ NMS_CONFIDENCE_THRESHOLD_KEY,
22
+ NMS_FUSED_FEATURE,
23
+ NMS_IOU_THRESHOLD_KEY,
24
+ NMS_MAX_DETECTIONS_KEY,
25
+ )
26
+ from inference_models.models.auto_loaders.entities import (
27
+ BackendType,
28
+ ModelArchitecture,
29
+ TaskType,
30
+ )
31
+ from inference_models.models.auto_loaders.models_registry import (
32
+ model_implementation_exists,
33
+ )
34
+ from inference_models.models.auto_loaders.ranking import rank_model_packages
35
+ from inference_models.models.auto_loaders.utils import (
36
+ filter_available_devices_with_selected_device,
37
+ )
38
+ from inference_models.runtime_introspection.core import (
39
+ RuntimeXRayResult,
40
+ x_ray_runtime_environment,
41
+ )
42
+ from inference_models.utils.onnx_introspection import (
43
+ get_selected_onnx_execution_providers,
44
+ )
45
+ from inference_models.weights_providers.entities import (
46
+ JetsonEnvironmentRequirements,
47
+ ModelPackageMetadata,
48
+ Quantization,
49
+ ServerEnvironmentRequirements,
50
+ )
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class DiscardedPackage:
55
+ package_id: str
56
+ reason: str
57
+
58
+
59
+ def negotiate_model_packages(
60
+ model_architecture: ModelArchitecture,
61
+ task_type: Optional[TaskType],
62
+ model_packages: List[ModelPackageMetadata],
63
+ requested_model_package_id: Optional[str] = None,
64
+ requested_backends: Optional[
65
+ Union[str, BackendType, List[Union[str, BackendType]]]
66
+ ] = None,
67
+ requested_batch_size: Optional[Union[int, Tuple[int, int]]] = None,
68
+ requested_quantization: Optional[
69
+ Union[str, Quantization, List[Union[str, Quantization]]]
70
+ ] = None,
71
+ device: Optional[torch.device] = None,
72
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
73
+ allow_untrusted_packages: bool = False,
74
+ trt_engine_host_code_allowed: bool = True,
75
+ nms_fusion_preferences: Optional[Union[bool, dict]] = None,
76
+ verbose: bool = False,
77
+ ) -> List[ModelPackageMetadata]:
78
+ verbose_info(
79
+ "The following model packages were exposed by weights provider:",
80
+ verbose_requested=verbose,
81
+ )
82
+ print_model_packages(model_packages=model_packages, verbose=verbose)
83
+ if not model_packages:
84
+ raise NoModelPackagesAvailableError(
85
+ message=f"Could not find any model package announced by weights provider. If you see this error "
86
+ f"using Roboflow platform your model may not be ready - if the problem is persistent, "
87
+ f"contact us to get help. If you use weights provider other than Roboflow - this is most likely "
88
+ f"the root cause of the error.",
89
+ help_url="https://todo",
90
+ )
91
+ if requested_model_package_id is not None:
92
+ return [
93
+ select_model_package_by_id(
94
+ model_packages=model_packages,
95
+ requested_model_package_id=requested_model_package_id,
96
+ verbose=verbose,
97
+ )
98
+ ]
99
+ model_packages, discarded_packages = remove_packages_not_matching_implementation(
100
+ model_architecture=model_architecture,
101
+ task_type=task_type,
102
+ model_packages=model_packages,
103
+ )
104
+ if not allow_untrusted_packages:
105
+ model_packages, discarded_untrusted_packages = remove_untrusted_packages(
106
+ model_packages=model_packages,
107
+ verbose=verbose,
108
+ )
109
+ discarded_packages.extend(discarded_untrusted_packages)
110
+ if requested_backends is not None:
111
+ model_packages, discarded_by_not_matching_backend = (
112
+ filter_model_packages_by_requested_backend(
113
+ model_packages=model_packages,
114
+ requested_backends=requested_backends,
115
+ verbose=verbose,
116
+ )
117
+ )
118
+ discarded_packages.extend(discarded_by_not_matching_backend)
119
+ if requested_batch_size is not None:
120
+ model_packages, discarded_by_batch_size = (
121
+ filter_model_packages_by_requested_batch_size(
122
+ model_packages=model_packages,
123
+ requested_batch_size=requested_batch_size,
124
+ verbose=verbose,
125
+ )
126
+ )
127
+ discarded_packages.extend(discarded_by_batch_size)
128
+ default_quantization = False
129
+ if requested_quantization is None:
130
+ default_quantization = True
131
+ requested_quantization = determine_default_allowed_quantization(device=device)
132
+ if requested_quantization:
133
+ model_packages, discarded_by_quantization = (
134
+ filter_model_packages_by_requested_quantization(
135
+ model_packages=model_packages,
136
+ requested_quantization=requested_quantization,
137
+ default_quantization_used=default_quantization,
138
+ verbose=verbose,
139
+ )
140
+ )
141
+ discarded_packages.extend(discarded_by_quantization)
142
+ model_packages, discarded_by_env_matching = (
143
+ filter_model_packages_matching_runtime_environment(
144
+ model_packages=model_packages,
145
+ device=device,
146
+ onnx_execution_providers=onnx_execution_providers,
147
+ trt_engine_host_code_allowed=trt_engine_host_code_allowed,
148
+ verbose=verbose,
149
+ )
150
+ )
151
+ discarded_packages.extend(discarded_by_env_matching)
152
+ model_packages, discarded_by_model_features = (
153
+ filter_model_packages_based_on_model_features(
154
+ model_packages=model_packages,
155
+ nms_fusion_preferences=nms_fusion_preferences,
156
+ model_architecture=model_architecture,
157
+ task_type=task_type,
158
+ )
159
+ )
160
+ if not model_packages:
161
+ rejections_summary = summarise_discarded_packages(
162
+ discarded_packages=discarded_packages
163
+ )
164
+ raise NoModelPackagesAvailableError(
165
+ message=f"Auto-negotiation protocol could not select model packages. This situation may be caused by "
166
+ f"several issues, with the most common being missing dependencies or too strict requirements "
167
+ f"stated as parameters of loading function. Below you can find reasons why specific model "
168
+ f"packages were rejected:\n{rejections_summary}\n",
169
+ help_url="https://todo",
170
+ )
171
+ model_packages = rank_model_packages(
172
+ model_packages=model_packages,
173
+ selected_device=device,
174
+ nms_fusion_preferences=nms_fusion_preferences,
175
+ )
176
+ verbose_info("Eligible packages ranked:", verbose_requested=verbose)
177
+ print_model_packages(model_packages=model_packages, verbose=verbose)
178
+ return model_packages
179
+
180
+
181
+ def summarise_discarded_packages(discarded_packages: List[DiscardedPackage]) -> str:
182
+ reasons_and_counts = Counter()
183
+ for package in discarded_packages:
184
+ reasons_and_counts[package.reason] += 1
185
+ reasons_stats = reasons_and_counts.most_common()
186
+ result = []
187
+ for reason, count in reasons_stats:
188
+ package_str = "package" if count < 2 else "packages"
189
+ result.append(f"\t* {count} {package_str} with the following note: {reason}")
190
+ return "\n".join(result)
191
+
192
+
193
+ @cache
194
+ def determine_default_allowed_quantization(
195
+ device: Optional[torch.device] = None,
196
+ ) -> List[Quantization]:
197
+ if device is not None:
198
+ if device.type == "cpu":
199
+ return [
200
+ Quantization.UNKNOWN,
201
+ Quantization.FP32,
202
+ Quantization.BF16,
203
+ ]
204
+ return [
205
+ Quantization.UNKNOWN,
206
+ Quantization.FP32,
207
+ Quantization.FP16,
208
+ ]
209
+ runtime_x_ray = x_ray_runtime_environment()
210
+ if runtime_x_ray.gpu_devices:
211
+ return [
212
+ Quantization.UNKNOWN,
213
+ Quantization.FP32,
214
+ Quantization.FP16,
215
+ ]
216
+ return [
217
+ Quantization.UNKNOWN,
218
+ Quantization.FP32,
219
+ Quantization.FP16,
220
+ Quantization.BF16,
221
+ ]
222
+
223
+
224
+ def print_model_packages(
225
+ model_packages: List[ModelPackageMetadata], verbose: bool
226
+ ) -> None:
227
+ if not model_packages:
228
+ verbose_info(message="No model packages.", verbose_requested=verbose)
229
+ return None
230
+ contents = []
231
+ for i, model_package in enumerate(model_packages):
232
+ contents.append(f"{i+1}. {model_package.get_summary()}")
233
+ verbose_info(message="\n".join(contents), verbose_requested=verbose)
234
+
235
+
236
+ def remove_packages_not_matching_implementation(
237
+ model_architecture: ModelArchitecture,
238
+ task_type: Optional[TaskType],
239
+ model_packages: List[ModelPackageMetadata],
240
+ verbose: bool = False,
241
+ ) -> Tuple[List[ModelPackageMetadata], List[DiscardedPackage]]:
242
+ result, discarded = [], []
243
+ for model_package in model_packages:
244
+ if not model_implementation_exists(
245
+ model_architecture=model_architecture,
246
+ task_type=task_type,
247
+ backend=model_package.backend,
248
+ model_features=model_package.model_features,
249
+ ):
250
+ verbose_info(
251
+ message=f"Model package with id `{model_package.package_id}` is filtered out as `inference-models` "
252
+ f"does not provide implementation for the model architecture {model_architecture} with "
253
+ f"task type: {task_type}, backend {model_package.backend} and requested model "
254
+ f"features {model_package.model_features}.",
255
+ verbose_requested=verbose,
256
+ )
257
+ if model_package.model_features:
258
+ model_features_infix = (
259
+ f" (and requested model features {model_package.model_features})"
260
+ )
261
+ else:
262
+ model_features_infix = ""
263
+ discarded.append(
264
+ DiscardedPackage(
265
+ package_id=model_package.package_id,
266
+ reason=f"`inference-models` does not provide implementation for the model {model_architecture} "
267
+ f"({task_type}) with backend {model_package.backend.value}{model_features_infix}. "
268
+ "If new versions of package available, consider the upgrade - we may already have this "
269
+ "package supported.",
270
+ )
271
+ )
272
+ continue
273
+ result.append(model_package)
274
+ return result, discarded
275
+
276
+
277
+ def remove_untrusted_packages(
278
+ model_packages: List[ModelPackageMetadata],
279
+ verbose: bool = False,
280
+ ) -> Tuple[List[ModelPackageMetadata], List[DiscardedPackage]]:
281
+ result, discarded_packages = [], []
282
+ for model_package in model_packages:
283
+ if not model_package.trusted_source:
284
+ verbose_info(
285
+ message=f"Model package with id `{model_package.package_id}` is filtered out as come from "
286
+ f"untrusted source.",
287
+ verbose_requested=verbose,
288
+ )
289
+ discarded_packages.append(
290
+ DiscardedPackage(
291
+ package_id=model_package.package_id,
292
+ reason="Package is marked as `untrusted` and auto-loader was used with "
293
+ "`allow_untrusted_packages=False`",
294
+ )
295
+ )
296
+ continue
297
+ result.append(model_package)
298
+ return result, discarded_packages
299
+
300
+
301
+ def select_model_package_by_id(
302
+ model_packages: List[ModelPackageMetadata],
303
+ requested_model_package_id: str,
304
+ verbose: bool = False,
305
+ ) -> ModelPackageMetadata:
306
+ matching_packages = [
307
+ p for p in model_packages if p.package_id == requested_model_package_id
308
+ ]
309
+ if not matching_packages:
310
+ raise NoModelPackagesAvailableError(
311
+ message=f"Requested model package ID: {requested_model_package_id} cannot be resolved among "
312
+ f"the model packages announced by weights provider. This may indicate either "
313
+ f"typo on the identifier or a change in set of models packages being announced by provider.",
314
+ help_url="https://todo",
315
+ )
316
+ if len(matching_packages) > 1:
317
+ raise AmbiguousModelPackageResolutionError(
318
+ message=f"Requested model package ID: {requested_model_package_id} resolved to {len(matching_packages)} "
319
+ f"different packages announced by weights provider. That is most likely weights provider "
320
+ f"error, as it is supposed to provide unique identifiers for each model package.",
321
+ help_url="https://todo",
322
+ )
323
+ verbose_info(
324
+ message=f"Model package matching requested package id: {matching_packages[0].get_summary()}",
325
+ verbose_requested=verbose,
326
+ )
327
+ return matching_packages[0]
328
+
329
+
330
+ def filter_model_packages_by_requested_backend(
331
+ model_packages: List[ModelPackageMetadata],
332
+ requested_backends: Union[str, BackendType, List[Union[str, BackendType]]],
333
+ verbose: bool = False,
334
+ ) -> Tuple[List[ModelPackageMetadata], List[DiscardedPackage]]:
335
+ if not isinstance(requested_backends, list):
336
+ requested_backends = [requested_backends]
337
+ requested_backends_set = set()
338
+ for requested_backend in requested_backends:
339
+ if isinstance(requested_backend, str):
340
+ requested_backend = parse_backend_type(value=requested_backend)
341
+ requested_backends_set.add(requested_backend)
342
+ verbose_info(
343
+ message=f"Filtering model packages by requested backends: {requested_backends_set}",
344
+ verbose_requested=verbose,
345
+ )
346
+ requested_backends_serialised = [b.value for b in requested_backends_set]
347
+ filtered_packages, discarded_packages = [], []
348
+ for model_package in model_packages:
349
+ if model_package.backend not in requested_backends_set:
350
+ verbose_info(
351
+ message=f"Model package with id `{model_package.package_id}` does not match requested backends.",
352
+ verbose_requested=verbose,
353
+ )
354
+ discarded_packages.append(
355
+ DiscardedPackage(
356
+ package_id=model_package.package_id,
357
+ reason=f"Package backend {model_package.backend.value} does not match requested backends: "
358
+ f"{requested_backends_serialised}",
359
+ )
360
+ )
361
+ continue
362
+ filtered_packages.append(model_package)
363
+ return filtered_packages, discarded_packages
364
+
365
+
366
+ def filter_model_packages_by_requested_batch_size(
367
+ model_packages: List[ModelPackageMetadata],
368
+ requested_batch_size: Union[int, Tuple[int, int]],
369
+ verbose: bool = False,
370
+ ) -> Tuple[List[ModelPackageMetadata], List[DiscardedPackage]]:
371
+ min_batch_size, max_batch_size = parse_batch_size(
372
+ requested_batch_size=requested_batch_size
373
+ )
374
+ verbose_info(
375
+ message=f"Filtering model packages by supported batch sizes min={min_batch_size} max={max_batch_size}",
376
+ verbose_requested=verbose,
377
+ )
378
+ filtered_packages, discarded_packages = [], []
379
+ for model_package in model_packages:
380
+ if not model_package_matches_batch_size_request(
381
+ model_package=model_package,
382
+ min_batch_size=min_batch_size,
383
+ max_batch_size=max_batch_size,
384
+ verbose=verbose,
385
+ ):
386
+ verbose_info(
387
+ message=f"Model package with id `{model_package.package_id}` does not match requested batch "
388
+ f"size <{min_batch_size}, {max_batch_size}>.",
389
+ verbose_requested=verbose,
390
+ )
391
+ discarded_packages.append(
392
+ DiscardedPackage(
393
+ package_id=model_package.package_id,
394
+ reason=f"Package batch size does not match requested batch size <{min_batch_size}, {max_batch_size}>",
395
+ )
396
+ )
397
+ continue
398
+ filtered_packages.append(model_package)
399
+ return filtered_packages, discarded_packages
400
+
401
+
402
+ def filter_model_packages_by_requested_quantization(
403
+ model_packages: List[ModelPackageMetadata],
404
+ requested_quantization: Union[str, Quantization, List[Union[str, Quantization]]],
405
+ default_quantization_used: bool,
406
+ verbose: bool = False,
407
+ ) -> Tuple[List[ModelPackageMetadata], List[DiscardedPackage]]:
408
+ requested_quantization = parse_requested_quantization(value=requested_quantization)
409
+ requested_quantization_str = [e.value for e in requested_quantization]
410
+ verbose_info(
411
+ message=f"Filtering model packages by quantization - allowed values: {requested_quantization_str}",
412
+ verbose_requested=verbose,
413
+ )
414
+ default_quantization_used_str = (
415
+ " (which was selected by default)."
416
+ if default_quantization_used
417
+ else " (which was selected by caller)."
418
+ )
419
+ filtered_packages, discarded_packages = [], []
420
+ for model_package in model_packages:
421
+ if model_package.quantization not in requested_quantization:
422
+ verbose_info(
423
+ message=f"Model package with id `{model_package.package_id}` does not match requested quantization "
424
+ f"{requested_quantization_str}{default_quantization_used_str}",
425
+ verbose_requested=verbose,
426
+ )
427
+ discarded_packages.append(
428
+ DiscardedPackage(
429
+ package_id=model_package.package_id,
430
+ reason=f"Package does not match requested quantization {requested_quantization_str}"
431
+ f"{default_quantization_used_str}",
432
+ )
433
+ )
434
+ continue
435
+ filtered_packages.append(model_package)
436
+ return filtered_packages, discarded_packages
437
+
438
+
439
+ def model_package_matches_batch_size_request(
440
+ model_package: ModelPackageMetadata,
441
+ min_batch_size: int,
442
+ max_batch_size: int,
443
+ verbose: bool = False,
444
+ ) -> bool:
445
+ if model_package.dynamic_batch_size_supported:
446
+ if model_package.specifies_dynamic_batch_boundaries():
447
+ declared_min_batch_size, declared_max_batch_size = (
448
+ model_package.get_dynamic_batch_boundaries()
449
+ )
450
+ ranges_match = range_within_other(
451
+ external_range=(declared_min_batch_size, declared_max_batch_size),
452
+ internal_range=(min_batch_size, max_batch_size),
453
+ )
454
+ if not ranges_match:
455
+ verbose_info(
456
+ message=f"Model package with id `{model_package.package_id}` declared to support dynamic batch sizes: "
457
+ f"[{declared_min_batch_size}, {declared_max_batch_size}] and requested batch size was: "
458
+ f"[{min_batch_size}, {max_batch_size}] - package does not match criteria.",
459
+ verbose_requested=verbose,
460
+ )
461
+ return ranges_match
462
+ return True
463
+ if min_batch_size <= model_package.static_batch_size <= max_batch_size:
464
+ return True
465
+ verbose_info(
466
+ message=f"Model package with id `{model_package.package_id}` filtered out, as static batch size does not "
467
+ f"match requested values: ({min_batch_size}, {max_batch_size})."
468
+ f"If you see this error on Roboflow platform - contact us to get help. "
469
+ f"Otherwise, consider adjusting requested batch size.",
470
+ verbose_requested=verbose,
471
+ )
472
+ return False
473
+
474
+
475
+ def filter_model_packages_matching_runtime_environment(
476
+ model_packages: List[ModelPackageMetadata],
477
+ device: torch.device,
478
+ onnx_execution_providers: Optional[List[Union[str, tuple]]],
479
+ trt_engine_host_code_allowed: bool,
480
+ verbose: bool = False,
481
+ ) -> Tuple[List[ModelPackageMetadata], List[DiscardedPackage]]:
482
+ runtime_x_ray = x_ray_runtime_environment()
483
+ verbose_info(
484
+ message=f"Selecting model packages matching to runtime: {runtime_x_ray}",
485
+ verbose_requested=verbose,
486
+ )
487
+ results, discarded_packages = [], []
488
+ for model_package in model_packages:
489
+ matches, reason = model_package_matches_runtime_environment(
490
+ model_package=model_package,
491
+ runtime_x_ray=runtime_x_ray,
492
+ device=device,
493
+ onnx_execution_providers=onnx_execution_providers,
494
+ trt_engine_host_code_allowed=trt_engine_host_code_allowed,
495
+ verbose=verbose,
496
+ )
497
+ if not matches:
498
+ discarded_packages.append(
499
+ DiscardedPackage(
500
+ package_id=model_package.package_id,
501
+ reason=reason,
502
+ )
503
+ )
504
+ continue
505
+ results.append(model_package)
506
+ return results, discarded_packages
507
+
508
+
509
+ def filter_model_packages_based_on_model_features(
510
+ model_packages: List[ModelPackageMetadata],
511
+ nms_fusion_preferences: Optional[Union[bool, dict]],
512
+ model_architecture: ModelArchitecture,
513
+ task_type: Optional[TaskType],
514
+ ) -> Tuple[List[ModelPackageMetadata], List[DiscardedPackage]]:
515
+ results, discarded_packages = [], []
516
+ for model_package in model_packages:
517
+ if not model_package.model_features:
518
+ results.append(model_package)
519
+ continue
520
+ eliminated_by_nms_fusion_preferences, reason = (
521
+ should_model_package_be_filtered_out_based_on_nms_fusion_preferences(
522
+ model_package=model_package,
523
+ nms_fusion_preferences=nms_fusion_preferences,
524
+ model_architecture=model_architecture,
525
+ task_type=task_type,
526
+ )
527
+ )
528
+ if eliminated_by_nms_fusion_preferences:
529
+ if reason is None:
530
+ raise AssumptionError(
531
+ message="Detected bug in `inference` - "
532
+ "`should_model_package_be_filtered_out_based_on_nms_fusion_preferencess()` returned malformed "
533
+ "result. Please raise the issue: https://github.com/roboflow/inference/issues",
534
+ help_url="https://todo",
535
+ )
536
+ discarded_packages.append(
537
+ DiscardedPackage(package_id=model_package.package_id, reason=reason)
538
+ )
539
+ continue
540
+ results.append(model_package)
541
+ return results, discarded_packages
542
+
543
+
544
+ def should_model_package_be_filtered_out_based_on_nms_fusion_preferences(
545
+ model_package: ModelPackageMetadata,
546
+ nms_fusion_preferences: Optional[Union[bool, dict]],
547
+ model_architecture: ModelArchitecture,
548
+ task_type: Optional[TaskType],
549
+ ) -> Tuple[bool, Optional[str]]:
550
+ nms_fused_config = model_package.model_features.get(NMS_FUSED_FEATURE)
551
+ if nms_fused_config is None:
552
+ return False, None
553
+ if nms_fusion_preferences is None or nms_fusion_preferences is False:
554
+ return (
555
+ True,
556
+ "Package specifies NMS fusion, but auto-loading used with `nms_fusion_preferences`=None rejecting such packages.",
557
+ )
558
+ if nms_fusion_preferences is True:
559
+ nms_fusion_preferences = get_default_nms_settings(
560
+ model_architecture=model_architecture,
561
+ task_type=task_type,
562
+ )
563
+ try:
564
+ actual_max_detections = nms_fused_config[NMS_MAX_DETECTIONS_KEY]
565
+ actual_confidence_threshold = nms_fused_config[NMS_CONFIDENCE_THRESHOLD_KEY]
566
+ actual_iou_threshold = nms_fused_config[NMS_IOU_THRESHOLD_KEY]
567
+ actual_class_agnostic = nms_fused_config[NMS_CLASS_AGNOSTIC_KEY]
568
+ except KeyError as error:
569
+ return (
570
+ True,
571
+ f"Package specifies malformed `{NMS_FUSED_FEATURE}` property in model features - missing key: {error}",
572
+ )
573
+ if NMS_MAX_DETECTIONS_KEY in nms_fusion_preferences:
574
+ requested_max_detections = nms_fusion_preferences[NMS_MAX_DETECTIONS_KEY]
575
+ if isinstance(requested_max_detections, (list, tuple)):
576
+ min_detections, max_detections = requested_max_detections
577
+ else:
578
+ min_detections, max_detections = (
579
+ requested_max_detections,
580
+ requested_max_detections,
581
+ )
582
+ min_detections, max_detections = min(min_detections, max_detections), max(
583
+ min_detections, max_detections
584
+ )
585
+ if not min_detections <= actual_max_detections <= max_detections:
586
+ return (
587
+ True,
588
+ f"Package specifies NMS fusion with `{NMS_MAX_DETECTIONS_KEY}` not matching the preference passed to "
589
+ f"auto-loading.",
590
+ )
591
+ if NMS_CONFIDENCE_THRESHOLD_KEY in nms_fusion_preferences:
592
+ requested_confidence = nms_fusion_preferences[NMS_CONFIDENCE_THRESHOLD_KEY]
593
+ if isinstance(requested_confidence, (list, tuple)):
594
+ min_confidence, max_confidence = requested_confidence
595
+ else:
596
+ min_confidence, max_confidence = (
597
+ requested_confidence,
598
+ requested_confidence,
599
+ )
600
+ min_confidence, max_confidence = min(min_confidence, max_confidence), max(
601
+ min_confidence, max_confidence
602
+ )
603
+ if not min_confidence <= actual_confidence_threshold <= max_confidence:
604
+ return (
605
+ True,
606
+ f"Package specifies NMS fusion with `{NMS_CONFIDENCE_THRESHOLD_KEY}` not matching the preference passed to "
607
+ f"auto-loading.",
608
+ )
609
+
610
+ if NMS_IOU_THRESHOLD_KEY in nms_fusion_preferences:
611
+ requested_iou_threshold = nms_fusion_preferences[NMS_IOU_THRESHOLD_KEY]
612
+ if isinstance(requested_iou_threshold, (list, tuple)):
613
+ min_iou_threshold, max_iou_threshold = requested_iou_threshold
614
+ else:
615
+ min_iou_threshold, max_iou_threshold = (
616
+ requested_iou_threshold,
617
+ requested_iou_threshold,
618
+ )
619
+ min_iou_threshold, max_iou_threshold = min(
620
+ min_iou_threshold, max_iou_threshold
621
+ ), max(min_iou_threshold, max_iou_threshold)
622
+ if not min_iou_threshold <= actual_iou_threshold <= max_iou_threshold:
623
+ return (
624
+ True,
625
+ f"Package specifies NMS fusion with `{NMS_IOU_THRESHOLD_KEY}` not matching the preference passed to "
626
+ f"auto-loading.",
627
+ )
628
+ if NMS_CLASS_AGNOSTIC_KEY in nms_fusion_preferences:
629
+ if actual_class_agnostic != nms_fusion_preferences[NMS_CLASS_AGNOSTIC_KEY]:
630
+ return (
631
+ True,
632
+ f"Package specifies NMS fusion with `{NMS_CLASS_AGNOSTIC_KEY}` not matching the preference passed to "
633
+ f"auto-loading.",
634
+ )
635
+ return False, None
636
+
637
+
638
+ def get_default_nms_settings(
639
+ model_architecture: ModelArchitecture,
640
+ task_type: Optional[TaskType],
641
+ ) -> dict:
642
+ # TODO - over time it may change - but please keep it specific, without ranges (for the sake of simplicity
643
+ # of ranking)
644
+ return {
645
+ NMS_MAX_DETECTIONS_KEY: 300,
646
+ NMS_CONFIDENCE_THRESHOLD_KEY: 0.25,
647
+ NMS_IOU_THRESHOLD_KEY: 0.7,
648
+ NMS_CLASS_AGNOSTIC_KEY: False,
649
+ }
650
+
651
+
652
+ def model_package_matches_runtime_environment(
653
+ model_package: ModelPackageMetadata,
654
+ runtime_x_ray: RuntimeXRayResult,
655
+ device: Optional[torch.device] = None,
656
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
657
+ trt_engine_host_code_allowed: bool = True,
658
+ verbose: bool = False,
659
+ ) -> Tuple[bool, Optional[str]]:
660
+ if model_package.backend not in MODEL_TO_RUNTIME_COMPATIBILITY_MATCHERS:
661
+ raise ModelPackageNegotiationError(
662
+ message=f"Model package negotiation protocol not implemented for model backend {model_package.backend}. "
663
+ f"This is `inference-models` bug - raise issue: https://github.com/roboflow/inference/issues",
664
+ help_url="https://todo",
665
+ )
666
+ return MODEL_TO_RUNTIME_COMPATIBILITY_MATCHERS[model_package.backend](
667
+ model_package,
668
+ runtime_x_ray,
669
+ device,
670
+ onnx_execution_providers,
671
+ trt_engine_host_code_allowed,
672
+ verbose,
673
+ )
674
+
675
+
676
+ ONNX_RUNTIME_OPSET_COMPATIBILITY = {
677
+ Version("1.15"): 19,
678
+ Version("1.16"): 19,
679
+ Version("1.17"): 20,
680
+ Version("1.18"): 21,
681
+ Version("1.19"): 21,
682
+ Version("1.20"): 21,
683
+ Version("1.21"): 22,
684
+ Version("1.22"): 23,
685
+ }
686
+
687
+
688
+ def onnx_package_matches_runtime_environment(
689
+ model_package: ModelPackageMetadata,
690
+ runtime_x_ray: RuntimeXRayResult,
691
+ device: Optional[torch.device] = None,
692
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
693
+ trt_engine_host_code_allowed: bool = True,
694
+ verbose: bool = False,
695
+ ) -> Tuple[bool, Optional[str]]:
696
+ if (
697
+ not runtime_x_ray.onnxruntime_version
698
+ or not runtime_x_ray.available_onnx_execution_providers
699
+ ):
700
+ verbose_info(
701
+ message=f"Mode package with id '{model_package.package_id}' filtered out as onnxruntime not detected",
702
+ verbose_requested=verbose,
703
+ )
704
+ return False, (
705
+ "ONNX backend not installed - consider installing relevant ONNX extras: "
706
+ "`onnx-cpu`, `onnx-cu118`, `onnx-cu12`, `onnx-jp6-cu126` depending on hardware you run `inference-models`"
707
+ )
708
+ if model_package.onnx_package_details is None:
709
+ verbose_info(
710
+ message=f"Mode package with id '{model_package.package_id}' filtered out as onnxruntime specification "
711
+ f"not provided by weights provider.",
712
+ verbose_requested=verbose,
713
+ )
714
+ return (
715
+ False,
716
+ "Model package metadata delivered by weights provider lack required ONNX package details",
717
+ )
718
+ providers_auto_selected = False
719
+ if not onnx_execution_providers:
720
+ providers_auto_selected = True
721
+ onnx_execution_providers = get_selected_onnx_execution_providers()
722
+ onnx_execution_providers = [
723
+ provider
724
+ for provider in onnx_execution_providers
725
+ if provider in runtime_x_ray.available_onnx_execution_providers
726
+ ]
727
+ if not onnx_execution_providers:
728
+ if providers_auto_selected:
729
+ reason = (
730
+ "Incorrect ONNX backend installation none of the default ONNX Execution Providers "
731
+ "available in environment"
732
+ )
733
+ else:
734
+ reason = (
735
+ "None of the selected ONNX Execution Providers detected in runtime environment - consider "
736
+ "adjusting the settings"
737
+ )
738
+ verbose_info(
739
+ message=f"Mode package with id '{model_package.package_id}' filtered out as `inference-models` could not find "
740
+ f"matching execution providers that are available in runtime to run a model.",
741
+ verbose_requested=verbose,
742
+ )
743
+ return False, reason
744
+ incompatible_providers = model_package.onnx_package_details.incompatible_providers
745
+ if incompatible_providers is None:
746
+ incompatible_providers = []
747
+ incompatible_providers = set(incompatible_providers)
748
+ if onnx_execution_providers[0] in incompatible_providers:
749
+ # checking the first one only - this is kind of heuristic, as
750
+ # probably there may be a fallback - so theoretically it is possible
751
+ # for this function to claim that package is compatible, but specific
752
+ # operation in the graph may fall back to another EP - but that's
753
+ # rather a situation deeply specific for a model and if we see this be
754
+ # problematic, we will implement solution - so far - to counter-act errors
755
+ # which can only be determined in runtime we may either expect model implementation
756
+ # would run test inference in init or user to define specific model package ID
757
+ # to run. Not great, not terrible, yet I can expect this to be a basis of heated
758
+ # debate some time in the future :)
759
+ verbose_info(
760
+ message=f"Mode package with id '{model_package.package_id}' filtered out as execution provider "
761
+ f"which is selected as primary one ('{onnx_execution_providers[0]}') is enlisted as incompatible "
762
+ f"for model package.",
763
+ verbose_requested=verbose,
764
+ )
765
+ return (
766
+ False,
767
+ f"Model package cannot be run with default ONNX Execution Provider: {onnx_execution_providers[0]}",
768
+ )
769
+ package_opset = model_package.onnx_package_details.opset
770
+ onnx_runtime_simple_version = Version(
771
+ f"{runtime_x_ray.onnxruntime_version.major}.{runtime_x_ray.onnxruntime_version.minor}"
772
+ )
773
+ if onnx_runtime_simple_version not in ONNX_RUNTIME_OPSET_COMPATIBILITY:
774
+ if package_opset <= ONNX_RUNTIME_OPSET_COMPATIBILITY[Version("1.15")]:
775
+ return True, None
776
+ verbose_info(
777
+ message=f"Mode package with id '{model_package.package_id}' filtered out as onnxruntime version "
778
+ f"detected ({runtime_x_ray.onnxruntime_version}) could not be resolved with the matching "
779
+ f"onnx opset. The auto-negotiation assumes that in such case, maximum supported opset is 19.",
780
+ verbose_requested=verbose,
781
+ )
782
+ return (
783
+ False,
784
+ "ONNX model package was compiled with opset higher than supported for installed ONNX backend",
785
+ )
786
+ max_supported_opset = ONNX_RUNTIME_OPSET_COMPATIBILITY[onnx_runtime_simple_version]
787
+ if package_opset > max_supported_opset:
788
+ verbose_info(
789
+ message=f"Mode package with id '{model_package.package_id}' filtered out as onnxruntime version "
790
+ f"detected ({runtime_x_ray.onnxruntime_version}) can only run onnx packages with opset "
791
+ f"up to {max_supported_opset}, but the package opset is {package_opset}.",
792
+ verbose_requested=verbose,
793
+ )
794
+ return (
795
+ False,
796
+ "ONNX model package was compiled with opset higher than supported for installed ONNX backend",
797
+ )
798
+ return True, None
799
+
800
+
801
+ def torch_package_matches_runtime_environment(
802
+ model_package: ModelPackageMetadata,
803
+ runtime_x_ray: RuntimeXRayResult,
804
+ device: Optional[torch.device] = None,
805
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
806
+ trt_engine_host_code_allowed: bool = True,
807
+ verbose: bool = False,
808
+ ) -> Tuple[bool, Optional[str]]:
809
+ if not runtime_x_ray.torch_available:
810
+ verbose_info(
811
+ message="Mode package with id '{model_package.package_id}' filtered out as torch not detected",
812
+ verbose_requested=verbose,
813
+ )
814
+ return (
815
+ False,
816
+ "Torch backend not installed - consider installing relevant torch extras: "
817
+ "`torch-cpu`, `torch-cu118`, `torch-cu124`, `torch-cu126`, `torch-cu128` or `torch-jp6-cu126` \
818
+ depending on hardware you run `inference-models`",
819
+ )
820
+ return True, None
821
+
822
+
823
+ def hf_transformers_package_matches_runtime_environment(
824
+ model_package: ModelPackageMetadata,
825
+ runtime_x_ray: RuntimeXRayResult,
826
+ device: Optional[torch.device] = None,
827
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
828
+ trt_engine_host_code_allowed: bool = True,
829
+ verbose: bool = False,
830
+ ) -> Tuple[bool, Optional[str]]:
831
+ if not runtime_x_ray.hf_transformers_available:
832
+ verbose_info(
833
+ message=f"Mode package with id '{model_package.package_id}' filtered out as transformers not detected",
834
+ verbose_requested=verbose,
835
+ )
836
+ return False, (
837
+ "Transformers backend not installed - this package should be installed by default and probably "
838
+ "was accidentally deleted - install `inference-models` package again."
839
+ )
840
+ return True, None
841
+
842
+
843
+ def ultralytics_package_matches_runtime_environment(
844
+ model_package: ModelPackageMetadata,
845
+ runtime_x_ray: RuntimeXRayResult,
846
+ device: Optional[torch.device] = None,
847
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
848
+ trt_engine_host_code_allowed: bool = True,
849
+ verbose: bool = False,
850
+ ) -> Tuple[bool, Optional[str]]:
851
+ if not runtime_x_ray.ultralytics_available:
852
+ verbose_info(
853
+ message=f"Mode package with id '{model_package.package_id}' filtered out as ultralytics not detected",
854
+ verbose_requested=verbose,
855
+ )
856
+ return False, "Ultralytics backend not installed"
857
+ return True, None
858
+
859
+
860
+ def mediapipe_package_matches_runtime_environment(
861
+ model_package: ModelPackageMetadata,
862
+ runtime_x_ray: RuntimeXRayResult,
863
+ device: Optional[torch.device] = None,
864
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
865
+ trt_engine_host_code_allowed: bool = True,
866
+ verbose: bool = False,
867
+ ) -> Tuple[bool, Optional[str]]:
868
+ if not runtime_x_ray.mediapipe_available:
869
+ verbose_info(
870
+ message=f"Mode package with id '{model_package.package_id}' filtered out as mediapipe not detected",
871
+ verbose_requested=verbose,
872
+ )
873
+ return False, "Mediapipe backend not installed"
874
+ return True, None
875
+
876
+
877
+ def trt_package_matches_runtime_environment(
878
+ model_package: ModelPackageMetadata,
879
+ runtime_x_ray: RuntimeXRayResult,
880
+ device: Optional[torch.device] = None,
881
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
882
+ trt_engine_host_code_allowed: bool = True,
883
+ verbose: bool = False,
884
+ ) -> Tuple[bool, Optional[str]]:
885
+ if not runtime_x_ray.trt_version:
886
+ verbose_info(
887
+ message=f"Mode package with id '{model_package.package_id}' filtered out as TRT libraries not detected",
888
+ verbose_requested=verbose,
889
+ )
890
+ return False, "TRT backend not installed. Consider installing `trt10` extras."
891
+ if not runtime_x_ray.trt_python_package_available:
892
+ verbose_info(
893
+ message=f"Mode package with id '{model_package.package_id}' filtered out as TRT python package not available",
894
+ verbose_requested=verbose,
895
+ )
896
+ return (
897
+ False,
898
+ "Model package metadata delivered by weights provider lack required TRT package details",
899
+ )
900
+ if model_package.environment_requirements is None:
901
+ verbose_info(
902
+ message=f"Mode package with id '{model_package.package_id}' filtered out as environment requirements "
903
+ f"not provided by backend.",
904
+ verbose_requested=verbose,
905
+ )
906
+ return (
907
+ False,
908
+ "Model package metadata delivered by weights provider lack required TRT package details",
909
+ )
910
+ trt_compiled_with_cc_compatibility = False
911
+ if model_package.trt_package_details is not None:
912
+ trt_compiled_with_cc_compatibility = (
913
+ model_package.trt_package_details.same_cc_compatible
914
+ )
915
+ trt_forward_compatible = False
916
+ if model_package.trt_package_details is not None:
917
+ trt_forward_compatible = (
918
+ model_package.trt_package_details.trt_forward_compatible
919
+ )
920
+ trt_lean_runtime_excluded = False
921
+ if model_package.trt_package_details is not None:
922
+ trt_lean_runtime_excluded = (
923
+ model_package.trt_package_details.trt_lean_runtime_excluded
924
+ )
925
+ model_environment = model_package.environment_requirements
926
+ if isinstance(model_environment, JetsonEnvironmentRequirements):
927
+ if model_environment.trt_version is None:
928
+ verbose_info(
929
+ message=f"Mode package with id '{model_package.package_id}' filtered out as model TRT version not provided by backend",
930
+ verbose_requested=verbose,
931
+ )
932
+ return (
933
+ False,
934
+ "Model package metadata delivered by weights provider lack required TRT package details",
935
+ )
936
+ if runtime_x_ray.l4t_version is None:
937
+ verbose_info(
938
+ message=f"Mode package with id '{model_package.package_id}' filtered out as runtime environment does not declare L4T version",
939
+ verbose_requested=verbose,
940
+ )
941
+ return (
942
+ False,
943
+ "Model package metadata delivered by weights provider lack required TRT package details",
944
+ )
945
+ device_compatibility = verify_trt_package_compatibility_with_cuda_device(
946
+ all_available_cuda_devices=runtime_x_ray.gpu_devices,
947
+ all_available_devices_cc=runtime_x_ray.gpu_devices_cc,
948
+ compilation_device=model_environment.cuda_device_name,
949
+ compilation_device_cc=model_environment.cuda_device_cc,
950
+ selected_device=device,
951
+ trt_compiled_with_cc_compatibility=trt_compiled_with_cc_compatibility,
952
+ )
953
+ if not device_compatibility:
954
+ verbose_info(
955
+ message=f"Model package with id '{model_package.package_id}' filtered out due to device incompatibility.",
956
+ verbose_requested=verbose,
957
+ )
958
+ return False, "TRT model package is incompatible with your hardware"
959
+ if not verify_versions_up_to_major_and_minor(
960
+ runtime_x_ray.l4t_version, model_environment.l4t_version
961
+ ):
962
+ verbose_info(
963
+ message=f"Mode package with id '{model_package.package_id}' filtered out as package L4T {model_environment.l4t_version} does not match runtime L4T: {runtime_x_ray.l4t_version}",
964
+ verbose_requested=verbose,
965
+ )
966
+ return False, "TRT model package is incompatible with installed L4T version"
967
+ if trt_forward_compatible:
968
+ if runtime_x_ray.trt_version < model_environment.trt_version:
969
+ verbose_info(
970
+ message=f"Mode package with id '{model_package.package_id}' filtered out as TRT version in "
971
+ f"environment ({runtime_x_ray.trt_version}) is older than engine TRT version "
972
+ f"({model_environment.trt_version}) - despite engine being forward compatible, "
973
+ f"TRT requires that TRT available in runtime is in version higher or equal compared "
974
+ f"to the one used for compilation.",
975
+ verbose_requested=verbose,
976
+ )
977
+ return (
978
+ False,
979
+ "TRT model package is incompatible with installed TRT version",
980
+ )
981
+ if trt_lean_runtime_excluded:
982
+ # not supported for now
983
+ verbose_info(
984
+ message=f"Mode package with id '{model_package.package_id}' filtered out as it was compiled to "
985
+ f"be forward compatible, but with lean runtime excluded from the engine - this mode is "
986
+ f"currently not supported in `inference-models`.",
987
+ verbose_requested=verbose,
988
+ )
989
+ return False, "TRT model package is currently not supported"
990
+ elif not trt_engine_host_code_allowed:
991
+ verbose_info(
992
+ message=f"Mode package with id '{model_package.package_id}' filtered out as it contains TRT "
993
+ f"Lean Runtime that requires potentially unsafe deserialization which is forbidden "
994
+ f"in this configuration of `inference-models`. Set `trt_engine_host_code_allowed=True` if "
995
+ f"you want this package to be supported.",
996
+ verbose_requested=verbose,
997
+ )
998
+ return False, (
999
+ "TRT model package cannot run with `trt_engine_host_code_allowed=False` - "
1000
+ "consider settings adjustment."
1001
+ )
1002
+ elif runtime_x_ray.trt_version != model_environment.trt_version:
1003
+ verbose_info(
1004
+ message=f"Mode package with id '{model_package.package_id}' filtered out as package trt version {model_environment.trt_version} does not match runtime trt version: {runtime_x_ray.trt_version}",
1005
+ verbose_requested=verbose,
1006
+ )
1007
+ return False, "TRT model package is incompatible with installed TRT version"
1008
+ return True, None
1009
+ if not isinstance(model_environment, ServerEnvironmentRequirements):
1010
+ raise ModelPackageNegotiationError(
1011
+ message=f"Model package negotiation protocol not implemented for environment specification detected "
1012
+ f"in runtime. This is `inference-models` bug - raise issue: https://github.com/roboflow/inference/issues",
1013
+ help_url="https://todo",
1014
+ )
1015
+ if model_environment.trt_version is None:
1016
+ verbose_info(
1017
+ message=f"Mode package with id '{model_package.package_id}' filtered out as model TRT version not provided by backend",
1018
+ verbose_requested=verbose,
1019
+ )
1020
+ return (
1021
+ False,
1022
+ "Model package metadata delivered by weights provider lack required TRT package details",
1023
+ )
1024
+ device_compatibility = verify_trt_package_compatibility_with_cuda_device(
1025
+ all_available_cuda_devices=runtime_x_ray.gpu_devices,
1026
+ all_available_devices_cc=runtime_x_ray.gpu_devices_cc,
1027
+ compilation_device=model_environment.cuda_device_name,
1028
+ compilation_device_cc=model_environment.cuda_device_cc,
1029
+ selected_device=device,
1030
+ trt_compiled_with_cc_compatibility=trt_compiled_with_cc_compatibility,
1031
+ )
1032
+ if not device_compatibility:
1033
+ verbose_info(
1034
+ message=f"Model package with id '{model_package.package_id}' filtered out due to device incompatibility.",
1035
+ verbose_requested=verbose,
1036
+ )
1037
+ return False, "TRT model package is incompatible with your hardware"
1038
+ if trt_forward_compatible:
1039
+ if runtime_x_ray.trt_version < model_environment.trt_version:
1040
+ verbose_info(
1041
+ message=f"Mode package with id '{model_package.package_id}' filtered out as TRT version in "
1042
+ f"environment ({runtime_x_ray.trt_version}) is older than engine TRT version "
1043
+ f"({model_environment.trt_version}) - despite engine being forward compatible, "
1044
+ f"TRT requires that TRT available in runtime is in version higher or equal compared "
1045
+ f"to the one used for compilation.",
1046
+ verbose_requested=verbose,
1047
+ )
1048
+ return False, "TRT model package is incompatible with installed TRT version"
1049
+ if trt_lean_runtime_excluded:
1050
+ # not supported for now
1051
+ verbose_info(
1052
+ message=f"Mode package with id '{model_package.package_id}' filtered out as it was compiled to "
1053
+ f"be forward compatible, but with lean runtime excluded from the engine - this mode is "
1054
+ f"currently not supported in `inference-models`.",
1055
+ verbose_requested=verbose,
1056
+ )
1057
+ return False, "TRT model package is currently not supported"
1058
+ elif not trt_engine_host_code_allowed:
1059
+ verbose_info(
1060
+ message=f"Mode package with id '{model_package.package_id}' filtered out as it contains TRT "
1061
+ f"Lean Runtime that requires potentially unsafe deserialization which is forbidden "
1062
+ f"in this configuration of `inference-models`. Set `trt_engine_host_code_allowed=True` if "
1063
+ f"you want this package to be supported.",
1064
+ verbose_requested=verbose,
1065
+ )
1066
+ return False, (
1067
+ "TRT model package cannot run with `trt_engine_host_code_allowed=False` - "
1068
+ "consider settings adjustment."
1069
+ )
1070
+ elif runtime_x_ray.trt_version != model_environment.trt_version:
1071
+ verbose_info(
1072
+ message=f"Mode package with id '{model_package.package_id}' filtered out as package trt version {model_environment.trt_version} does not match runtime trt version: {runtime_x_ray.trt_version}",
1073
+ verbose_requested=verbose,
1074
+ )
1075
+ return False, "TRT model package is incompatible with installed TRT version"
1076
+ return True, None
1077
+
1078
+
1079
+ def verify_trt_package_compatibility_with_cuda_device(
1080
+ selected_device: Optional[torch.device],
1081
+ all_available_cuda_devices: List[str],
1082
+ all_available_devices_cc: List[Version],
1083
+ compilation_device: str,
1084
+ compilation_device_cc: Version,
1085
+ trt_compiled_with_cc_compatibility: bool,
1086
+ ) -> bool:
1087
+ all_available_cuda_devices, all_available_devices_cc = (
1088
+ filter_available_devices_with_selected_device(
1089
+ selected_device=selected_device,
1090
+ all_available_cuda_devices=all_available_cuda_devices,
1091
+ all_available_devices_cc=all_available_devices_cc,
1092
+ )
1093
+ )
1094
+ if trt_compiled_with_cc_compatibility:
1095
+ return any(cc == compilation_device_cc for cc in all_available_devices_cc)
1096
+ return any(dev == compilation_device for dev in all_available_cuda_devices)
1097
+
1098
+
1099
+ def torch_script_package_matches_runtime_environment(
1100
+ model_package: ModelPackageMetadata,
1101
+ runtime_x_ray: RuntimeXRayResult,
1102
+ device: Optional[torch.device] = None,
1103
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
1104
+ trt_engine_host_code_allowed: bool = True,
1105
+ verbose: bool = False,
1106
+ ) -> Tuple[bool, Optional[str]]:
1107
+ if not runtime_x_ray.torch_available:
1108
+ verbose_info(
1109
+ message=f"Mode package with id '{model_package.package_id}' filtered out as torch not detected",
1110
+ verbose_requested=verbose,
1111
+ )
1112
+ return (
1113
+ False,
1114
+ "Torch backend not installed - consider installing relevant torch extras: "
1115
+ "`torch-cpu`, `torch-cu118`, `torch-cu124`, `torch-cu126`, `torch-cu128` or `torch-jp6-cu126` \
1116
+ depending on hardware you run `inference-models`",
1117
+ )
1118
+ if model_package.torch_script_package_details is None:
1119
+ verbose_info(
1120
+ message=f"Mode package with id '{model_package.package_id}' filtered out as TorchScript package details "
1121
+ f"not provided by backend.",
1122
+ verbose_requested=verbose,
1123
+ )
1124
+ return (
1125
+ False,
1126
+ "Model package metadata delivered by weights provider lack required TorchScript package details",
1127
+ )
1128
+ if device is None:
1129
+ verbose_info(
1130
+ message=f"Mode package with id '{model_package.package_id}' filtered out as auto-negotiation does not "
1131
+ f"specify `device` parameter which makes it impossible to match with compatible devices registered "
1132
+ f"for the package.",
1133
+ verbose_requested=verbose,
1134
+ )
1135
+ return (
1136
+ False,
1137
+ "Auto-negotiation run with `device=None` which makes it impossible to match the request with TorchScript "
1138
+ "model package. Specify the device.",
1139
+ )
1140
+ supported_device_types = (
1141
+ model_package.torch_script_package_details.supported_device_types
1142
+ )
1143
+ if device.type not in supported_device_types:
1144
+ verbose_info(
1145
+ message=f"Mode package with id '{model_package.package_id}' filtered out as requested device type "
1146
+ f"is {device.type}, whereas model package is compatible with the following devices: "
1147
+ f"{supported_device_types}.",
1148
+ verbose_requested=verbose,
1149
+ )
1150
+ return (
1151
+ False,
1152
+ f"Model package is supported with the following device types: {supported_device_types}, but "
1153
+ f"auto-negotiation requested model for device with type: {device.type}",
1154
+ )
1155
+ if not runtime_x_ray.torch_version:
1156
+ verbose_info(
1157
+ message=f"Model package with id '{model_package.package_id}' filtered out as it was not possible "
1158
+ f"to extract torch version from environment. This may be a problem worth reporting at "
1159
+ f"https://github.com/roboflow/inference/issues/",
1160
+ verbose_requested=verbose,
1161
+ )
1162
+ return (
1163
+ False,
1164
+ f"Model package is not supported when torch version cannot be determined. This may be a bug - "
1165
+ f"please report: https://github.com/roboflow/inference/issues/",
1166
+ )
1167
+ requested_torch_version = model_package.torch_script_package_details.torch_version
1168
+ if runtime_x_ray.torch_version < requested_torch_version:
1169
+ verbose_info(
1170
+ message=f"Model package with id '{model_package.package_id}' filtered out as it request torch in version "
1171
+ f"at least {requested_torch_version}, but the version {runtime_x_ray.torch_version} is installed. "
1172
+ f"Consider the upgrade of torch.",
1173
+ verbose_requested=verbose,
1174
+ )
1175
+ return (
1176
+ False,
1177
+ f"Model package requires torch in version at least {requested_torch_version}, but your environment "
1178
+ f"has the following version installed: {runtime_x_ray.torch_version} - consider the upgrade of torch.",
1179
+ )
1180
+ requested_torch_vision_version = (
1181
+ model_package.torch_script_package_details.torch_vision_version
1182
+ )
1183
+ if requested_torch_vision_version is None:
1184
+ return True, None
1185
+ if runtime_x_ray.torchvision_version is None:
1186
+ verbose_info(
1187
+ message=f"Model package with id '{model_package.package_id}' filtered out as it was not possible "
1188
+ f"to extract torchvision version from environment. This may be a problem worth reporting at "
1189
+ f"https://github.com/roboflow/inference/issues/",
1190
+ verbose_requested=verbose,
1191
+ )
1192
+ return (
1193
+ False,
1194
+ f"Model package is not supported when torchvision version cannot be determined. This may be a bug - "
1195
+ f"please report: https://github.com/roboflow/inference/issues/",
1196
+ )
1197
+ if runtime_x_ray.torchvision_version < requested_torch_vision_version:
1198
+ verbose_info(
1199
+ message=f"Model package with id '{model_package.package_id}' filtered out as it request torchvision in "
1200
+ f"version at least {requested_torch_vision_version}, but the version "
1201
+ f"{runtime_x_ray.torchvision_version} is installed. Consider the upgrade of torch.",
1202
+ verbose_requested=verbose,
1203
+ )
1204
+ return (
1205
+ False,
1206
+ f"Model package requires torchvision in version at least {requested_torch_vision_version}, but your "
1207
+ f"environment has the following version installed: {runtime_x_ray.torchvision_version} - consider "
1208
+ f"the upgrade of torchvision.",
1209
+ )
1210
+ return True, None
1211
+
1212
+
1213
+ def verify_versions_up_to_major_and_minor(x: Version, y: Version) -> bool:
1214
+ x_simplified = Version(f"{x.major}.{x.minor}")
1215
+ y_simplified = Version(f"{y.major}.{y.minor}")
1216
+ return x_simplified == y_simplified
1217
+
1218
+
1219
+ MODEL_TO_RUNTIME_COMPATIBILITY_MATCHERS = {
1220
+ BackendType.HF: hf_transformers_package_matches_runtime_environment,
1221
+ BackendType.TRT: trt_package_matches_runtime_environment,
1222
+ BackendType.ONNX: onnx_package_matches_runtime_environment,
1223
+ BackendType.TORCH: torch_package_matches_runtime_environment,
1224
+ BackendType.ULTRALYTICS: ultralytics_package_matches_runtime_environment,
1225
+ BackendType.TORCH_SCRIPT: torch_script_package_matches_runtime_environment,
1226
+ BackendType.MEDIAPIPE: mediapipe_package_matches_runtime_environment,
1227
+ }
1228
+
1229
+
1230
+ def range_within_other(
1231
+ external_range: Tuple[int, int],
1232
+ internal_range: Tuple[int, int],
1233
+ ) -> bool:
1234
+ external_min, external_max = external_range
1235
+ internal_min, internal_max = internal_range
1236
+ return external_min <= internal_min <= internal_max <= external_max
1237
+
1238
+
1239
+ def parse_batch_size(
1240
+ requested_batch_size: Union[int, Tuple[int, int]],
1241
+ ) -> Tuple[int, int]:
1242
+ if isinstance(requested_batch_size, tuple):
1243
+ if len(requested_batch_size) != 2:
1244
+ raise InvalidRequestedBatchSizeError(
1245
+ message="Could not parse batch size requested from model package negotiation procedure. "
1246
+ "Batch size request is supposed to be either integer value or tuple specifying (min, max) "
1247
+ f"batch size - but detected tuple of invalid size ({len(requested_batch_size)}) - this is "
1248
+ f"probably typo while specifying requested batch size.",
1249
+ help_url="https://todo",
1250
+ )
1251
+ min_batch_size, max_batch_size = requested_batch_size
1252
+ if not isinstance(min_batch_size, int) or not isinstance(max_batch_size, int):
1253
+ raise InvalidRequestedBatchSizeError(
1254
+ message="Could not parse batch size requested from model package negotiation procedure. "
1255
+ "Batch size request is supposed to be either integer value or tuple specifying (min, max) "
1256
+ f"batch size - but detected tuple elements which are not integer values - this is "
1257
+ f"probably typo while specifying requested batch size.",
1258
+ help_url="https://todo",
1259
+ )
1260
+ if max_batch_size < min_batch_size:
1261
+ raise InvalidRequestedBatchSizeError(
1262
+ message="Could not parse batch size requested from model package negotiation procedure. "
1263
+ "`max_batch_size` is lower than `min_batch_size` - which is invalid value - this is "
1264
+ "probably typo while specifying requested batch size.",
1265
+ help_url="https://todo",
1266
+ )
1267
+ if max_batch_size <= 0 or min_batch_size <= 0:
1268
+ raise InvalidRequestedBatchSizeError(
1269
+ message="Could not parse batch size requested from model package negotiation procedure. "
1270
+ "`min_batch_size` is <= 0 or `max_batch_size` <= - which is invalid value - this is "
1271
+ "probably typo while specifying requested batch size.",
1272
+ help_url="https://todo",
1273
+ )
1274
+ return min_batch_size, max_batch_size
1275
+ if not isinstance(requested_batch_size, int):
1276
+ raise InvalidRequestedBatchSizeError(
1277
+ message="Could not parse batch size requested from model package negotiation procedure. "
1278
+ "Batch size request is supposed to be either integer value or tuple specifying (min, max) "
1279
+ f"batch size - but detected single value which is not integer but has type "
1280
+ f"{requested_batch_size.__class__.__name__} - this is "
1281
+ f"probably typo while specifying requested batch size.",
1282
+ help_url="https://todo",
1283
+ )
1284
+ if requested_batch_size <= 0:
1285
+ raise InvalidRequestedBatchSizeError(
1286
+ message="Could not parse batch size requested from model package negotiation procedure. "
1287
+ "`requested_batch_size` is <= 0 which is invalid value this is "
1288
+ f"probably typo while specifying requested batch size.",
1289
+ help_url="https://todo",
1290
+ )
1291
+ return requested_batch_size, requested_batch_size
1292
+
1293
+
1294
+ def parse_backend_type(value: str) -> BackendType:
1295
+ try:
1296
+ return BackendType(value.lower())
1297
+ except ValueError as error:
1298
+ supported_backends = [e.value for e in BackendType]
1299
+ raise UnknownBackendTypeError(
1300
+ message=f"Requested backend of type '{value}' which is not recognized by `inference-models`. Most likely this "
1301
+ f"error is a result of typo while specifying requested backend. Supported backends: "
1302
+ f"{supported_backends}.",
1303
+ help_url="https://todo",
1304
+ ) from error
1305
+
1306
+
1307
+ def parse_requested_quantization(
1308
+ value: Union[str, Quantization, List[Union[str, Quantization]]],
1309
+ ) -> Set[Quantization]:
1310
+ if not isinstance(value, list):
1311
+ value = [value]
1312
+ result = set()
1313
+ for element in value:
1314
+ if isinstance(element, str):
1315
+ element = parse_quantization(value=element)
1316
+ result.add(element)
1317
+ return result
1318
+
1319
+
1320
+ def parse_quantization(value: str) -> Quantization:
1321
+ try:
1322
+ return Quantization(value)
1323
+ except ValueError as error:
1324
+ raise UnknownQuantizationError(
1325
+ message=f"Requested quantization of type '{value}' which is not recognized by `inference-models`. Most likely this "
1326
+ f"error is a result of typo while specifying requested quantization. Supported values: "
1327
+ f"{list(Quantization.__members__)}.",
1328
+ help_url="https://todo",
1329
+ ) from error