inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,601 @@
1
+ import json
2
+ from typing import Annotated, Callable, Dict, List, Literal, Optional, Union
3
+
4
+ import backoff
5
+ import requests
6
+ from packaging.version import InvalidVersion, Version
7
+ from pydantic import BaseModel, Discriminator, Field, ValidationError
8
+ from requests import Response, Timeout
9
+
10
+ from inference_models.configuration import (
11
+ API_CALLS_MAX_TRIES,
12
+ API_CALLS_TIMEOUT,
13
+ IDEMPOTENT_API_REQUEST_CODES_TO_RETRY,
14
+ ROBOFLOW_API_HOST,
15
+ ROBOFLOW_API_KEY,
16
+ )
17
+ from inference_models.errors import (
18
+ BaseInferenceError,
19
+ ModelMetadataConsistencyError,
20
+ ModelMetadataHandlerNotImplementedError,
21
+ ModelRetrievalError,
22
+ RetryError,
23
+ UnauthorizedModelAccessError,
24
+ )
25
+ from inference_models.logger import LOGGER
26
+ from inference_models.models.auto_loaders.entities import BackendType
27
+ from inference_models.weights_providers.entities import (
28
+ FileDownloadSpecs,
29
+ JetsonEnvironmentRequirements,
30
+ ModelDependency,
31
+ ModelMetadata,
32
+ ModelPackageMetadata,
33
+ ONNXPackageDetails,
34
+ Quantization,
35
+ ServerEnvironmentRequirements,
36
+ TorchScriptPackageDetails,
37
+ TRTPackageDetails,
38
+ )
39
+
40
+ MAX_MODEL_PACKAGE_PAGES = 10
41
+ MODEL_PACKAGES_TO_IGNORE = {
42
+ "oak-model-package-v1",
43
+ "tfjs-model-package-v1",
44
+ }
45
+
46
+
47
+ class RoboflowModelPackageFile(BaseModel):
48
+ file_handle: str = Field(alias="fileHandle")
49
+ download_url: str = Field(alias="downloadUrl")
50
+ md5_hash: Optional[str] = Field(alias="md5Hash", default=None)
51
+
52
+
53
+ class RoboflowModelPackageV1(BaseModel):
54
+ type: Literal["external-model-package-v1"]
55
+ package_id: str = Field(alias="packageId")
56
+ package_manifest: dict = Field(alias="packageManifest")
57
+ model_features: Optional[dict] = Field(alias="modelFeatures", default=None)
58
+ package_files: List[RoboflowModelPackageFile] = Field(alias="packageFiles")
59
+ trusted_source: bool = Field(alias="trustedSource", default=False)
60
+
61
+
62
+ class RoboflowModelDependencyV1(BaseModel):
63
+ type: Literal["model-dependency-v1"]
64
+ name: str
65
+ model_id: str = Field(alias="modelId")
66
+ model_package_id: Optional[str] = Field(alias="modelPackageId", default=None)
67
+
68
+
69
+ class RoboflowModelMetadata(BaseModel):
70
+ type: Literal["external-model-metadata-v1"]
71
+ model_id: str = Field(alias="modelId")
72
+ model_architecture: str = Field(alias="modelArchitecture")
73
+ model_variant: Optional[str] = Field(alias="modelVariant", default=None)
74
+ task_type: Optional[str] = Field(alias="taskType", default=None)
75
+ model_dependencies: Optional[List[RoboflowModelDependencyV1]] = Field(
76
+ alias="modelDependencies", default=None
77
+ )
78
+ model_packages: List[Union[RoboflowModelPackageV1, dict]] = Field(
79
+ alias="modelPackages",
80
+ )
81
+ next_page: Optional[str] = Field(alias="nextPage", default=None)
82
+
83
+
84
+ def get_roboflow_model(model_id: str, api_key: Optional[str] = None) -> ModelMetadata:
85
+ model_metadata = get_model_metadata(model_id=model_id, api_key=api_key)
86
+ parsed_model_packages = []
87
+ for model_package in model_metadata.model_packages:
88
+ parsed_model_package = parse_model_package_metadata(metadata=model_package)
89
+ if parsed_model_package is None:
90
+ continue
91
+ parsed_model_packages.append(parsed_model_package)
92
+ model_dependencies = None
93
+ if model_metadata.model_dependencies:
94
+ model_dependencies = []
95
+ for declared_dependency in model_metadata.model_dependencies:
96
+ model_dependencies.append(
97
+ ModelDependency(
98
+ name=declared_dependency.name,
99
+ model_id=declared_dependency.model_id,
100
+ model_package_id=declared_dependency.model_package_id,
101
+ )
102
+ )
103
+ return ModelMetadata(
104
+ model_id=model_metadata.model_id,
105
+ model_architecture=model_metadata.model_architecture,
106
+ model_packages=parsed_model_packages,
107
+ task_type=model_metadata.task_type,
108
+ model_variant=model_metadata.model_variant,
109
+ model_dependencies=model_dependencies,
110
+ )
111
+
112
+
113
+ def get_model_metadata(
114
+ model_id: str,
115
+ api_key: Optional[str],
116
+ max_pages: int = MAX_MODEL_PACKAGE_PAGES,
117
+ ) -> RoboflowModelMetadata:
118
+ if api_key is None:
119
+ api_key = ROBOFLOW_API_KEY
120
+ fetched_pages = []
121
+ start_after = None
122
+ while len(fetched_pages) < max_pages:
123
+ pagination_result = get_one_page_of_model_metadata(
124
+ model_id=model_id, api_key=api_key, start_after=start_after
125
+ )
126
+ fetched_pages.append(pagination_result)
127
+ start_after = pagination_result.next_page
128
+ if start_after is None:
129
+ break
130
+ all_model_packages = []
131
+ for page in fetched_pages:
132
+ all_model_packages.extend(page.model_packages)
133
+ if not fetched_pages or not all_model_packages:
134
+ raise ModelRetrievalError(
135
+ message=f"Could not retrieve model {model_id} from Roboflow API. Backend provided empty list of model "
136
+ f"packages `inference-models` library could load. Contact Roboflow to solve the problem.",
137
+ help_url="https://todo",
138
+ )
139
+ fetched_pages[-1].model_packages = all_model_packages
140
+ return fetched_pages[-1]
141
+
142
+
143
+ @backoff.on_exception(
144
+ backoff.expo,
145
+ exception=RetryError,
146
+ max_tries=API_CALLS_MAX_TRIES,
147
+ )
148
+ def get_one_page_of_model_metadata(
149
+ model_id: str,
150
+ api_key: Optional[str] = None,
151
+ page_size: Optional[int] = None,
152
+ start_after: Optional[str] = None,
153
+ ) -> RoboflowModelMetadata:
154
+ query = {
155
+ "modelId": model_id,
156
+ }
157
+ if api_key:
158
+ query["api_key"] = api_key
159
+ if page_size:
160
+ query["pageSize"] = page_size
161
+ if start_after:
162
+ query["startAfter"] = start_after
163
+ try:
164
+ response = requests.get(
165
+ f"{ROBOFLOW_API_HOST}/models/v1/external/weights",
166
+ params=query,
167
+ timeout=API_CALLS_TIMEOUT,
168
+ )
169
+ except (OSError, Timeout, requests.exceptions.ConnectionError):
170
+ raise RetryError(
171
+ message=f"Connectivity error",
172
+ help_url="https://todo",
173
+ )
174
+ handle_response_errors(response=response, operation_name="get model weights")
175
+ try:
176
+ return RoboflowModelMetadata.model_validate(response.json()["modelMetadata"])
177
+ except (ValueError, ValidationError, KeyError) as error:
178
+ # TODO: either handle here or fix API, which return 200 with content {error: "endpoint not found"} id endpoint isnt available
179
+ raise ModelRetrievalError(
180
+ message=f"Could not decode Roboflow API response when trying to retrieve model {model_id}. If that problem "
181
+ f"is not ephemeral - contact Roboflow.",
182
+ help_url="https://todo",
183
+ ) from error
184
+
185
+
186
+ def handle_response_errors(response: Response, operation_name: str) -> None:
187
+ if response.status_code == 401 or response.status_code == 403:
188
+ raise UnauthorizedModelAccessError(
189
+ message=f"Could not {operation_name}. Request unauthorised. Are you sure you use valid Roboflow API key? "
190
+ "See details here: https://docs.roboflow.com/api-reference/authentication and "
191
+ "export key to `ROBOFLOW_API_KEY` environment variable",
192
+ help_url="https://todo",
193
+ )
194
+ if response.status_code in IDEMPOTENT_API_REQUEST_CODES_TO_RETRY:
195
+ raise RetryError(
196
+ message=f"Roboflow API returned invalid response code for {operation_name} operation "
197
+ f"{response.status_code}. If that problem is not ephemeral - contact Roboflow.",
198
+ help_url="https://todo",
199
+ )
200
+ if response.status_code >= 400:
201
+ response_payload = get_error_response_payload(response=response)
202
+ raise ModelRetrievalError(
203
+ message=f"Roboflow API returned invalid response code for {operation_name} operation "
204
+ f"{response.status_code}.\n\nResponse:\n{response_payload}",
205
+ help_url="https://todo",
206
+ )
207
+
208
+
209
+ def get_error_response_payload(response: Response) -> str:
210
+ try:
211
+ return json.dumps(response.json(), indent=4)
212
+ except ValueError:
213
+ return response.text
214
+
215
+
216
+ def parse_model_package_metadata(
217
+ metadata: Union[RoboflowModelPackageV1, dict],
218
+ ) -> Optional[ModelPackageMetadata]:
219
+ if isinstance(metadata, dict):
220
+ metadata_type = metadata.get("type", "unknown")
221
+ model_package_id = metadata.get("packageId", "unknown")
222
+ LOGGER.warning(
223
+ "Roboflow API returned entity describing model package which cannot be parsed. This may indicate that "
224
+ f"your `inference-models` package is outdated. "
225
+ f"Debug info - entity type: `{metadata_type}`, model package id: {model_package_id}"
226
+ )
227
+ return None
228
+ manifest_type = metadata.package_manifest.get("type", "unknown")
229
+ if manifest_type in MODEL_PACKAGES_TO_IGNORE:
230
+ LOGGER.debug(
231
+ "Ignoring model package with manifest incompatible with inference."
232
+ f"Debug info - model package id: {metadata.package_id}, manifest type: {manifest_type}."
233
+ )
234
+ return None
235
+ if manifest_type not in MODEL_PACKAGE_PARSERS:
236
+ LOGGER.warning(
237
+ "Roboflow API returned entity describing model package which cannot be parsed. This may indicate that "
238
+ f"your `inference-models` package is outdated. "
239
+ f"Debug info - package manifest type: `{manifest_type}`."
240
+ )
241
+ return None
242
+ try:
243
+ return MODEL_PACKAGE_PARSERS[manifest_type](metadata)
244
+ except BaseInferenceError as error:
245
+ raise error
246
+ except Exception as error:
247
+ raise ModelMetadataConsistencyError(
248
+ message="Roboflow API returned model package metadata which cannot be parsed. Contact Roboflow to "
249
+ f"solve the problem. Error details: {error}. Error type: {error.__class__.__name__}",
250
+ help_url="https://todo",
251
+ ) from error
252
+
253
+
254
+ class OnnxModelPackageV1(BaseModel):
255
+ type: Literal["onnx-model-package-v1"]
256
+ backend_type: Literal["onnx"] = Field(alias="backendType")
257
+ dynamic_batch_size: bool = Field(alias="dynamicBatchSize", default=False)
258
+ static_batch_size: Optional[int] = Field(alias="staticBatchSize", default=None)
259
+ quantization: Quantization
260
+ opset: int
261
+ incompatible_providers: Optional[List[str]] = Field(
262
+ alias="incompatibleProviders", default=None
263
+ )
264
+
265
+
266
+ def parse_onnx_model_package(metadata: RoboflowModelPackageV1) -> ModelPackageMetadata:
267
+ parsed_manifest = OnnxModelPackageV1.model_validate(metadata.package_manifest)
268
+ validate_batch_settings(
269
+ dynamic_batch_size=parsed_manifest.dynamic_batch_size,
270
+ static_batch_size=parsed_manifest.static_batch_size,
271
+ )
272
+ package_artefacts = parse_package_artefacts(
273
+ package_artefacts=metadata.package_files
274
+ )
275
+ return ModelPackageMetadata(
276
+ package_id=metadata.package_id,
277
+ backend=BackendType.ONNX,
278
+ quantization=parsed_manifest.quantization,
279
+ dynamic_batch_size_supported=parsed_manifest.dynamic_batch_size,
280
+ static_batch_size=parsed_manifest.static_batch_size,
281
+ package_artefacts=package_artefacts,
282
+ onnx_package_details=ONNXPackageDetails(
283
+ opset=parsed_manifest.opset,
284
+ incompatible_providers=parsed_manifest.incompatible_providers,
285
+ ),
286
+ trusted_source=metadata.trusted_source,
287
+ model_features=metadata.model_features,
288
+ )
289
+
290
+
291
+ class JetsonMachineSpecsV1(BaseModel):
292
+ type: Literal["jetson-machine-specs-v1"]
293
+ l4t_version: str = Field(alias="l4tVersion")
294
+ device_name: str = Field(alias="deviceName")
295
+ driver_version: str = Field(alias="driverVersion")
296
+
297
+
298
+ class GPUServerSpecsV1(BaseModel):
299
+ type: Literal["gpu-server-specs-v1"]
300
+ driver_version: str = Field(alias="driverVersion")
301
+ os_version: str = Field(alias="osVersion")
302
+
303
+
304
+ class TrtModelPackageV1(BaseModel):
305
+ type: Literal["trt-model-package-v1"]
306
+ backend_type: Literal["trt"] = Field(alias="backendType")
307
+ dynamic_batch_size: bool = Field(alias="dynamicBatchSize", default=False)
308
+ static_batch_size: Optional[int] = Field(alias="staticBatchSize", default=None)
309
+ min_batch_size: Optional[int] = Field(alias="minBatchSize", default=None)
310
+ opt_batch_size: Optional[int] = Field(alias="optBatchSize", default=None)
311
+ max_batch_size: Optional[int] = Field(alias="maxBatchSize", default=None)
312
+ quantization: Quantization
313
+ cuda_device_type: str = Field(alias="cudaDeviceType")
314
+ cuda_device_cc: str = Field(alias="cudaDeviceCC")
315
+ cuda_version: str = Field(alias="cudaVersion")
316
+ trt_version: str = Field(alias="trtVersion")
317
+ same_cc_compatible: bool = Field(alias="sameCCCompatible", default=False)
318
+ trt_forward_compatible: bool = Field(alias="trtForwardCompatible", default=False)
319
+ trt_lean_runtime_excluded: bool = Field(
320
+ alias="trtLeanRuntimeExcluded", default=False
321
+ )
322
+ machine_type: Literal["gpu-server", "jetson"] = Field(alias="machineType")
323
+ machine_specs: Annotated[
324
+ Union[JetsonMachineSpecsV1, GPUServerSpecsV1],
325
+ Discriminator(discriminator="type"),
326
+ ] = Field(alias="machineSpecs")
327
+
328
+
329
+ def parse_trt_model_package(metadata: RoboflowModelPackageV1) -> ModelPackageMetadata:
330
+ parsed_manifest = TrtModelPackageV1.model_validate(metadata.package_manifest)
331
+ validate_batch_settings(
332
+ dynamic_batch_size=parsed_manifest.dynamic_batch_size,
333
+ static_batch_size=parsed_manifest.static_batch_size,
334
+ )
335
+ if parsed_manifest.dynamic_batch_size is True and any(
336
+ e is None
337
+ for e in [
338
+ parsed_manifest.min_batch_size,
339
+ parsed_manifest.opt_batch_size,
340
+ parsed_manifest.max_batch_size,
341
+ ]
342
+ ):
343
+ raise ModelMetadataConsistencyError(
344
+ message="While downloading model weights, Roboflow API provided inconsistent metadata "
345
+ "describing model package - TRT package declared support for dynamic batch size, but did not "
346
+ "specify min / opt / max batch size supported which is required.",
347
+ help_url="https://todo",
348
+ )
349
+ if parsed_manifest.machine_type == "gpu-server":
350
+ if not isinstance(parsed_manifest.machine_specs, GPUServerSpecsV1):
351
+ raise ModelMetadataConsistencyError(
352
+ message="While downloading model weights, Roboflow API provided inconsistent metadata "
353
+ "describing model package - expected GPU Server specification for TRT model package registered as "
354
+ "compiled on gpu-server. Contact Roboflow to solve the problem.",
355
+ help_url="https://todo",
356
+ )
357
+ environment_requirements = ServerEnvironmentRequirements(
358
+ cuda_device_cc=as_version(parsed_manifest.cuda_device_cc),
359
+ cuda_device_name=parsed_manifest.cuda_device_type,
360
+ driver_version=as_version(parsed_manifest.machine_specs.driver_version),
361
+ cuda_version=as_version(parsed_manifest.cuda_version),
362
+ trt_version=as_version(parsed_manifest.trt_version),
363
+ os_version=parsed_manifest.machine_specs.os_version,
364
+ )
365
+ elif parsed_manifest.machine_type == "jetson":
366
+ if not isinstance(parsed_manifest.machine_specs, JetsonMachineSpecsV1):
367
+ raise ModelMetadataConsistencyError(
368
+ message="While downloading model weights, Roboflow API provided inconsistent metadata "
369
+ "describing model package - expected Jetson Device specification for TRT model package registered as "
370
+ "compiled on Jetson. Contact Roboflow to solve the problem.",
371
+ help_url="https://todo",
372
+ )
373
+ environment_requirements = JetsonEnvironmentRequirements(
374
+ cuda_device_cc=as_version(parsed_manifest.cuda_device_cc),
375
+ cuda_device_name=parsed_manifest.cuda_device_type,
376
+ l4t_version=as_version(parsed_manifest.machine_specs.l4t_version),
377
+ jetson_product_name=parsed_manifest.machine_specs.device_name,
378
+ cuda_version=as_version(parsed_manifest.cuda_version),
379
+ trt_version=as_version(parsed_manifest.trt_version),
380
+ driver_version=as_version(parsed_manifest.machine_specs.driver_version),
381
+ )
382
+ else:
383
+ raise ModelMetadataHandlerNotImplementedError(
384
+ message="While downloading model weights, Roboflow API provided metadata which are not handled by current version "
385
+ "of inference detected while parsing TRT model package. This problem may indicate that your inference "
386
+ "package is outdated. Try to upgrade - if that does not help, contact Roboflow to solve the problem.",
387
+ help_url="https://todo",
388
+ )
389
+ package_artefacts = parse_package_artefacts(
390
+ package_artefacts=metadata.package_files
391
+ )
392
+ trt_package_details = TRTPackageDetails(
393
+ min_dynamic_batch_size=parsed_manifest.min_batch_size,
394
+ opt_dynamic_batch_size=parsed_manifest.opt_batch_size,
395
+ max_dynamic_batch_size=parsed_manifest.max_batch_size,
396
+ same_cc_compatible=parsed_manifest.same_cc_compatible,
397
+ trt_forward_compatible=parsed_manifest.trt_forward_compatible,
398
+ trt_lean_runtime_excluded=parsed_manifest.trt_lean_runtime_excluded,
399
+ )
400
+ return ModelPackageMetadata(
401
+ package_id=metadata.package_id,
402
+ backend=BackendType.TRT,
403
+ quantization=parsed_manifest.quantization,
404
+ dynamic_batch_size_supported=parsed_manifest.dynamic_batch_size,
405
+ static_batch_size=parsed_manifest.static_batch_size,
406
+ trt_package_details=trt_package_details,
407
+ package_artefacts=package_artefacts,
408
+ environment_requirements=environment_requirements,
409
+ trusted_source=metadata.trusted_source,
410
+ model_features=metadata.model_features,
411
+ )
412
+
413
+
414
+ class TorchModelPackageV1(BaseModel):
415
+ type: Literal["torch-model-package-v1"]
416
+ backend_type: Literal["torch"] = Field(alias="backendType")
417
+ dynamic_batch_size: bool = Field(alias="dynamicBatchSize", default=False)
418
+ static_batch_size: Optional[int] = Field(alias="staticBatchSize", default=None)
419
+ quantization: Quantization
420
+
421
+
422
+ def parse_torch_model_package(metadata: RoboflowModelPackageV1) -> ModelPackageMetadata:
423
+ parsed_manifest = TorchModelPackageV1.model_validate(metadata.package_manifest)
424
+ validate_batch_settings(
425
+ dynamic_batch_size=parsed_manifest.dynamic_batch_size,
426
+ static_batch_size=parsed_manifest.static_batch_size,
427
+ )
428
+ package_artefacts = parse_package_artefacts(
429
+ package_artefacts=metadata.package_files
430
+ )
431
+ return ModelPackageMetadata(
432
+ package_id=metadata.package_id,
433
+ backend=BackendType.TORCH,
434
+ quantization=parsed_manifest.quantization,
435
+ dynamic_batch_size_supported=parsed_manifest.dynamic_batch_size,
436
+ static_batch_size=parsed_manifest.static_batch_size,
437
+ package_artefacts=package_artefacts,
438
+ trusted_source=metadata.trusted_source,
439
+ model_features=metadata.model_features,
440
+ )
441
+
442
+
443
+ class HFModelPackageV1(BaseModel):
444
+ type: Literal["hf-model-package-v1"]
445
+ backend_type: Literal["hf"] = Field(alias="backendType")
446
+ quantization: Quantization
447
+
448
+
449
+ def parse_hf_model_package(metadata: RoboflowModelPackageV1) -> ModelPackageMetadata:
450
+ parsed_manifest = HFModelPackageV1.model_validate(metadata.package_manifest)
451
+ package_artefacts = parse_package_artefacts(
452
+ package_artefacts=metadata.package_files
453
+ )
454
+ return ModelPackageMetadata(
455
+ package_id=metadata.package_id,
456
+ backend=BackendType.HF,
457
+ quantization=parsed_manifest.quantization,
458
+ package_artefacts=package_artefacts,
459
+ trusted_source=metadata.trusted_source,
460
+ model_features=metadata.model_features,
461
+ )
462
+
463
+
464
+ def parse_ultralytics_model_package(
465
+ metadata: RoboflowModelPackageV1,
466
+ ) -> ModelPackageMetadata:
467
+ package_artefacts = parse_package_artefacts(
468
+ package_artefacts=metadata.package_files
469
+ )
470
+ return ModelPackageMetadata(
471
+ package_id=metadata.package_id,
472
+ backend=BackendType.ULTRALYTICS,
473
+ package_artefacts=package_artefacts,
474
+ quantization=Quantization.UNKNOWN,
475
+ trusted_source=metadata.trusted_source,
476
+ model_features=metadata.model_features,
477
+ )
478
+
479
+
480
+ class TorchScriptModelPackageV1(BaseModel):
481
+ type: Literal["torch-script-model-package-v1"]
482
+ backend_type: Literal["torch-script"] = Field(alias="backendType")
483
+ dynamic_batch_size: bool = Field(alias="dynamicBatchSize", default=False)
484
+ static_batch_size: Optional[int] = Field(alias="staticBatchSize", default=None)
485
+ quantization: Quantization
486
+ supported_device_types: List[str] = Field(alias="supportedDeviceTypes")
487
+ torch_version: str = Field(alias="torchVersion")
488
+ torch_vision_version: Optional[str] = Field(
489
+ alias="torchVisionVersion", default=None
490
+ )
491
+
492
+
493
+ def parse_torch_script_model_package(
494
+ metadata: RoboflowModelPackageV1,
495
+ ) -> ModelPackageMetadata:
496
+ parsed_manifest = TorchScriptModelPackageV1.model_validate(
497
+ metadata.package_manifest
498
+ )
499
+ validate_batch_settings(
500
+ dynamic_batch_size=parsed_manifest.dynamic_batch_size,
501
+ static_batch_size=parsed_manifest.static_batch_size,
502
+ )
503
+ package_artefacts = parse_package_artefacts(
504
+ package_artefacts=metadata.package_files
505
+ )
506
+ torch_vision_version = None
507
+ if parsed_manifest.torch_vision_version is not None:
508
+ torch_vision_version = as_version(parsed_manifest.torch_vision_version)
509
+ torch_script_package_details = TorchScriptPackageDetails(
510
+ supported_device_types=set(parsed_manifest.supported_device_types),
511
+ torch_version=as_version(parsed_manifest.torch_version),
512
+ torch_vision_version=torch_vision_version,
513
+ )
514
+ return ModelPackageMetadata(
515
+ package_id=metadata.package_id,
516
+ backend=BackendType.TORCH_SCRIPT,
517
+ dynamic_batch_size_supported=parsed_manifest.dynamic_batch_size,
518
+ static_batch_size=parsed_manifest.static_batch_size,
519
+ package_artefacts=package_artefacts,
520
+ quantization=parsed_manifest.quantization,
521
+ trusted_source=metadata.trusted_source,
522
+ model_features=metadata.model_features,
523
+ torch_script_package_details=torch_script_package_details,
524
+ )
525
+
526
+
527
+ class MediapipeModelPackageV1(BaseModel):
528
+ type: Literal["mediapipe-model-package-v1"]
529
+ backend_type: Literal["mediapipe"] = Field(alias="backendType")
530
+
531
+
532
+ def parse_mediapipe_model_package(
533
+ metadata: RoboflowModelPackageV1,
534
+ ) -> ModelPackageMetadata:
535
+ _ = MediapipeModelPackageV1.model_validate(metadata.package_manifest)
536
+ package_artefacts = parse_package_artefacts(
537
+ package_artefacts=metadata.package_files
538
+ )
539
+ return ModelPackageMetadata(
540
+ package_id=metadata.package_id,
541
+ backend=BackendType.MEDIAPIPE,
542
+ package_artefacts=package_artefacts,
543
+ quantization=Quantization.UNKNOWN,
544
+ trusted_source=metadata.trusted_source,
545
+ model_features=metadata.model_features,
546
+ )
547
+
548
+
549
+ def validate_batch_settings(
550
+ dynamic_batch_size: bool, static_batch_size: Optional[int]
551
+ ) -> None:
552
+ if not dynamic_batch_size and (static_batch_size is None or static_batch_size <= 0):
553
+ raise ModelMetadataConsistencyError(
554
+ message="While downloading model weights, Roboflow API provided inconsistent metadata "
555
+ "describing model package - model package declared not to support dynamic batch size and "
556
+ "supported static batch size not provided. Contact Roboflow to solve the problem.",
557
+ help_url="https://todo",
558
+ )
559
+ if dynamic_batch_size and static_batch_size is not None:
560
+ raise ModelMetadataConsistencyError(
561
+ message="While downloading model weights, Roboflow API provided inconsistent metadata "
562
+ "describing model package - model package declared not to support dynamic batch size and "
563
+ "supported static batch size not provided. Contact Roboflow to solve the problem.",
564
+ help_url="https://todo",
565
+ )
566
+
567
+
568
+ def parse_package_artefacts(
569
+ package_artefacts: List[RoboflowModelPackageFile],
570
+ ) -> List[FileDownloadSpecs]:
571
+ return [
572
+ FileDownloadSpecs(
573
+ download_url=f.download_url, file_handle=f.file_handle, md5_hash=f.md5_hash
574
+ )
575
+ for f in package_artefacts
576
+ ]
577
+
578
+
579
+ MODEL_PACKAGE_PARSERS: Dict[
580
+ str, Callable[[RoboflowModelPackageV1], ModelPackageMetadata]
581
+ ] = {
582
+ "onnx-model-package-v1": parse_onnx_model_package,
583
+ "trt-model-package-v1": parse_trt_model_package,
584
+ "torch-model-package-v1": parse_torch_model_package,
585
+ "hf-model-package-v1": parse_hf_model_package,
586
+ "ultralytics-model-package-v1": parse_ultralytics_model_package,
587
+ "torch-script-model-package-v1": parse_torch_script_model_package,
588
+ "mediapipe-model-package-v1": parse_mediapipe_model_package,
589
+ }
590
+
591
+
592
+ def as_version(value: str) -> Version:
593
+ try:
594
+ return Version(value)
595
+ except InvalidVersion as error:
596
+ raise ModelMetadataConsistencyError(
597
+ message="Roboflow API returned model package manifest that is expected to provide valid version specification for "
598
+ "one of the field of package manifest, but instead provides value that cannot be parsed. This is most "
599
+ "likely Roboflow API bug - contact Roboflow to solve the problem.",
600
+ help_url="https://todo",
601
+ ) from error