inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") is None:
|
|
4
|
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
5
|
+
if os.environ.get("TOKENIZERS_PARALLELISM") is None:
|
|
6
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
7
|
+
|
|
8
|
+
from inference_models.entities import ColorFormat
|
|
9
|
+
from inference_models.model_pipelines.auto_loaders.core import AutoModelPipeline
|
|
10
|
+
from inference_models.models.auto_loaders.core import AutoModel
|
|
11
|
+
from inference_models.models.base.classification import (
|
|
12
|
+
ClassificationModel,
|
|
13
|
+
ClassificationPrediction,
|
|
14
|
+
MultiLabelClassificationModel,
|
|
15
|
+
MultiLabelClassificationPrediction,
|
|
16
|
+
)
|
|
17
|
+
from inference_models.models.base.depth_estimation import DepthEstimationModel
|
|
18
|
+
from inference_models.models.base.documents_parsing import (
|
|
19
|
+
StructuredOCRModel,
|
|
20
|
+
TextOnlyOCRModel,
|
|
21
|
+
)
|
|
22
|
+
from inference_models.models.base.embeddings import TextImageEmbeddingModel
|
|
23
|
+
from inference_models.models.base.instance_segmentation import (
|
|
24
|
+
InstanceDetections,
|
|
25
|
+
InstanceSegmentationModel,
|
|
26
|
+
)
|
|
27
|
+
from inference_models.models.base.keypoints_detection import (
|
|
28
|
+
KeyPoints,
|
|
29
|
+
KeyPointsDetectionModel,
|
|
30
|
+
)
|
|
31
|
+
from inference_models.models.base.object_detection import (
|
|
32
|
+
Detections,
|
|
33
|
+
ObjectDetectionModel,
|
|
34
|
+
OpenVocabularyObjectDetectionModel,
|
|
35
|
+
)
|
|
36
|
+
from inference_models.models.base.semantic_segmentation import SemanticSegmentationModel
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from inference_models.utils.environment import parse_comma_separated_values, str2bool
|
|
6
|
+
|
|
7
|
+
ONNXRUNTIME_EXECUTION_PROVIDERS = parse_comma_separated_values(
|
|
8
|
+
values=os.getenv(
|
|
9
|
+
"ONNXRUNTIME_EXECUTION_PROVIDERS",
|
|
10
|
+
"CUDAExecutionProvider,OpenVINOExecutionProvider,CoreMLExecutionProvider,CPUExecutionProvider",
|
|
11
|
+
)
|
|
12
|
+
.strip("[")
|
|
13
|
+
.strip("]")
|
|
14
|
+
)
|
|
15
|
+
DEFAULT_DEVICE_STR = os.getenv(
|
|
16
|
+
"DEFAULT_DEVICE",
|
|
17
|
+
("cuda" if torch.cuda.is_available() else "cpu"),
|
|
18
|
+
)
|
|
19
|
+
DEFAULT_DEVICE = torch.device(DEFAULT_DEVICE_STR)
|
|
20
|
+
ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY")
|
|
21
|
+
API_CALLS_TIMEOUT = int(os.getenv("API_CALLS_TIMEOUT", "5"))
|
|
22
|
+
API_CALLS_MAX_TRIES = int(os.getenv("API_CALLS_MAX_TRIES", "3"))
|
|
23
|
+
IDEMPOTENT_API_REQUEST_CODES_TO_RETRY = set(
|
|
24
|
+
int(e.strip())
|
|
25
|
+
for e in os.getenv(
|
|
26
|
+
"IDEMPOTENT_API_REQUEST_CODES_TO_RETRY", "408,429,502,503,504"
|
|
27
|
+
).split(",")
|
|
28
|
+
)
|
|
29
|
+
ROBOFLOW_ENVIRONMENT = os.getenv("ROBOFLOW_ENVIRONMENT", "prod")
|
|
30
|
+
ROBOFLOW_API_HOST = os.getenv(
|
|
31
|
+
"ROBOFLOW_API_HOST",
|
|
32
|
+
(
|
|
33
|
+
"https://api.roboflow.com"
|
|
34
|
+
if ROBOFLOW_ENVIRONMENT.lower() == "prod"
|
|
35
|
+
else "https://api.roboflow.one"
|
|
36
|
+
),
|
|
37
|
+
)
|
|
38
|
+
RUNNING_ON_JETSON = os.getenv("RUNNING_ON_JETSON")
|
|
39
|
+
L4T_VERSION = os.getenv("L4T_VERSION")
|
|
40
|
+
INFERENCE_HOME = os.getenv("INFERENCE_HOME", "/tmp/cache")
|
|
41
|
+
DISABLE_INTERACTIVE_PROGRESS_BARS = str2bool(
|
|
42
|
+
os.getenv("DISABLE_INTERACTIVE_PROGRESS_BARS", "False")
|
|
43
|
+
)
|
|
44
|
+
LOG_LEVEL = os.getenv("LOG_LEVEL", "WARNING")
|
|
45
|
+
VERBOSE_LOG_LEVEL = os.getenv("VERBOSE_LOG_LEVEL", "INFO")
|
|
46
|
+
DISABLE_VERBOSE_LOGGER = str2bool(os.getenv("DISABLE_VERBOSE_LOGGER", "False"))
|
|
47
|
+
AUTO_LOADER_CACHE_EXPIRATION_MINUTES = int(
|
|
48
|
+
os.getenv("AUTO_LOADER_CACHE_EXPIRATION_MINUTES", "1440")
|
|
49
|
+
)
|
|
50
|
+
ALLOW_URL_INPUT = str2bool(os.getenv("ALLOW_URL_INPUT", True))
|
|
51
|
+
ALLOW_NON_HTTPS_URL_INPUT = str2bool(os.getenv("ALLOW_NON_HTTPS_URL_INPUT", False))
|
|
52
|
+
ALLOW_URL_INPUT_WITHOUT_FQDN = str2bool(
|
|
53
|
+
os.getenv("ALLOW_URL_INPUT_WITHOUT_FQDN", False)
|
|
54
|
+
)
|
|
55
|
+
WHITELISTED_DESTINATIONS_FOR_URL_INPUT = os.getenv(
|
|
56
|
+
"WHITELISTED_DESTINATIONS_FOR_URL_INPUT"
|
|
57
|
+
)
|
|
58
|
+
if WHITELISTED_DESTINATIONS_FOR_URL_INPUT is not None:
|
|
59
|
+
WHITELISTED_DESTINATIONS_FOR_URL_INPUT = parse_comma_separated_values(
|
|
60
|
+
WHITELISTED_DESTINATIONS_FOR_URL_INPUT
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
BLACKLISTED_DESTINATIONS_FOR_URL_INPUT = os.getenv(
|
|
64
|
+
"BLACKLISTED_DESTINATIONS_FOR_URL_INPUT"
|
|
65
|
+
)
|
|
66
|
+
if BLACKLISTED_DESTINATIONS_FOR_URL_INPUT is not None:
|
|
67
|
+
BLACKLISTED_DESTINATIONS_FOR_URL_INPUT = parse_comma_separated_values(
|
|
68
|
+
BLACKLISTED_DESTINATIONS_FOR_URL_INPUT
|
|
69
|
+
)
|
|
70
|
+
ALLOW_LOCAL_STORAGE_ACCESS_FOR_REFERENCE_DATA = os.getenv(
|
|
71
|
+
"ALLOW_LOCAL_STORAGE_ACCESS_FOR_REFERENCE_DATA"
|
|
72
|
+
)
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class BaseInferenceError(Exception):
|
|
5
|
+
|
|
6
|
+
def __init__(self, message: str, help_url: Optional[str] = None):
|
|
7
|
+
super().__init__(message)
|
|
8
|
+
self._help_url = help_url
|
|
9
|
+
|
|
10
|
+
@property
|
|
11
|
+
def help_url(self) -> Optional[str]:
|
|
12
|
+
return self._help_url
|
|
13
|
+
|
|
14
|
+
def __str__(self) -> str:
|
|
15
|
+
if self._help_url is None:
|
|
16
|
+
return super().__str__()
|
|
17
|
+
return f"{super().__str__()} - VISIT {self._help_url} FOR FURTHER SUPPORT"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AssumptionError(BaseInferenceError):
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EnvironmentConfigurationError(BaseInferenceError):
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ModelRuntimeError(BaseInferenceError):
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ModelInputError(BaseInferenceError):
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class RetryError(BaseInferenceError):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ModelRetrievalError(BaseInferenceError):
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class UntrustedFileError(BaseInferenceError):
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class FileHashSumMissmatch(BaseInferenceError):
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class UnauthorizedModelAccessError(ModelRetrievalError):
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ModelMetadataConsistencyError(ModelRetrievalError):
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ModelMetadataHandlerNotImplementedError(ModelRetrievalError):
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class InvalidEnvVariable(BaseInferenceError):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ModelPackageNegotiationError(BaseInferenceError):
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class UnknownBackendTypeError(ModelPackageNegotiationError):
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class UnknownQuantizationError(ModelPackageNegotiationError):
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class InvalidRequestedBatchSizeError(ModelPackageNegotiationError):
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class RuntimeIntrospectionError(ModelPackageNegotiationError):
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class JetsonTypeResolutionError(RuntimeIntrospectionError):
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class NoModelPackagesAvailableError(ModelPackageNegotiationError):
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class AmbiguousModelPackageResolutionError(ModelPackageNegotiationError):
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class ModelLoadingError(BaseInferenceError):
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class InsecureModelIdentifierError(ModelLoadingError):
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class DirectLocalStorageAccessError(ModelLoadingError):
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class ModelImplementationLoaderError(ModelLoadingError):
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class CorruptedModelPackageError(ModelLoadingError):
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class MissingDependencyError(BaseInferenceError):
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class InvalidParameterError(BaseInferenceError):
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class DependencyModelParametersValidationError(ModelLoadingError):
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class ModelPipelineInitializationError(ModelLoadingError):
|
|
133
|
+
pass
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class ModelPipelineNotFound(ModelPipelineInitializationError):
|
|
137
|
+
pass
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from inference_models.configuration import (
|
|
4
|
+
DISABLE_VERBOSE_LOGGER,
|
|
5
|
+
LOG_LEVEL,
|
|
6
|
+
VERBOSE_LOG_LEVEL,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def configure_log_level(
|
|
11
|
+
logger: logging.Logger, log_level: str, fallback_level: int
|
|
12
|
+
) -> None:
|
|
13
|
+
log_level = getattr(logging, log_level, fallback_level)
|
|
14
|
+
logger.setLevel(log_level)
|
|
15
|
+
if not logger.handlers:
|
|
16
|
+
handler = logging.StreamHandler()
|
|
17
|
+
formatter = logging.Formatter("%(message)s")
|
|
18
|
+
handler.setFormatter(formatter)
|
|
19
|
+
logger.addHandler(handler)
|
|
20
|
+
for handler in logger.handlers:
|
|
21
|
+
handler.setLevel(log_level)
|
|
22
|
+
logger.propagate = False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
LOGGER = logging.getLogger("inference-models")
|
|
26
|
+
configure_log_level(logger=LOGGER, log_level=LOG_LEVEL, fallback_level=logging.WARNING)
|
|
27
|
+
VERBOSE_LOGGER = logging.getLogger("inference-models-verbose")
|
|
28
|
+
configure_log_level(
|
|
29
|
+
logger=VERBOSE_LOGGER, log_level=VERBOSE_LOG_LEVEL, fallback_level=logging.INFO
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def verbose_info(
|
|
34
|
+
message: str,
|
|
35
|
+
verbose_requested: bool = True,
|
|
36
|
+
) -> None:
|
|
37
|
+
if DISABLE_VERBOSE_LOGGER:
|
|
38
|
+
return None
|
|
39
|
+
if not verbose_requested:
|
|
40
|
+
return None
|
|
41
|
+
VERBOSE_LOGGER.info(message)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def verbose_debug(
|
|
45
|
+
message: str,
|
|
46
|
+
verbose_requested: bool = True,
|
|
47
|
+
) -> None:
|
|
48
|
+
if DISABLE_VERBOSE_LOGGER:
|
|
49
|
+
return None
|
|
50
|
+
if not verbose_requested:
|
|
51
|
+
return None
|
|
52
|
+
VERBOSE_LOGGER.debug(message)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
|
+
|
|
3
|
+
from rich.console import Console
|
|
4
|
+
from rich.tree import Tree
|
|
5
|
+
|
|
6
|
+
from inference_models.errors import ModelPipelineInitializationError
|
|
7
|
+
from inference_models.logger import verbose_info
|
|
8
|
+
from inference_models.model_pipelines.auto_loaders.pipelines_registry import (
|
|
9
|
+
REGISTERED_PIPELINES,
|
|
10
|
+
get_default_pipeline_parameters,
|
|
11
|
+
resolve_pipeline_class,
|
|
12
|
+
)
|
|
13
|
+
from inference_models.models.auto_loaders.access_manager import ModelAccessManager
|
|
14
|
+
from inference_models.models.auto_loaders.auto_resolution_cache import (
|
|
15
|
+
AutoResolutionCache,
|
|
16
|
+
)
|
|
17
|
+
from inference_models.models.auto_loaders.core import AutoModel
|
|
18
|
+
from inference_models.models.auto_loaders.dependency_models import (
|
|
19
|
+
DependencyModelParameters,
|
|
20
|
+
prepare_dependency_model_parameters,
|
|
21
|
+
)
|
|
22
|
+
from inference_models.models.auto_loaders.entities import AnyModel
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AutoModelPipeline:
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def list_available_pipelines(cls) -> None:
|
|
29
|
+
console = Console()
|
|
30
|
+
tree = Tree("Available Model Pipelines:")
|
|
31
|
+
for pipeline_id in sorted(REGISTERED_PIPELINES):
|
|
32
|
+
tree.add(pipeline_id)
|
|
33
|
+
console.print(tree)
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def from_pretrained(
|
|
37
|
+
cls,
|
|
38
|
+
pipline_id: str,
|
|
39
|
+
models_parameters: Optional[
|
|
40
|
+
List[Optional[Union[str, dict, DependencyModelParameters]]]
|
|
41
|
+
] = None,
|
|
42
|
+
weights_provider: str = "roboflow",
|
|
43
|
+
api_key: Optional[str] = None,
|
|
44
|
+
max_package_loading_attempts: Optional[int] = None,
|
|
45
|
+
verbose: bool = False,
|
|
46
|
+
model_download_file_lock_acquire_timeout: int = 10,
|
|
47
|
+
allow_untrusted_packages: bool = False,
|
|
48
|
+
trt_engine_host_code_allowed: bool = True,
|
|
49
|
+
allow_local_code_packages: bool = True,
|
|
50
|
+
verify_hash_while_download: bool = True,
|
|
51
|
+
download_files_without_hash: bool = False,
|
|
52
|
+
use_auto_resolution_cache: bool = True,
|
|
53
|
+
auto_resolution_cache: Optional[AutoResolutionCache] = None,
|
|
54
|
+
allow_direct_local_storage_loading: bool = True,
|
|
55
|
+
model_access_manager: Optional[ModelAccessManager] = None,
|
|
56
|
+
**kwargs,
|
|
57
|
+
) -> AnyModel:
|
|
58
|
+
pipeline_class = resolve_pipeline_class(pipline_id=pipline_id)
|
|
59
|
+
models = []
|
|
60
|
+
verbose_info(
|
|
61
|
+
message=f"Initializing models for pipeline `{pipline_id}`",
|
|
62
|
+
verbose_requested=verbose,
|
|
63
|
+
)
|
|
64
|
+
default_parameters = get_default_pipeline_parameters(pipline_id=pipline_id)
|
|
65
|
+
if models_parameters is None and default_parameters is None:
|
|
66
|
+
raise ModelPipelineInitializationError(
|
|
67
|
+
message=f"Could not initialize model pipeline `{pipline_id}` - models parameters not provided and "
|
|
68
|
+
f"default values not registered in the library. If you run locally, please verify your "
|
|
69
|
+
f"integration - it must specify the models to be used by the pipeline. If you use Roboflow "
|
|
70
|
+
f"hosted solution, contact us to get help.",
|
|
71
|
+
help_url="https://todo",
|
|
72
|
+
)
|
|
73
|
+
if models_parameters is None:
|
|
74
|
+
models_parameters = default_parameters
|
|
75
|
+
if default_parameters is None:
|
|
76
|
+
default_parameters = [None] * len(models_parameters)
|
|
77
|
+
for idx, model_parameters in enumerate(models_parameters):
|
|
78
|
+
if model_parameters is None:
|
|
79
|
+
parameters_to_be_used = (
|
|
80
|
+
default_parameters[idx] if idx < len(default_parameters) else None
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
parameters_to_be_used = model_parameters
|
|
84
|
+
resolved_model_parameters = prepare_dependency_model_parameters(
|
|
85
|
+
model_parameters=parameters_to_be_used
|
|
86
|
+
)
|
|
87
|
+
verbose_info(
|
|
88
|
+
message=f"Initializing model: `{resolved_model_parameters.model_id_or_path}`",
|
|
89
|
+
verbose_requested=verbose,
|
|
90
|
+
)
|
|
91
|
+
model = AutoModel.from_pretrained(
|
|
92
|
+
model_id_or_path=resolved_model_parameters.model_id_or_path,
|
|
93
|
+
weights_provider=weights_provider,
|
|
94
|
+
api_key=api_key,
|
|
95
|
+
model_package_id=resolved_model_parameters.model_package_id,
|
|
96
|
+
backend=resolved_model_parameters.backend,
|
|
97
|
+
batch_size=resolved_model_parameters.batch_size,
|
|
98
|
+
quantization=resolved_model_parameters.quantization,
|
|
99
|
+
onnx_execution_providers=resolved_model_parameters.onnx_execution_providers,
|
|
100
|
+
device=resolved_model_parameters.device,
|
|
101
|
+
default_onnx_trt_options=resolved_model_parameters.default_onnx_trt_options,
|
|
102
|
+
max_package_loading_attempts=max_package_loading_attempts,
|
|
103
|
+
verbose=verbose,
|
|
104
|
+
model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
105
|
+
allow_untrusted_packages=allow_untrusted_packages,
|
|
106
|
+
trt_engine_host_code_allowed=trt_engine_host_code_allowed,
|
|
107
|
+
allow_local_code_packages=allow_local_code_packages,
|
|
108
|
+
verify_hash_while_download=verify_hash_while_download,
|
|
109
|
+
download_files_without_hash=download_files_without_hash,
|
|
110
|
+
use_auto_resolution_cache=use_auto_resolution_cache,
|
|
111
|
+
auto_resolution_cache=auto_resolution_cache,
|
|
112
|
+
allow_direct_local_storage_loading=allow_direct_local_storage_loading,
|
|
113
|
+
model_access_manager=model_access_manager,
|
|
114
|
+
nms_fusion_preferences=resolved_model_parameters.nms_fusion_preferences,
|
|
115
|
+
model_type=resolved_model_parameters.model_type,
|
|
116
|
+
task_type=resolved_model_parameters.task_type,
|
|
117
|
+
**resolved_model_parameters.kwargs,
|
|
118
|
+
)
|
|
119
|
+
models.append(model)
|
|
120
|
+
return pipeline_class.with_models(models, **kwargs)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Dict, List, Optional, Union
|
|
2
|
+
|
|
3
|
+
from inference_models.errors import ModelPipelineNotFound
|
|
4
|
+
from inference_models.utils.imports import LazyClass
|
|
5
|
+
|
|
6
|
+
REGISTERED_PIPELINES: Dict[str, LazyClass] = {
|
|
7
|
+
"face-and-gaze-detection": LazyClass(
|
|
8
|
+
module_name="inference_models.model_pipelines.face_and_gaze_detection.mediapipe_l2cs",
|
|
9
|
+
class_name="FaceAndGazeDetectionMPAndL2CS",
|
|
10
|
+
)
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
DEFAULT_PIPELINES_PARAMETERS: Dict[str, List[Union[str, dict]]] = {
|
|
14
|
+
"face-and-gaze-detection": [
|
|
15
|
+
"mediapipe/face-detector",
|
|
16
|
+
"l2cs-net/rn50",
|
|
17
|
+
]
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def resolve_pipeline_class(pipline_id: str) -> type:
|
|
22
|
+
if pipline_id not in REGISTERED_PIPELINES:
|
|
23
|
+
raise ModelPipelineNotFound(
|
|
24
|
+
message=f"Could not find model pipeline with id: `{pipline_id}`. "
|
|
25
|
+
f"Registered pipelines: {list(REGISTERED_PIPELINES.keys())}. This error ma be caused by typo "
|
|
26
|
+
f"in the identifier, or pipeline is not registered / not supported in the environment you try to "
|
|
27
|
+
f"run it.",
|
|
28
|
+
help_url="https://todo",
|
|
29
|
+
)
|
|
30
|
+
return REGISTERED_PIPELINES[pipline_id].resolve()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_default_pipeline_parameters(
|
|
34
|
+
pipline_id: str,
|
|
35
|
+
) -> Optional[List[Union[str, dict]]]:
|
|
36
|
+
return DEFAULT_PIPELINES_PARAMETERS.get(pipline_id)
|
|
File without changes
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
from typing import Any, List, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from inference_models import Detections, KeyPoints
|
|
7
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
8
|
+
from inference_models.entities import ColorFormat
|
|
9
|
+
from inference_models.errors import ModelPipelineInitializationError, ModelRuntimeError
|
|
10
|
+
from inference_models.models.l2cs.l2cs_onnx import (
|
|
11
|
+
DEFAULT_GAZE_MAX_BATCH_SIZE,
|
|
12
|
+
L2CSGazeDetection,
|
|
13
|
+
L2CSNetOnnx,
|
|
14
|
+
)
|
|
15
|
+
from inference_models.models.mediapipe_face_detection.face_detection import (
|
|
16
|
+
MediaPipeFaceDetector,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FaceAndGazeDetectionMPAndL2CS:
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def with_models(
|
|
24
|
+
cls, models: List[Any], **kwargs
|
|
25
|
+
) -> "FaceAndGazeDetectionMPAndL2CS":
|
|
26
|
+
if len(models) != 2:
|
|
27
|
+
raise ModelPipelineInitializationError(
|
|
28
|
+
message="Model pipeline `face-and-gaze-detection` requires two models tu run - face detector "
|
|
29
|
+
"and gaze detector. If you run `inference` locally, verify the parameter of pipeline loader "
|
|
30
|
+
"to make sure that two models parameters' are provided. If you use Roboflow hosted solution, "
|
|
31
|
+
"contact us to get help.",
|
|
32
|
+
help_url="https://todo",
|
|
33
|
+
)
|
|
34
|
+
face_detector, gaze_detector = models
|
|
35
|
+
if not isinstance(face_detector, MediaPipeFaceDetector):
|
|
36
|
+
raise ModelPipelineInitializationError(
|
|
37
|
+
message="Model pipeline `face-and-gaze-detection` requires first model to be `MediaPipeFaceDetector` - "
|
|
38
|
+
"if you run `inference` locally, make sure that you initialized the pipeline pointing model of "
|
|
39
|
+
"matching type.",
|
|
40
|
+
help_url="https://todo",
|
|
41
|
+
)
|
|
42
|
+
if not isinstance(gaze_detector, L2CSNetOnnx):
|
|
43
|
+
raise ModelPipelineInitializationError(
|
|
44
|
+
message="Model pipeline `face-and-gaze-detection` requires second model to be `L2CSNet` - "
|
|
45
|
+
"if you run `inference` locally, make sure that you initialized the pipeline pointing model of "
|
|
46
|
+
"matching type.",
|
|
47
|
+
help_url="https://todo",
|
|
48
|
+
)
|
|
49
|
+
return FaceAndGazeDetectionMPAndL2CS.from_pretrained(
|
|
50
|
+
face_detector=face_detector, gaze_detector=gaze_detector, **kwargs
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def from_pretrained(
|
|
55
|
+
cls,
|
|
56
|
+
face_detector: Union[str, MediaPipeFaceDetector],
|
|
57
|
+
gaze_detector: Union[str, L2CSNetOnnx],
|
|
58
|
+
onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
|
|
59
|
+
default_onnx_trt_options: bool = True,
|
|
60
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
61
|
+
max_batch_size: int = DEFAULT_GAZE_MAX_BATCH_SIZE,
|
|
62
|
+
**kwargs,
|
|
63
|
+
) -> "FaceAndGazeDetectionMPAndL2CS":
|
|
64
|
+
if isinstance(face_detector, str):
|
|
65
|
+
face_detector = MediaPipeFaceDetector.from_pretrained(
|
|
66
|
+
model_name_or_path=face_detector
|
|
67
|
+
)
|
|
68
|
+
if isinstance(gaze_detector, str):
|
|
69
|
+
gaze_detector = L2CSNetOnnx.from_pretrained(
|
|
70
|
+
model_name_or_path=gaze_detector,
|
|
71
|
+
onnx_execution_providers=onnx_execution_providers,
|
|
72
|
+
default_onnx_trt_options=default_onnx_trt_options,
|
|
73
|
+
device=device,
|
|
74
|
+
max_batch_size=max_batch_size,
|
|
75
|
+
)
|
|
76
|
+
return cls(
|
|
77
|
+
face_detector=face_detector,
|
|
78
|
+
gaze_detector=gaze_detector,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
face_detector: MediaPipeFaceDetector,
|
|
84
|
+
gaze_detector: L2CSNetOnnx,
|
|
85
|
+
):
|
|
86
|
+
self._face_detector = face_detector
|
|
87
|
+
self._gaze_detector = gaze_detector
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def class_names(self) -> List[str]:
|
|
91
|
+
return self._face_detector.class_names
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def key_points_classes(self) -> List[List[str]]:
|
|
95
|
+
return self._face_detector.key_points_classes
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def skeletons(self) -> List[List[Tuple[int, int]]]:
|
|
99
|
+
return [[(5, 1), (1, 2), (4, 0), (0, 2), (2, 3)]]
|
|
100
|
+
|
|
101
|
+
def infer(
|
|
102
|
+
self,
|
|
103
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
104
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
105
|
+
conf_threshold: float = 0.25,
|
|
106
|
+
**kwargs,
|
|
107
|
+
) -> Tuple[List[KeyPoints], List[Detections], List[L2CSGazeDetection]]:
|
|
108
|
+
key_points, detections = self._face_detector(
|
|
109
|
+
images,
|
|
110
|
+
input_color_format=input_color_format,
|
|
111
|
+
conf_thresh=conf_threshold,
|
|
112
|
+
**kwargs,
|
|
113
|
+
)
|
|
114
|
+
crops, crops_images_bounds = crop_images_to_detections(
|
|
115
|
+
images=images,
|
|
116
|
+
detections=detections,
|
|
117
|
+
device=self._gaze_detector.device,
|
|
118
|
+
)
|
|
119
|
+
gaze_detections = self._gaze_detector(crops, input_color_format="rgb", **kwargs)
|
|
120
|
+
gaze_detections_dispatched = []
|
|
121
|
+
for start, end in crops_images_bounds:
|
|
122
|
+
gaze_detections_dispatched.append(
|
|
123
|
+
L2CSGazeDetection(
|
|
124
|
+
yaw=gaze_detections.yaw[start:end],
|
|
125
|
+
pitch=gaze_detections.pitch[start:end],
|
|
126
|
+
)
|
|
127
|
+
)
|
|
128
|
+
return key_points, detections, gaze_detections_dispatched
|
|
129
|
+
|
|
130
|
+
def __call__(
|
|
131
|
+
self,
|
|
132
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
133
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
134
|
+
conf_threshold: float = 0.25,
|
|
135
|
+
**kwargs,
|
|
136
|
+
) -> Tuple[List[KeyPoints], List[Detections], List[L2CSGazeDetection]]:
|
|
137
|
+
return self.infer(
|
|
138
|
+
images=images,
|
|
139
|
+
input_color_format=input_color_format,
|
|
140
|
+
conf_threshold=conf_threshold,
|
|
141
|
+
**kwargs,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def crop_images_to_detections(
|
|
146
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
147
|
+
detections: List[Detections],
|
|
148
|
+
device: torch.device,
|
|
149
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
150
|
+
) -> Tuple[List[torch.Tensor], List[Tuple[int, int]]]:
|
|
151
|
+
if isinstance(images, np.ndarray):
|
|
152
|
+
input_color_format = input_color_format or "bgr"
|
|
153
|
+
if input_color_format != "rgb":
|
|
154
|
+
images = np.ascontiguousarray(images[:, :, ::-1])
|
|
155
|
+
prepared_images = [torch.from_numpy(images).permute(2, 0, 1).to(device)]
|
|
156
|
+
elif isinstance(images, torch.Tensor):
|
|
157
|
+
input_color_format = input_color_format or "rgb"
|
|
158
|
+
images = images.to(device)
|
|
159
|
+
if len(images.shape) == 3:
|
|
160
|
+
images = images.unsqueeze(dim=0)
|
|
161
|
+
if input_color_format != "rgb":
|
|
162
|
+
images = images[:, [2, 1, 0], :, :]
|
|
163
|
+
prepared_images = [i for i in images]
|
|
164
|
+
elif isinstance(images, list) and len(images) == 0:
|
|
165
|
+
raise ModelRuntimeError(
|
|
166
|
+
message="Detected empty input to the model",
|
|
167
|
+
help_url="https://todo",
|
|
168
|
+
)
|
|
169
|
+
elif isinstance(images, list) and isinstance(images[0], np.ndarray):
|
|
170
|
+
prepared_images = []
|
|
171
|
+
input_color_format = input_color_format or "bgr"
|
|
172
|
+
for image in images:
|
|
173
|
+
if input_color_format != "rgb":
|
|
174
|
+
image = np.ascontiguousarray(image[:, :, ::-1])
|
|
175
|
+
prepared_images.append(torch.from_numpy(image).permute(2, 0, 1).to(device))
|
|
176
|
+
elif isinstance(images, list) and isinstance(images[0], torch.Tensor):
|
|
177
|
+
prepared_images = []
|
|
178
|
+
input_color_format = input_color_format or "rgb"
|
|
179
|
+
for image in images:
|
|
180
|
+
if input_color_format != "rgb":
|
|
181
|
+
image = image[[2, 1, 0], :, :]
|
|
182
|
+
prepared_images.append(image.to(device))
|
|
183
|
+
else:
|
|
184
|
+
raise ModelRuntimeError(
|
|
185
|
+
message=f"Detected unknown input batch element: {type(images)}",
|
|
186
|
+
help_url="https://todo",
|
|
187
|
+
)
|
|
188
|
+
crops = []
|
|
189
|
+
crops_images_bounds = []
|
|
190
|
+
for image, image_detections in zip(prepared_images, detections):
|
|
191
|
+
start_bound = len(crops)
|
|
192
|
+
for xyxy in image_detections.xyxy:
|
|
193
|
+
x_min, y_min, x_max, y_max = xyxy.tolist()
|
|
194
|
+
crop = image[:, y_min:y_max, x_min:x_max]
|
|
195
|
+
if crop.numel() == 0:
|
|
196
|
+
continue
|
|
197
|
+
crops.append(crop)
|
|
198
|
+
end_bound = len(crops)
|
|
199
|
+
crops_images_bounds.append((start_bound, end_bound))
|
|
200
|
+
return crops, crops_images_bounds
|
|
File without changes
|
|
File without changes
|