inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,413 @@
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+
5
+ from inference_models.models.auto_loaders.constants import (
6
+ NMS_CLASS_AGNOSTIC_KEY,
7
+ NMS_CONFIDENCE_THRESHOLD_KEY,
8
+ NMS_FUSED_FEATURE,
9
+ NMS_IOU_THRESHOLD_KEY,
10
+ NMS_MAX_DETECTIONS_KEY,
11
+ )
12
+ from inference_models.models.auto_loaders.entities import BackendType
13
+ from inference_models.models.auto_loaders.utils import (
14
+ filter_available_devices_with_selected_device,
15
+ )
16
+ from inference_models.runtime_introspection.core import x_ray_runtime_environment
17
+ from inference_models.weights_providers.entities import (
18
+ JetsonEnvironmentRequirements,
19
+ ModelPackageMetadata,
20
+ Quantization,
21
+ ServerEnvironmentRequirements,
22
+ )
23
+
24
+ BACKEND_PRIORITY = {
25
+ BackendType.TRT: 7,
26
+ BackendType.TORCH: 6,
27
+ BackendType.HF: 5,
28
+ BackendType.ONNX: 3,
29
+ BackendType.TORCH_SCRIPT: 3,
30
+ BackendType.MEDIAPIPE: 2,
31
+ BackendType.ULTRALYTICS: 1,
32
+ BackendType.CUSTOM: 0,
33
+ }
34
+ QUANTIZATION_PRIORITY = {
35
+ Quantization.INT8: 4,
36
+ Quantization.FP16: 3,
37
+ Quantization.FP32: 2,
38
+ Quantization.UNKNOWN: 1,
39
+ }
40
+ DYNAMIC_BATCH_SIZE_KEY = "dynamic"
41
+ STATIC_BATCH_SIZE_KEY = "static"
42
+ BATCH_SIZE_PRIORITY = {
43
+ DYNAMIC_BATCH_SIZE_KEY: 2,
44
+ STATIC_BATCH_SIZE_KEY: 1,
45
+ }
46
+
47
+
48
+ def rank_model_packages(
49
+ model_packages: List[ModelPackageMetadata],
50
+ selected_device: Optional[torch.device] = None,
51
+ nms_fusion_preferences: Optional[Union[bool, dict]] = None,
52
+ ) -> List[ModelPackageMetadata]:
53
+ # I feel like this will be the biggest liability of new inference :))
54
+ # Some dimensions are just hard to rank arbitrarily and reasonably
55
+ sorting_features = []
56
+ # ordering TRT and Cu versions from older to newest -
57
+ # with the assumption that incompatible versions are eliminated earlier, and
58
+ # ranking implicitly attempts to match version closes to the current -
59
+ # in come cases we would rank high versions below current one, but for that
60
+ # it is assumed such versions are compatible, otherwise should be
61
+ # discarded in the previous stage.
62
+ cuda_ranking = rank_cuda_versions(model_packages=model_packages)
63
+ trt_ranking = rank_trt_versions(model_packages=model_packages)
64
+ # this is to ensure determinism when other methods fail
65
+ identifiers_ranking = rank_packages_ids(model_packages=model_packages)
66
+ for model_package, package_cu_rank, package_trt_rank, package_id_rank in zip(
67
+ model_packages, cuda_ranking, trt_ranking, identifiers_ranking
68
+ ):
69
+ batch_mode = (
70
+ DYNAMIC_BATCH_SIZE_KEY
71
+ if model_package.dynamic_batch_size_supported
72
+ else STATIC_BATCH_SIZE_KEY
73
+ )
74
+ static_batch_size_score = (
75
+ 0
76
+ if model_package.static_batch_size is None
77
+ else -1 * model_package.static_batch_size
78
+ )
79
+ sorting_features.append(
80
+ (
81
+ BACKEND_PRIORITY.get(model_package.backend, 0),
82
+ QUANTIZATION_PRIORITY.get(model_package.quantization, 0),
83
+ model_package.trusted_source,
84
+ BATCH_SIZE_PRIORITY[batch_mode],
85
+ static_batch_size_score, # the bigger statis batch size, the worse - requires padding
86
+ retrieve_onnx_opset_score(
87
+ model_package
88
+ ), # the higher opset, the better
89
+ retrieve_trt_forward_compatible_match_score(
90
+ model_package
91
+ ), # exact matches first
92
+ retrieve_same_trt_cc_compatibility_score(model_package),
93
+ retrieve_cuda_device_match_score(
94
+ model_package, selected_device
95
+ ), # we like more direct matches
96
+ package_cu_rank,
97
+ package_trt_rank,
98
+ retrieve_onnx_incompatible_providers_score(model_package),
99
+ retrieve_trt_dynamic_batch_size_score(model_package),
100
+ retrieve_fused_nms_rank(
101
+ model_package, nms_fusion_preferences=nms_fusion_preferences
102
+ ),
103
+ retrieve_trt_lean_runtime_excluded_score(model_package),
104
+ retrieve_jetson_device_name_match_score(model_package),
105
+ retrieve_os_version_match_score(model_package),
106
+ retrieve_l4t_version_match_score(model_package),
107
+ retrieve_driver_version_match_score(model_package),
108
+ package_id_rank,
109
+ model_package,
110
+ )
111
+ )
112
+ sorted_features = sorted(sorting_features, key=lambda x: x[:-1], reverse=True)
113
+ return [f[-1] for f in sorted_features]
114
+
115
+
116
+ def retrieve_onnx_opset_score(model_package: ModelPackageMetadata) -> int:
117
+ if model_package.onnx_package_details is None:
118
+ return -1
119
+ return model_package.onnx_package_details.opset
120
+
121
+
122
+ def retrieve_cuda_device_match_score(
123
+ model_package: ModelPackageMetadata,
124
+ selected_device: Optional[torch.device] = None,
125
+ ) -> int:
126
+ if model_package.backend is not BackendType.TRT:
127
+ return 0
128
+ if model_package.environment_requirements is None:
129
+ return 0
130
+ if not isinstance(
131
+ model_package.environment_requirements,
132
+ (JetsonEnvironmentRequirements, ServerEnvironmentRequirements),
133
+ ):
134
+ return 0
135
+ runtime_x_ray = x_ray_runtime_environment()
136
+ all_available_cuda_devices, _ = filter_available_devices_with_selected_device(
137
+ selected_device=selected_device,
138
+ all_available_cuda_devices=runtime_x_ray.gpu_devices,
139
+ all_available_devices_cc=runtime_x_ray.gpu_devices_cc,
140
+ )
141
+ compilation_device = model_package.environment_requirements.cuda_device_name
142
+ return sum(dev == compilation_device for dev in all_available_cuda_devices)
143
+
144
+
145
+ def retrieve_same_trt_cc_compatibility_score(
146
+ model_package: ModelPackageMetadata,
147
+ ) -> int:
148
+ if model_package.trt_package_details is None:
149
+ return 1
150
+ return int(not model_package.trt_package_details.same_cc_compatible)
151
+
152
+
153
+ def retrieve_trt_forward_compatible_match_score(
154
+ model_package: ModelPackageMetadata,
155
+ ) -> int:
156
+ if model_package.trt_package_details is None:
157
+ return 1
158
+ return int(not model_package.trt_package_details.trt_forward_compatible)
159
+
160
+
161
+ def retrieve_onnx_incompatible_providers_score(
162
+ model_package: ModelPackageMetadata,
163
+ ) -> int:
164
+ if model_package.onnx_package_details is None:
165
+ return 0
166
+ if not model_package.onnx_package_details.incompatible_providers:
167
+ return 0
168
+ runtime_x_ray = x_ray_runtime_environment()
169
+ available_onnx_execution_providers = set(
170
+ runtime_x_ray.available_onnx_execution_providers or []
171
+ )
172
+ return -len(
173
+ available_onnx_execution_providers.intersection(
174
+ model_package.onnx_package_details.incompatible_providers
175
+ )
176
+ )
177
+
178
+
179
+ def retrieve_trt_dynamic_batch_size_score(model_package: ModelPackageMetadata) -> int:
180
+ if model_package.trt_package_details is None:
181
+ return 0
182
+ if any(
183
+ bs is None
184
+ for bs in [
185
+ model_package.trt_package_details.min_dynamic_batch_size,
186
+ model_package.trt_package_details.max_dynamic_batch_size,
187
+ ]
188
+ ):
189
+ return 0
190
+ return (
191
+ model_package.trt_package_details.max_dynamic_batch_size
192
+ - model_package.trt_package_details.min_dynamic_batch_size
193
+ )
194
+
195
+
196
+ def retrieve_trt_lean_runtime_excluded_score(
197
+ model_package: ModelPackageMetadata,
198
+ ) -> int:
199
+ if model_package.trt_package_details is None:
200
+ return 0
201
+ return int(not model_package.trt_package_details.trt_lean_runtime_excluded)
202
+
203
+
204
+ def retrieve_os_version_match_score(model_package: ModelPackageMetadata) -> int:
205
+ if model_package.backend is not BackendType.TRT:
206
+ # irrelevant for not trt
207
+ return 0
208
+ if model_package.environment_requirements is None:
209
+ return 0
210
+ if not isinstance(
211
+ model_package.environment_requirements, ServerEnvironmentRequirements
212
+ ):
213
+ return 0
214
+ if not model_package.environment_requirements.os_version:
215
+ return 0
216
+ runtime_x_ray = x_ray_runtime_environment()
217
+ return int(
218
+ runtime_x_ray.os_version == model_package.environment_requirements.os_version
219
+ )
220
+
221
+
222
+ def retrieve_l4t_version_match_score(model_package: ModelPackageMetadata) -> int:
223
+ if model_package.backend is not BackendType.TRT:
224
+ # irrelevant for not trt
225
+ return 0
226
+ if model_package.environment_requirements is None:
227
+ return 0
228
+ if not isinstance(
229
+ model_package.environment_requirements, JetsonEnvironmentRequirements
230
+ ):
231
+ return 0
232
+ runtime_x_ray = x_ray_runtime_environment()
233
+ return int(
234
+ runtime_x_ray.l4t_version == model_package.environment_requirements.l4t_version
235
+ )
236
+
237
+
238
+ def retrieve_driver_version_match_score(model_package: ModelPackageMetadata) -> int:
239
+ if model_package.trt_package_details is None:
240
+ # irrelevant for not trt
241
+ return 0
242
+ if model_package.environment_requirements is None:
243
+ return 0
244
+ if not isinstance(
245
+ model_package.environment_requirements, JetsonEnvironmentRequirements
246
+ ) and not isinstance(
247
+ model_package.environment_requirements, ServerEnvironmentRequirements
248
+ ):
249
+ return 0
250
+ if not model_package.environment_requirements.driver_version:
251
+ return 0
252
+ runtime_x_ray = x_ray_runtime_environment()
253
+ return int(
254
+ runtime_x_ray.driver_version
255
+ == model_package.environment_requirements.driver_version
256
+ )
257
+
258
+
259
+ def retrieve_jetson_device_name_match_score(model_package: ModelPackageMetadata) -> int:
260
+ if model_package.trt_package_details is None:
261
+ # irrelevant for not trt
262
+ return 0
263
+ if model_package.environment_requirements is None:
264
+ return 0
265
+ if not isinstance(
266
+ model_package.environment_requirements, JetsonEnvironmentRequirements
267
+ ):
268
+ return 0
269
+ runtime_x_ray = x_ray_runtime_environment()
270
+ return int(
271
+ runtime_x_ray.jetson_type
272
+ == model_package.environment_requirements.jetson_product_name
273
+ )
274
+
275
+
276
+ def retrieve_fused_nms_rank(
277
+ model_package: ModelPackageMetadata,
278
+ nms_fusion_preferences: Optional[Union[bool, dict]],
279
+ ) -> Union[float, int]:
280
+ if nms_fusion_preferences is None or nms_fusion_preferences is False:
281
+ return 0
282
+ if not model_package.model_features:
283
+ return 0
284
+ nms_fused = model_package.model_features.get(NMS_FUSED_FEATURE)
285
+ if not isinstance(nms_fused, dict):
286
+ return 0
287
+ if nms_fusion_preferences is True:
288
+ # default values should be passed by filter, so here we treat every package equally good
289
+ return 1
290
+ actual_max_detections = nms_fused[NMS_MAX_DETECTIONS_KEY]
291
+ actual_confidence_threshold = nms_fused[NMS_CONFIDENCE_THRESHOLD_KEY]
292
+ actual_iou_threshold = nms_fused[NMS_IOU_THRESHOLD_KEY]
293
+ actual_class_agnostic = nms_fused[NMS_CLASS_AGNOSTIC_KEY]
294
+ final_score = 0.0
295
+ if NMS_MAX_DETECTIONS_KEY in nms_fusion_preferences:
296
+ requested_max_detections = nms_fusion_preferences[NMS_MAX_DETECTIONS_KEY]
297
+ if isinstance(requested_max_detections, (list, tuple)):
298
+ min_detections, max_detections = requested_max_detections
299
+ else:
300
+ min_detections, max_detections = (
301
+ requested_max_detections,
302
+ requested_max_detections,
303
+ )
304
+ final_score += score_distance_from_mean(
305
+ min_value=min_detections,
306
+ max_value=max_detections,
307
+ examined_value=actual_max_detections,
308
+ )
309
+ if NMS_CONFIDENCE_THRESHOLD_KEY in nms_fusion_preferences:
310
+ requested_confidence = nms_fusion_preferences[NMS_CONFIDENCE_THRESHOLD_KEY]
311
+ if isinstance(requested_confidence, (list, tuple)):
312
+ min_confidence, max_confidence = requested_confidence
313
+ else:
314
+ min_confidence, max_confidence = (
315
+ requested_confidence,
316
+ requested_confidence,
317
+ )
318
+ final_score += score_distance_from_mean(
319
+ min_value=min_confidence,
320
+ max_value=max_confidence,
321
+ examined_value=actual_confidence_threshold,
322
+ )
323
+ if NMS_IOU_THRESHOLD_KEY in nms_fusion_preferences:
324
+ requested_iou_threshold = nms_fusion_preferences[NMS_IOU_THRESHOLD_KEY]
325
+ if isinstance(requested_iou_threshold, (list, tuple)):
326
+ min_iou_threshold, max_iou_threshold = requested_iou_threshold
327
+ else:
328
+ min_iou_threshold, max_iou_threshold = (
329
+ requested_iou_threshold,
330
+ requested_iou_threshold,
331
+ )
332
+ final_score += score_distance_from_mean(
333
+ min_value=min_iou_threshold,
334
+ max_value=max_iou_threshold,
335
+ examined_value=actual_iou_threshold,
336
+ )
337
+ if NMS_CLASS_AGNOSTIC_KEY in nms_fusion_preferences:
338
+ final_score += float(
339
+ actual_class_agnostic == nms_fusion_preferences[NMS_CLASS_AGNOSTIC_KEY]
340
+ )
341
+ return final_score
342
+
343
+
344
+ def score_distance_from_mean(
345
+ min_value: float, max_value: float, examined_value: float
346
+ ) -> float:
347
+ min_value, max_value = min(min_value, max_value), max(min_value, max_value)
348
+ if min_value == max_value:
349
+ return float(abs(examined_value - max_value) < 1e-5)
350
+ if examined_value < min_value or examined_value > max_value:
351
+ return 0.0
352
+ span = max_value - min_value
353
+ examined_value_scaled = min(max((examined_value - min_value) / (span + 1e-6), 0), 1)
354
+ return 1.0 - abs(0.5 - examined_value_scaled)
355
+
356
+
357
+ def rank_cuda_versions(model_packages: List[ModelPackageMetadata]) -> List[int]:
358
+ cuda_versions = []
359
+ package_id_to_cuda_version = {}
360
+ last_ranking = -len(model_packages) + 1
361
+ for package in model_packages:
362
+ if isinstance(package.environment_requirements, ServerEnvironmentRequirements):
363
+ cuda_versions.append(package.environment_requirements.cuda_version)
364
+ package_id_to_cuda_version[package.package_id] = (
365
+ package.environment_requirements.cuda_version
366
+ )
367
+ elif isinstance(
368
+ package.environment_requirements, JetsonEnvironmentRequirements
369
+ ):
370
+ cuda_versions.append(package.environment_requirements.cuda_version)
371
+ package_id_to_cuda_version[package.package_id] = (
372
+ package.environment_requirements.cuda_version
373
+ )
374
+ cuda_versions = sorted(set(cuda_versions))
375
+ cuda_versions_ranking = {version: -idx for idx, version in enumerate(cuda_versions)}
376
+ results = []
377
+ for package in model_packages:
378
+ package_cu_version = package_id_to_cuda_version.get(package.package_id)
379
+ result = cuda_versions_ranking.get(package_cu_version, last_ranking)
380
+ results.append(result)
381
+ return results
382
+
383
+
384
+ def rank_trt_versions(model_packages: List[ModelPackageMetadata]) -> List[int]:
385
+ trt_versions = []
386
+ package_id_to_trt_version = {}
387
+ last_ranking = -len(model_packages) + 1
388
+ for package in model_packages:
389
+ if isinstance(package.environment_requirements, ServerEnvironmentRequirements):
390
+ trt_versions.append(package.environment_requirements.trt_version)
391
+ package_id_to_trt_version[package.package_id] = (
392
+ package.environment_requirements.trt_version
393
+ )
394
+ elif isinstance(
395
+ package.environment_requirements, JetsonEnvironmentRequirements
396
+ ):
397
+ trt_versions.append(package.environment_requirements.trt_version)
398
+ package_id_to_trt_version[package.package_id] = (
399
+ package.environment_requirements.trt_version
400
+ )
401
+ trt_versions = sorted(set(trt_versions))
402
+ trt_versions_ranking = {version: -idx for idx, version in enumerate(trt_versions)}
403
+ results = []
404
+ for package in model_packages:
405
+ package_trt_version = package_id_to_trt_version.get(package.package_id)
406
+ result = trt_versions_ranking.get(package_trt_version, last_ranking)
407
+ results.append(result)
408
+ return results
409
+
410
+
411
+ def rank_packages_ids(model_packages: List[ModelPackageMetadata]) -> List[int]:
412
+ package_ids = [p.package_id for p in model_packages]
413
+ return sorted(range(len(package_ids)), key=lambda i: package_ids[i])
@@ -0,0 +1,31 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import torch
4
+ from packaging.version import Version
5
+
6
+ from inference_models.errors import ModelPackageNegotiationError
7
+
8
+
9
+ def filter_available_devices_with_selected_device(
10
+ selected_device: Optional[torch.device],
11
+ all_available_cuda_devices: List[str],
12
+ all_available_devices_cc: List[Version],
13
+ ) -> Tuple[List[str], List[Version]]:
14
+ if selected_device is not None and selected_device.type != "cuda":
15
+ return [], []
16
+ if selected_device is not None and selected_device.type == "cuda":
17
+ index = selected_device.index or 0
18
+ if index >= len(all_available_cuda_devices) or index >= len(
19
+ all_available_devices_cc
20
+ ):
21
+ raise ModelPackageNegotiationError(
22
+ message=f"Model Package Negotiation algorithm received selected device: {selected_device} which "
23
+ f"does not match runtime introspection results. If you selected device to run the model "
24
+ f"manually - verify your choice. Otherwise, this error most likely is a bug. Create new "
25
+ f"issue: https://github.com/roboflow/inference/issues",
26
+ help_url="https://todo",
27
+ )
28
+ all_available_cuda_devices = [all_available_cuda_devices[index]]
29
+ all_available_devices_cc = [all_available_devices_cc[index]]
30
+ return all_available_cuda_devices, all_available_devices_cc
31
+ return all_available_cuda_devices, all_available_devices_cc
File without changes
@@ -0,0 +1,123 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Generic, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from inference_models.models.base.types import PreprocessedInputs, RawPrediction
9
+
10
+
11
+ @dataclass
12
+ class ClassificationPrediction:
13
+ class_id: torch.Tensor # (bs, )
14
+ confidence: torch.Tensor # (bs, )
15
+ images_metadata: Optional[List[dict]] = None # if given, list of size equal to bs
16
+
17
+
18
+ class ClassificationModel(ABC, Generic[PreprocessedInputs, RawPrediction]):
19
+
20
+ @classmethod
21
+ @abstractmethod
22
+ def from_pretrained(
23
+ cls, model_name_or_path: str, **kwargs
24
+ ) -> "ClassificationModel":
25
+ pass
26
+
27
+ @property
28
+ @abstractmethod
29
+ def class_names(self) -> List[str]:
30
+ pass
31
+
32
+ def infer(
33
+ self,
34
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
35
+ **kwargs,
36
+ ) -> ClassificationPrediction:
37
+ pre_processed_images = self.pre_process(images, **kwargs)
38
+ model_results = self.forward(pre_processed_images, **kwargs)
39
+ return self.post_process(model_results, **kwargs)
40
+
41
+ @abstractmethod
42
+ def pre_process(
43
+ self,
44
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
45
+ **kwargs,
46
+ ) -> PreprocessedInputs:
47
+ pass
48
+
49
+ @abstractmethod
50
+ def forward(
51
+ self, pre_processed_images: PreprocessedInputs, **kwargs
52
+ ) -> RawPrediction:
53
+ pass
54
+
55
+ @abstractmethod
56
+ def post_process(
57
+ self, model_results: RawPrediction, **kwargs
58
+ ) -> ClassificationPrediction:
59
+ pass
60
+
61
+ def __call__(
62
+ self,
63
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
64
+ **kwargs,
65
+ ) -> ClassificationPrediction:
66
+ return self.infer(images, **kwargs)
67
+
68
+
69
+ @dataclass
70
+ class MultiLabelClassificationPrediction:
71
+ class_ids: torch.Tensor # (predicted_labels_ids, )
72
+ confidence: torch.Tensor # (predicted_labels_confidence, )
73
+ image_metadata: Optional[dict] = None
74
+
75
+
76
+ class MultiLabelClassificationModel(ABC, Generic[PreprocessedInputs, RawPrediction]):
77
+
78
+ @classmethod
79
+ def from_pretrained(
80
+ cls, model_name_or_path: str, **kwargs
81
+ ) -> "MultiLabelClassificationModel":
82
+ pass
83
+
84
+ @property
85
+ @abstractmethod
86
+ def class_names(self) -> List[str]:
87
+ pass
88
+
89
+ def infer(
90
+ self,
91
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
92
+ **kwargs,
93
+ ) -> List[MultiLabelClassificationPrediction]:
94
+ pre_processed_images = self.pre_process(images, **kwargs)
95
+ model_results = self.forward(pre_processed_images, **kwargs)
96
+ return self.post_process(model_results, **kwargs)
97
+
98
+ @abstractmethod
99
+ def pre_process(
100
+ self,
101
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
102
+ **kwargs,
103
+ ) -> PreprocessedInputs:
104
+ pass
105
+
106
+ @abstractmethod
107
+ def forward(
108
+ self, pre_processed_images: PreprocessedInputs, **kwargs
109
+ ) -> RawPrediction:
110
+ pass
111
+
112
+ @abstractmethod
113
+ def post_process(
114
+ self, model_results: RawPrediction, **kwargs
115
+ ) -> List[MultiLabelClassificationPrediction]:
116
+ pass
117
+
118
+ def __call__(
119
+ self,
120
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
121
+ **kwargs,
122
+ ) -> List[MultiLabelClassificationPrediction]:
123
+ return self.infer(images, **kwargs)
@@ -0,0 +1,62 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Generic, List, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models.models.base.types import (
8
+ PreprocessedInputs,
9
+ PreprocessingMetadata,
10
+ RawPrediction,
11
+ )
12
+
13
+
14
+ class DepthEstimationModel(
15
+ ABC, Generic[PreprocessedInputs, PreprocessingMetadata, RawPrediction]
16
+ ):
17
+
18
+ @classmethod
19
+ @abstractmethod
20
+ def from_pretrained(
21
+ cls, model_name_or_path: str, **kwargs
22
+ ) -> "DepthEstimationModel":
23
+ pass
24
+
25
+ def infer(
26
+ self,
27
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
28
+ **kwargs,
29
+ ) -> List[torch.Tensor]:
30
+ pre_processed_images, pre_processing_meta = self.pre_process(images, **kwargs)
31
+ model_results = self.forward(pre_processed_images, **kwargs)
32
+ return self.post_process(model_results, pre_processing_meta, **kwargs)
33
+
34
+ @abstractmethod
35
+ def pre_process(
36
+ self,
37
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
38
+ **kwargs,
39
+ ) -> Tuple[PreprocessedInputs, PreprocessingMetadata]:
40
+ pass
41
+
42
+ @abstractmethod
43
+ def forward(
44
+ self, pre_processed_images: PreprocessedInputs, **kwargs
45
+ ) -> RawPrediction:
46
+ pass
47
+
48
+ @abstractmethod
49
+ def post_process(
50
+ self,
51
+ model_results: RawPrediction,
52
+ pre_processing_meta: PreprocessingMetadata,
53
+ **kwargs,
54
+ ) -> List[torch.Tensor]:
55
+ pass
56
+
57
+ def __call__(
58
+ self,
59
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
60
+ **kwargs,
61
+ ) -> List[torch.Tensor]:
62
+ return self.infer(images, **kwargs)