inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,36 @@
1
+ import os
2
+
3
+ if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") is None:
4
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
+ if os.environ.get("TOKENIZERS_PARALLELISM") is None:
6
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
7
+
8
+ from inference_models.entities import ColorFormat
9
+ from inference_models.model_pipelines.auto_loaders.core import AutoModelPipeline
10
+ from inference_models.models.auto_loaders.core import AutoModel
11
+ from inference_models.models.base.classification import (
12
+ ClassificationModel,
13
+ ClassificationPrediction,
14
+ MultiLabelClassificationModel,
15
+ MultiLabelClassificationPrediction,
16
+ )
17
+ from inference_models.models.base.depth_estimation import DepthEstimationModel
18
+ from inference_models.models.base.documents_parsing import (
19
+ StructuredOCRModel,
20
+ TextOnlyOCRModel,
21
+ )
22
+ from inference_models.models.base.embeddings import TextImageEmbeddingModel
23
+ from inference_models.models.base.instance_segmentation import (
24
+ InstanceDetections,
25
+ InstanceSegmentationModel,
26
+ )
27
+ from inference_models.models.base.keypoints_detection import (
28
+ KeyPoints,
29
+ KeyPointsDetectionModel,
30
+ )
31
+ from inference_models.models.base.object_detection import (
32
+ Detections,
33
+ ObjectDetectionModel,
34
+ OpenVocabularyObjectDetectionModel,
35
+ )
36
+ from inference_models.models.base.semantic_segmentation import SemanticSegmentationModel
@@ -0,0 +1,72 @@
1
+ import os
2
+
3
+ import torch
4
+
5
+ from inference_models.utils.environment import parse_comma_separated_values, str2bool
6
+
7
+ ONNXRUNTIME_EXECUTION_PROVIDERS = parse_comma_separated_values(
8
+ values=os.getenv(
9
+ "ONNXRUNTIME_EXECUTION_PROVIDERS",
10
+ "CUDAExecutionProvider,OpenVINOExecutionProvider,CoreMLExecutionProvider,CPUExecutionProvider",
11
+ )
12
+ .strip("[")
13
+ .strip("]")
14
+ )
15
+ DEFAULT_DEVICE_STR = os.getenv(
16
+ "DEFAULT_DEVICE",
17
+ ("cuda" if torch.cuda.is_available() else "cpu"),
18
+ )
19
+ DEFAULT_DEVICE = torch.device(DEFAULT_DEVICE_STR)
20
+ ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY")
21
+ API_CALLS_TIMEOUT = int(os.getenv("API_CALLS_TIMEOUT", "5"))
22
+ API_CALLS_MAX_TRIES = int(os.getenv("API_CALLS_MAX_TRIES", "3"))
23
+ IDEMPOTENT_API_REQUEST_CODES_TO_RETRY = set(
24
+ int(e.strip())
25
+ for e in os.getenv(
26
+ "IDEMPOTENT_API_REQUEST_CODES_TO_RETRY", "408,429,502,503,504"
27
+ ).split(",")
28
+ )
29
+ ROBOFLOW_ENVIRONMENT = os.getenv("ROBOFLOW_ENVIRONMENT", "prod")
30
+ ROBOFLOW_API_HOST = os.getenv(
31
+ "ROBOFLOW_API_HOST",
32
+ (
33
+ "https://api.roboflow.com"
34
+ if ROBOFLOW_ENVIRONMENT.lower() == "prod"
35
+ else "https://api.roboflow.one"
36
+ ),
37
+ )
38
+ RUNNING_ON_JETSON = os.getenv("RUNNING_ON_JETSON")
39
+ L4T_VERSION = os.getenv("L4T_VERSION")
40
+ INFERENCE_HOME = os.getenv("INFERENCE_HOME", "/tmp/cache")
41
+ DISABLE_INTERACTIVE_PROGRESS_BARS = str2bool(
42
+ os.getenv("DISABLE_INTERACTIVE_PROGRESS_BARS", "False")
43
+ )
44
+ LOG_LEVEL = os.getenv("LOG_LEVEL", "WARNING")
45
+ VERBOSE_LOG_LEVEL = os.getenv("VERBOSE_LOG_LEVEL", "INFO")
46
+ DISABLE_VERBOSE_LOGGER = str2bool(os.getenv("DISABLE_VERBOSE_LOGGER", "False"))
47
+ AUTO_LOADER_CACHE_EXPIRATION_MINUTES = int(
48
+ os.getenv("AUTO_LOADER_CACHE_EXPIRATION_MINUTES", "1440")
49
+ )
50
+ ALLOW_URL_INPUT = str2bool(os.getenv("ALLOW_URL_INPUT", True))
51
+ ALLOW_NON_HTTPS_URL_INPUT = str2bool(os.getenv("ALLOW_NON_HTTPS_URL_INPUT", False))
52
+ ALLOW_URL_INPUT_WITHOUT_FQDN = str2bool(
53
+ os.getenv("ALLOW_URL_INPUT_WITHOUT_FQDN", False)
54
+ )
55
+ WHITELISTED_DESTINATIONS_FOR_URL_INPUT = os.getenv(
56
+ "WHITELISTED_DESTINATIONS_FOR_URL_INPUT"
57
+ )
58
+ if WHITELISTED_DESTINATIONS_FOR_URL_INPUT is not None:
59
+ WHITELISTED_DESTINATIONS_FOR_URL_INPUT = parse_comma_separated_values(
60
+ WHITELISTED_DESTINATIONS_FOR_URL_INPUT
61
+ )
62
+
63
+ BLACKLISTED_DESTINATIONS_FOR_URL_INPUT = os.getenv(
64
+ "BLACKLISTED_DESTINATIONS_FOR_URL_INPUT"
65
+ )
66
+ if BLACKLISTED_DESTINATIONS_FOR_URL_INPUT is not None:
67
+ BLACKLISTED_DESTINATIONS_FOR_URL_INPUT = parse_comma_separated_values(
68
+ BLACKLISTED_DESTINATIONS_FOR_URL_INPUT
69
+ )
70
+ ALLOW_LOCAL_STORAGE_ACCESS_FOR_REFERENCE_DATA = os.getenv(
71
+ "ALLOW_LOCAL_STORAGE_ACCESS_FOR_REFERENCE_DATA"
72
+ )
@@ -0,0 +1,2 @@
1
+ HTTP_CODES_TO_RETRY = {408, 429, 502, 503, 504}
2
+ DOWNLOAD_CHUNK_SIZE = 8 * 1024
@@ -0,0 +1,5 @@
1
+ from collections import namedtuple
2
+ from typing import Literal
3
+
4
+ ImageDimensions = namedtuple("ImageDimensions", ["height", "width"])
5
+ ColorFormat = Literal["rgb", "bgr"]
@@ -0,0 +1,137 @@
1
+ from typing import Optional
2
+
3
+
4
+ class BaseInferenceError(Exception):
5
+
6
+ def __init__(self, message: str, help_url: Optional[str] = None):
7
+ super().__init__(message)
8
+ self._help_url = help_url
9
+
10
+ @property
11
+ def help_url(self) -> Optional[str]:
12
+ return self._help_url
13
+
14
+ def __str__(self) -> str:
15
+ if self._help_url is None:
16
+ return super().__str__()
17
+ return f"{super().__str__()} - VISIT {self._help_url} FOR FURTHER SUPPORT"
18
+
19
+
20
+ class AssumptionError(BaseInferenceError):
21
+ pass
22
+
23
+
24
+ class EnvironmentConfigurationError(BaseInferenceError):
25
+ pass
26
+
27
+
28
+ class ModelRuntimeError(BaseInferenceError):
29
+ pass
30
+
31
+
32
+ class ModelInputError(BaseInferenceError):
33
+ pass
34
+
35
+
36
+ class RetryError(BaseInferenceError):
37
+ pass
38
+
39
+
40
+ class ModelRetrievalError(BaseInferenceError):
41
+ pass
42
+
43
+
44
+ class UntrustedFileError(BaseInferenceError):
45
+ pass
46
+
47
+
48
+ class FileHashSumMissmatch(BaseInferenceError):
49
+ pass
50
+
51
+
52
+ class UnauthorizedModelAccessError(ModelRetrievalError):
53
+ pass
54
+
55
+
56
+ class ModelMetadataConsistencyError(ModelRetrievalError):
57
+ pass
58
+
59
+
60
+ class ModelMetadataHandlerNotImplementedError(ModelRetrievalError):
61
+ pass
62
+
63
+
64
+ class InvalidEnvVariable(BaseInferenceError):
65
+ pass
66
+
67
+
68
+ class ModelPackageNegotiationError(BaseInferenceError):
69
+ pass
70
+
71
+
72
+ class UnknownBackendTypeError(ModelPackageNegotiationError):
73
+ pass
74
+
75
+
76
+ class UnknownQuantizationError(ModelPackageNegotiationError):
77
+ pass
78
+
79
+
80
+ class InvalidRequestedBatchSizeError(ModelPackageNegotiationError):
81
+ pass
82
+
83
+
84
+ class RuntimeIntrospectionError(ModelPackageNegotiationError):
85
+ pass
86
+
87
+
88
+ class JetsonTypeResolutionError(RuntimeIntrospectionError):
89
+ pass
90
+
91
+
92
+ class NoModelPackagesAvailableError(ModelPackageNegotiationError):
93
+ pass
94
+
95
+
96
+ class AmbiguousModelPackageResolutionError(ModelPackageNegotiationError):
97
+ pass
98
+
99
+
100
+ class ModelLoadingError(BaseInferenceError):
101
+ pass
102
+
103
+
104
+ class InsecureModelIdentifierError(ModelLoadingError):
105
+ pass
106
+
107
+
108
+ class DirectLocalStorageAccessError(ModelLoadingError):
109
+ pass
110
+
111
+
112
+ class ModelImplementationLoaderError(ModelLoadingError):
113
+ pass
114
+
115
+
116
+ class CorruptedModelPackageError(ModelLoadingError):
117
+ pass
118
+
119
+
120
+ class MissingDependencyError(BaseInferenceError):
121
+ pass
122
+
123
+
124
+ class InvalidParameterError(BaseInferenceError):
125
+ pass
126
+
127
+
128
+ class DependencyModelParametersValidationError(ModelLoadingError):
129
+ pass
130
+
131
+
132
+ class ModelPipelineInitializationError(ModelLoadingError):
133
+ pass
134
+
135
+
136
+ class ModelPipelineNotFound(ModelPipelineInitializationError):
137
+ pass
@@ -0,0 +1,52 @@
1
+ import logging
2
+
3
+ from inference_models.configuration import (
4
+ DISABLE_VERBOSE_LOGGER,
5
+ LOG_LEVEL,
6
+ VERBOSE_LOG_LEVEL,
7
+ )
8
+
9
+
10
+ def configure_log_level(
11
+ logger: logging.Logger, log_level: str, fallback_level: int
12
+ ) -> None:
13
+ log_level = getattr(logging, log_level, fallback_level)
14
+ logger.setLevel(log_level)
15
+ if not logger.handlers:
16
+ handler = logging.StreamHandler()
17
+ formatter = logging.Formatter("%(message)s")
18
+ handler.setFormatter(formatter)
19
+ logger.addHandler(handler)
20
+ for handler in logger.handlers:
21
+ handler.setLevel(log_level)
22
+ logger.propagate = False
23
+
24
+
25
+ LOGGER = logging.getLogger("inference-models")
26
+ configure_log_level(logger=LOGGER, log_level=LOG_LEVEL, fallback_level=logging.WARNING)
27
+ VERBOSE_LOGGER = logging.getLogger("inference-models-verbose")
28
+ configure_log_level(
29
+ logger=VERBOSE_LOGGER, log_level=VERBOSE_LOG_LEVEL, fallback_level=logging.INFO
30
+ )
31
+
32
+
33
+ def verbose_info(
34
+ message: str,
35
+ verbose_requested: bool = True,
36
+ ) -> None:
37
+ if DISABLE_VERBOSE_LOGGER:
38
+ return None
39
+ if not verbose_requested:
40
+ return None
41
+ VERBOSE_LOGGER.info(message)
42
+
43
+
44
+ def verbose_debug(
45
+ message: str,
46
+ verbose_requested: bool = True,
47
+ ) -> None:
48
+ if DISABLE_VERBOSE_LOGGER:
49
+ return None
50
+ if not verbose_requested:
51
+ return None
52
+ VERBOSE_LOGGER.debug(message)
File without changes
@@ -0,0 +1,120 @@
1
+ from typing import List, Optional, Union
2
+
3
+ from rich.console import Console
4
+ from rich.tree import Tree
5
+
6
+ from inference_models.errors import ModelPipelineInitializationError
7
+ from inference_models.logger import verbose_info
8
+ from inference_models.model_pipelines.auto_loaders.pipelines_registry import (
9
+ REGISTERED_PIPELINES,
10
+ get_default_pipeline_parameters,
11
+ resolve_pipeline_class,
12
+ )
13
+ from inference_models.models.auto_loaders.access_manager import ModelAccessManager
14
+ from inference_models.models.auto_loaders.auto_resolution_cache import (
15
+ AutoResolutionCache,
16
+ )
17
+ from inference_models.models.auto_loaders.core import AutoModel
18
+ from inference_models.models.auto_loaders.dependency_models import (
19
+ DependencyModelParameters,
20
+ prepare_dependency_model_parameters,
21
+ )
22
+ from inference_models.models.auto_loaders.entities import AnyModel
23
+
24
+
25
+ class AutoModelPipeline:
26
+
27
+ @classmethod
28
+ def list_available_pipelines(cls) -> None:
29
+ console = Console()
30
+ tree = Tree("Available Model Pipelines:")
31
+ for pipeline_id in sorted(REGISTERED_PIPELINES):
32
+ tree.add(pipeline_id)
33
+ console.print(tree)
34
+
35
+ @classmethod
36
+ def from_pretrained(
37
+ cls,
38
+ pipline_id: str,
39
+ models_parameters: Optional[
40
+ List[Optional[Union[str, dict, DependencyModelParameters]]]
41
+ ] = None,
42
+ weights_provider: str = "roboflow",
43
+ api_key: Optional[str] = None,
44
+ max_package_loading_attempts: Optional[int] = None,
45
+ verbose: bool = False,
46
+ model_download_file_lock_acquire_timeout: int = 10,
47
+ allow_untrusted_packages: bool = False,
48
+ trt_engine_host_code_allowed: bool = True,
49
+ allow_local_code_packages: bool = True,
50
+ verify_hash_while_download: bool = True,
51
+ download_files_without_hash: bool = False,
52
+ use_auto_resolution_cache: bool = True,
53
+ auto_resolution_cache: Optional[AutoResolutionCache] = None,
54
+ allow_direct_local_storage_loading: bool = True,
55
+ model_access_manager: Optional[ModelAccessManager] = None,
56
+ **kwargs,
57
+ ) -> AnyModel:
58
+ pipeline_class = resolve_pipeline_class(pipline_id=pipline_id)
59
+ models = []
60
+ verbose_info(
61
+ message=f"Initializing models for pipeline `{pipline_id}`",
62
+ verbose_requested=verbose,
63
+ )
64
+ default_parameters = get_default_pipeline_parameters(pipline_id=pipline_id)
65
+ if models_parameters is None and default_parameters is None:
66
+ raise ModelPipelineInitializationError(
67
+ message=f"Could not initialize model pipeline `{pipline_id}` - models parameters not provided and "
68
+ f"default values not registered in the library. If you run locally, please verify your "
69
+ f"integration - it must specify the models to be used by the pipeline. If you use Roboflow "
70
+ f"hosted solution, contact us to get help.",
71
+ help_url="https://todo",
72
+ )
73
+ if models_parameters is None:
74
+ models_parameters = default_parameters
75
+ if default_parameters is None:
76
+ default_parameters = [None] * len(models_parameters)
77
+ for idx, model_parameters in enumerate(models_parameters):
78
+ if model_parameters is None:
79
+ parameters_to_be_used = (
80
+ default_parameters[idx] if idx < len(default_parameters) else None
81
+ )
82
+ else:
83
+ parameters_to_be_used = model_parameters
84
+ resolved_model_parameters = prepare_dependency_model_parameters(
85
+ model_parameters=parameters_to_be_used
86
+ )
87
+ verbose_info(
88
+ message=f"Initializing model: `{resolved_model_parameters.model_id_or_path}`",
89
+ verbose_requested=verbose,
90
+ )
91
+ model = AutoModel.from_pretrained(
92
+ model_id_or_path=resolved_model_parameters.model_id_or_path,
93
+ weights_provider=weights_provider,
94
+ api_key=api_key,
95
+ model_package_id=resolved_model_parameters.model_package_id,
96
+ backend=resolved_model_parameters.backend,
97
+ batch_size=resolved_model_parameters.batch_size,
98
+ quantization=resolved_model_parameters.quantization,
99
+ onnx_execution_providers=resolved_model_parameters.onnx_execution_providers,
100
+ device=resolved_model_parameters.device,
101
+ default_onnx_trt_options=resolved_model_parameters.default_onnx_trt_options,
102
+ max_package_loading_attempts=max_package_loading_attempts,
103
+ verbose=verbose,
104
+ model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
105
+ allow_untrusted_packages=allow_untrusted_packages,
106
+ trt_engine_host_code_allowed=trt_engine_host_code_allowed,
107
+ allow_local_code_packages=allow_local_code_packages,
108
+ verify_hash_while_download=verify_hash_while_download,
109
+ download_files_without_hash=download_files_without_hash,
110
+ use_auto_resolution_cache=use_auto_resolution_cache,
111
+ auto_resolution_cache=auto_resolution_cache,
112
+ allow_direct_local_storage_loading=allow_direct_local_storage_loading,
113
+ model_access_manager=model_access_manager,
114
+ nms_fusion_preferences=resolved_model_parameters.nms_fusion_preferences,
115
+ model_type=resolved_model_parameters.model_type,
116
+ task_type=resolved_model_parameters.task_type,
117
+ **resolved_model_parameters.kwargs,
118
+ )
119
+ models.append(model)
120
+ return pipeline_class.with_models(models, **kwargs)
@@ -0,0 +1,36 @@
1
+ from typing import Dict, List, Optional, Union
2
+
3
+ from inference_models.errors import ModelPipelineNotFound
4
+ from inference_models.utils.imports import LazyClass
5
+
6
+ REGISTERED_PIPELINES: Dict[str, LazyClass] = {
7
+ "face-and-gaze-detection": LazyClass(
8
+ module_name="inference_models.model_pipelines.face_and_gaze_detection.mediapipe_l2cs",
9
+ class_name="FaceAndGazeDetectionMPAndL2CS",
10
+ )
11
+ }
12
+
13
+ DEFAULT_PIPELINES_PARAMETERS: Dict[str, List[Union[str, dict]]] = {
14
+ "face-and-gaze-detection": [
15
+ "mediapipe/face-detector",
16
+ "l2cs-net/rn50",
17
+ ]
18
+ }
19
+
20
+
21
+ def resolve_pipeline_class(pipline_id: str) -> type:
22
+ if pipline_id not in REGISTERED_PIPELINES:
23
+ raise ModelPipelineNotFound(
24
+ message=f"Could not find model pipeline with id: `{pipline_id}`. "
25
+ f"Registered pipelines: {list(REGISTERED_PIPELINES.keys())}. This error ma be caused by typo "
26
+ f"in the identifier, or pipeline is not registered / not supported in the environment you try to "
27
+ f"run it.",
28
+ help_url="https://todo",
29
+ )
30
+ return REGISTERED_PIPELINES[pipline_id].resolve()
31
+
32
+
33
+ def get_default_pipeline_parameters(
34
+ pipline_id: str,
35
+ ) -> Optional[List[Union[str, dict]]]:
36
+ return DEFAULT_PIPELINES_PARAMETERS.get(pipline_id)
@@ -0,0 +1,200 @@
1
+ from typing import Any, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from inference_models import Detections, KeyPoints
7
+ from inference_models.configuration import DEFAULT_DEVICE
8
+ from inference_models.entities import ColorFormat
9
+ from inference_models.errors import ModelPipelineInitializationError, ModelRuntimeError
10
+ from inference_models.models.l2cs.l2cs_onnx import (
11
+ DEFAULT_GAZE_MAX_BATCH_SIZE,
12
+ L2CSGazeDetection,
13
+ L2CSNetOnnx,
14
+ )
15
+ from inference_models.models.mediapipe_face_detection.face_detection import (
16
+ MediaPipeFaceDetector,
17
+ )
18
+
19
+
20
+ class FaceAndGazeDetectionMPAndL2CS:
21
+
22
+ @classmethod
23
+ def with_models(
24
+ cls, models: List[Any], **kwargs
25
+ ) -> "FaceAndGazeDetectionMPAndL2CS":
26
+ if len(models) != 2:
27
+ raise ModelPipelineInitializationError(
28
+ message="Model pipeline `face-and-gaze-detection` requires two models tu run - face detector "
29
+ "and gaze detector. If you run `inference` locally, verify the parameter of pipeline loader "
30
+ "to make sure that two models parameters' are provided. If you use Roboflow hosted solution, "
31
+ "contact us to get help.",
32
+ help_url="https://todo",
33
+ )
34
+ face_detector, gaze_detector = models
35
+ if not isinstance(face_detector, MediaPipeFaceDetector):
36
+ raise ModelPipelineInitializationError(
37
+ message="Model pipeline `face-and-gaze-detection` requires first model to be `MediaPipeFaceDetector` - "
38
+ "if you run `inference` locally, make sure that you initialized the pipeline pointing model of "
39
+ "matching type.",
40
+ help_url="https://todo",
41
+ )
42
+ if not isinstance(gaze_detector, L2CSNetOnnx):
43
+ raise ModelPipelineInitializationError(
44
+ message="Model pipeline `face-and-gaze-detection` requires second model to be `L2CSNet` - "
45
+ "if you run `inference` locally, make sure that you initialized the pipeline pointing model of "
46
+ "matching type.",
47
+ help_url="https://todo",
48
+ )
49
+ return FaceAndGazeDetectionMPAndL2CS.from_pretrained(
50
+ face_detector=face_detector, gaze_detector=gaze_detector, **kwargs
51
+ )
52
+
53
+ @classmethod
54
+ def from_pretrained(
55
+ cls,
56
+ face_detector: Union[str, MediaPipeFaceDetector],
57
+ gaze_detector: Union[str, L2CSNetOnnx],
58
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
59
+ default_onnx_trt_options: bool = True,
60
+ device: torch.device = DEFAULT_DEVICE,
61
+ max_batch_size: int = DEFAULT_GAZE_MAX_BATCH_SIZE,
62
+ **kwargs,
63
+ ) -> "FaceAndGazeDetectionMPAndL2CS":
64
+ if isinstance(face_detector, str):
65
+ face_detector = MediaPipeFaceDetector.from_pretrained(
66
+ model_name_or_path=face_detector
67
+ )
68
+ if isinstance(gaze_detector, str):
69
+ gaze_detector = L2CSNetOnnx.from_pretrained(
70
+ model_name_or_path=gaze_detector,
71
+ onnx_execution_providers=onnx_execution_providers,
72
+ default_onnx_trt_options=default_onnx_trt_options,
73
+ device=device,
74
+ max_batch_size=max_batch_size,
75
+ )
76
+ return cls(
77
+ face_detector=face_detector,
78
+ gaze_detector=gaze_detector,
79
+ )
80
+
81
+ def __init__(
82
+ self,
83
+ face_detector: MediaPipeFaceDetector,
84
+ gaze_detector: L2CSNetOnnx,
85
+ ):
86
+ self._face_detector = face_detector
87
+ self._gaze_detector = gaze_detector
88
+
89
+ @property
90
+ def class_names(self) -> List[str]:
91
+ return self._face_detector.class_names
92
+
93
+ @property
94
+ def key_points_classes(self) -> List[List[str]]:
95
+ return self._face_detector.key_points_classes
96
+
97
+ @property
98
+ def skeletons(self) -> List[List[Tuple[int, int]]]:
99
+ return [[(5, 1), (1, 2), (4, 0), (0, 2), (2, 3)]]
100
+
101
+ def infer(
102
+ self,
103
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
104
+ input_color_format: Optional[ColorFormat] = None,
105
+ conf_threshold: float = 0.25,
106
+ **kwargs,
107
+ ) -> Tuple[List[KeyPoints], List[Detections], List[L2CSGazeDetection]]:
108
+ key_points, detections = self._face_detector(
109
+ images,
110
+ input_color_format=input_color_format,
111
+ conf_thresh=conf_threshold,
112
+ **kwargs,
113
+ )
114
+ crops, crops_images_bounds = crop_images_to_detections(
115
+ images=images,
116
+ detections=detections,
117
+ device=self._gaze_detector.device,
118
+ )
119
+ gaze_detections = self._gaze_detector(crops, input_color_format="rgb", **kwargs)
120
+ gaze_detections_dispatched = []
121
+ for start, end in crops_images_bounds:
122
+ gaze_detections_dispatched.append(
123
+ L2CSGazeDetection(
124
+ yaw=gaze_detections.yaw[start:end],
125
+ pitch=gaze_detections.pitch[start:end],
126
+ )
127
+ )
128
+ return key_points, detections, gaze_detections_dispatched
129
+
130
+ def __call__(
131
+ self,
132
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
133
+ input_color_format: Optional[ColorFormat] = None,
134
+ conf_threshold: float = 0.25,
135
+ **kwargs,
136
+ ) -> Tuple[List[KeyPoints], List[Detections], List[L2CSGazeDetection]]:
137
+ return self.infer(
138
+ images=images,
139
+ input_color_format=input_color_format,
140
+ conf_threshold=conf_threshold,
141
+ **kwargs,
142
+ )
143
+
144
+
145
+ def crop_images_to_detections(
146
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
147
+ detections: List[Detections],
148
+ device: torch.device,
149
+ input_color_format: Optional[ColorFormat] = None,
150
+ ) -> Tuple[List[torch.Tensor], List[Tuple[int, int]]]:
151
+ if isinstance(images, np.ndarray):
152
+ input_color_format = input_color_format or "bgr"
153
+ if input_color_format != "rgb":
154
+ images = np.ascontiguousarray(images[:, :, ::-1])
155
+ prepared_images = [torch.from_numpy(images).permute(2, 0, 1).to(device)]
156
+ elif isinstance(images, torch.Tensor):
157
+ input_color_format = input_color_format or "rgb"
158
+ images = images.to(device)
159
+ if len(images.shape) == 3:
160
+ images = images.unsqueeze(dim=0)
161
+ if input_color_format != "rgb":
162
+ images = images[:, [2, 1, 0], :, :]
163
+ prepared_images = [i for i in images]
164
+ elif isinstance(images, list) and len(images) == 0:
165
+ raise ModelRuntimeError(
166
+ message="Detected empty input to the model",
167
+ help_url="https://todo",
168
+ )
169
+ elif isinstance(images, list) and isinstance(images[0], np.ndarray):
170
+ prepared_images = []
171
+ input_color_format = input_color_format or "bgr"
172
+ for image in images:
173
+ if input_color_format != "rgb":
174
+ image = np.ascontiguousarray(image[:, :, ::-1])
175
+ prepared_images.append(torch.from_numpy(image).permute(2, 0, 1).to(device))
176
+ elif isinstance(images, list) and isinstance(images[0], torch.Tensor):
177
+ prepared_images = []
178
+ input_color_format = input_color_format or "rgb"
179
+ for image in images:
180
+ if input_color_format != "rgb":
181
+ image = image[[2, 1, 0], :, :]
182
+ prepared_images.append(image.to(device))
183
+ else:
184
+ raise ModelRuntimeError(
185
+ message=f"Detected unknown input batch element: {type(images)}",
186
+ help_url="https://todo",
187
+ )
188
+ crops = []
189
+ crops_images_bounds = []
190
+ for image, image_detections in zip(prepared_images, detections):
191
+ start_bound = len(crops)
192
+ for xyxy in image_detections.xyxy:
193
+ x_min, y_min, x_max, y_max = xyxy.tolist()
194
+ crop = image[:, y_min:y_max, x_min:x_max]
195
+ if crop.numel() == 0:
196
+ continue
197
+ crops.append(crop)
198
+ end_bound = len(crops)
199
+ crops_images_bounds.append((start_bound, end_bound))
200
+ return crops, crops_images_bounds
File without changes
File without changes