inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1341 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import importlib
|
|
3
|
+
import importlib.util
|
|
4
|
+
import os.path
|
|
5
|
+
import re
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from functools import partial
|
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from argon2 import PasswordHasher
|
|
12
|
+
from filelock import FileLock
|
|
13
|
+
from rich.console import Console
|
|
14
|
+
from rich.text import Text
|
|
15
|
+
|
|
16
|
+
from inference_models.configuration import DEFAULT_DEVICE, INFERENCE_HOME
|
|
17
|
+
from inference_models.errors import (
|
|
18
|
+
CorruptedModelPackageError,
|
|
19
|
+
DirectLocalStorageAccessError,
|
|
20
|
+
InsecureModelIdentifierError,
|
|
21
|
+
ModelLoadingError,
|
|
22
|
+
NoModelPackagesAvailableError,
|
|
23
|
+
UnauthorizedModelAccessError,
|
|
24
|
+
)
|
|
25
|
+
from inference_models.logger import LOGGER, verbose_info
|
|
26
|
+
from inference_models.models.auto_loaders.access_manager import (
|
|
27
|
+
AccessIdentifiers,
|
|
28
|
+
LiberalModelAccessManager,
|
|
29
|
+
ModelAccessManager,
|
|
30
|
+
)
|
|
31
|
+
from inference_models.models.auto_loaders.auto_negotiation import (
|
|
32
|
+
negotiate_model_packages,
|
|
33
|
+
parse_backend_type,
|
|
34
|
+
)
|
|
35
|
+
from inference_models.models.auto_loaders.auto_resolution_cache import (
|
|
36
|
+
AutoResolutionCache,
|
|
37
|
+
AutoResolutionCacheEntry,
|
|
38
|
+
BaseAutoLoadMetadataCache,
|
|
39
|
+
)
|
|
40
|
+
from inference_models.models.auto_loaders.constants import (
|
|
41
|
+
MODEL_DEPENDENCIES_KEY,
|
|
42
|
+
MODEL_DEPENDENCIES_SUB_DIR,
|
|
43
|
+
)
|
|
44
|
+
from inference_models.models.auto_loaders.dependency_models import (
|
|
45
|
+
prepare_dependency_model_parameters,
|
|
46
|
+
)
|
|
47
|
+
from inference_models.models.auto_loaders.entities import (
|
|
48
|
+
MODEL_CONFIG_FILE_NAME,
|
|
49
|
+
AnyModel,
|
|
50
|
+
BackendType,
|
|
51
|
+
InferenceModelConfig,
|
|
52
|
+
ModelArchitecture,
|
|
53
|
+
TaskType,
|
|
54
|
+
)
|
|
55
|
+
from inference_models.models.auto_loaders.models_registry import (
|
|
56
|
+
INSTANCE_SEGMENTATION_TASK,
|
|
57
|
+
OBJECT_DETECTION_TASK,
|
|
58
|
+
resolve_model_class,
|
|
59
|
+
)
|
|
60
|
+
from inference_models.models.auto_loaders.presentation_utils import (
|
|
61
|
+
calculate_artefacts_size,
|
|
62
|
+
calculate_size_of_all_model_packages_artefacts,
|
|
63
|
+
render_model_package_details_table,
|
|
64
|
+
render_runtime_x_ray,
|
|
65
|
+
render_table_with_model_overview,
|
|
66
|
+
render_table_with_model_packages,
|
|
67
|
+
)
|
|
68
|
+
from inference_models.runtime_introspection.core import x_ray_runtime_environment
|
|
69
|
+
from inference_models.utils.download import FileHandle, download_files_to_directory
|
|
70
|
+
from inference_models.utils.file_system import dump_json, read_json
|
|
71
|
+
from inference_models.utils.hashing import hash_dict_content
|
|
72
|
+
from inference_models.weights_providers.core import get_model_from_provider
|
|
73
|
+
from inference_models.weights_providers.entities import (
|
|
74
|
+
ModelDependency,
|
|
75
|
+
ModelPackageMetadata,
|
|
76
|
+
Quantization,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
MODEL_TYPES_TO_LOAD_FROM_CHECKPOINT = {
|
|
80
|
+
"rfdetr-base",
|
|
81
|
+
"rfdetr-small",
|
|
82
|
+
"rfdetr-medium",
|
|
83
|
+
"rfdetr-nano",
|
|
84
|
+
"rfdetr-large",
|
|
85
|
+
"rfdetr-seg-preview",
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
DEFAULT_KWARGS_PARAMS_TO_BE_FORWARDED_TO_DEPENDENT_MODELS = [
|
|
89
|
+
"owlv2_enforce_model_compilation",
|
|
90
|
+
"owlv2_class_embeddings_cache",
|
|
91
|
+
"owlv2_images_embeddings_cache",
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class AutoModel:
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def describe_model(
|
|
99
|
+
cls,
|
|
100
|
+
model_id: str,
|
|
101
|
+
weights_provider: str = "roboflow",
|
|
102
|
+
api_key: Optional[str] = None,
|
|
103
|
+
pull_artefacts_size: bool = False,
|
|
104
|
+
) -> None:
|
|
105
|
+
model_metadata = get_model_from_provider(
|
|
106
|
+
provider=weights_provider,
|
|
107
|
+
model_id=model_id,
|
|
108
|
+
api_key=api_key,
|
|
109
|
+
)
|
|
110
|
+
model_packages_size = None
|
|
111
|
+
if pull_artefacts_size:
|
|
112
|
+
model_packages_size = calculate_size_of_all_model_packages_artefacts(
|
|
113
|
+
model_packages=model_metadata.model_packages
|
|
114
|
+
)
|
|
115
|
+
console = Console()
|
|
116
|
+
model_overview_table = render_table_with_model_overview(
|
|
117
|
+
model_id=model_metadata.model_id,
|
|
118
|
+
requested_model_id=model_id,
|
|
119
|
+
model_architecture=model_metadata.model_architecture,
|
|
120
|
+
model_variant=model_metadata.model_variant,
|
|
121
|
+
task_type=model_metadata.task_type,
|
|
122
|
+
weights_provider=weights_provider,
|
|
123
|
+
registered_packages=len(model_metadata.model_packages),
|
|
124
|
+
model_dependencies=model_metadata.model_dependencies,
|
|
125
|
+
)
|
|
126
|
+
console.print(model_overview_table)
|
|
127
|
+
console.print("\n")
|
|
128
|
+
packages_overview_table = render_table_with_model_packages(
|
|
129
|
+
model_packages=model_metadata.model_packages,
|
|
130
|
+
model_packages_size=model_packages_size,
|
|
131
|
+
)
|
|
132
|
+
console.print(packages_overview_table)
|
|
133
|
+
text = Text.assemble(
|
|
134
|
+
("\nWant to check more details about specific package?", "bold"),
|
|
135
|
+
"\nUse AutoModel.describe_model_package('model_id', 'package_id').",
|
|
136
|
+
)
|
|
137
|
+
console.print(text)
|
|
138
|
+
if not pull_artefacts_size:
|
|
139
|
+
text = Text.assemble(
|
|
140
|
+
("\nWant to verify the size of model package?", "bold"),
|
|
141
|
+
"\nUse AutoModel.describe_model('model_id', pull_artefacts_size=True) - the execution will be "
|
|
142
|
+
"slightly longer, as we must collect the size of all elements of model package.",
|
|
143
|
+
)
|
|
144
|
+
console.print(text)
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
def describe_model_package(
|
|
148
|
+
cls,
|
|
149
|
+
model_id: str,
|
|
150
|
+
package_id: str,
|
|
151
|
+
weights_provider: str = "roboflow",
|
|
152
|
+
api_key: Optional[str] = None,
|
|
153
|
+
pull_artefacts_size: bool = True,
|
|
154
|
+
) -> None:
|
|
155
|
+
model_metadata = get_model_from_provider(
|
|
156
|
+
provider=weights_provider,
|
|
157
|
+
model_id=model_id,
|
|
158
|
+
api_key=api_key,
|
|
159
|
+
)
|
|
160
|
+
selected_package = None
|
|
161
|
+
for package in model_metadata.model_packages:
|
|
162
|
+
if package.package_id == package_id:
|
|
163
|
+
selected_package = package
|
|
164
|
+
if selected_package is None:
|
|
165
|
+
raise NoModelPackagesAvailableError(
|
|
166
|
+
message=f"Selected model package {package_id} does not exist for model {model_id}. Make sure provided "
|
|
167
|
+
f"value is valid.",
|
|
168
|
+
help_url="https://todo",
|
|
169
|
+
)
|
|
170
|
+
artefacts_size = None
|
|
171
|
+
if pull_artefacts_size:
|
|
172
|
+
artefacts_size = calculate_artefacts_size(
|
|
173
|
+
package_artefacts=selected_package.package_artefacts
|
|
174
|
+
)
|
|
175
|
+
table = render_model_package_details_table(
|
|
176
|
+
model_id=model_metadata.model_id,
|
|
177
|
+
requested_model_id=model_id,
|
|
178
|
+
artefacts_size=artefacts_size,
|
|
179
|
+
model_package=selected_package,
|
|
180
|
+
)
|
|
181
|
+
console = Console()
|
|
182
|
+
console.print(table)
|
|
183
|
+
if not pull_artefacts_size:
|
|
184
|
+
text = Text.assemble(
|
|
185
|
+
("\nWant to verify the size of model package?", "bold"),
|
|
186
|
+
"\nUse AutoModel.describe_model_package('model_id', 'package_id', pull_artefacts_size=True)"
|
|
187
|
+
"- the execution will be slightly longer, as we must collect the size of all elements of model package.",
|
|
188
|
+
)
|
|
189
|
+
console.print(text)
|
|
190
|
+
|
|
191
|
+
@classmethod
|
|
192
|
+
def describe_compute_environment(cls) -> None:
|
|
193
|
+
runtime_x_ray = x_ray_runtime_environment()
|
|
194
|
+
table = render_runtime_x_ray(runtime_x_ray=runtime_x_ray)
|
|
195
|
+
console = Console()
|
|
196
|
+
console.print(table)
|
|
197
|
+
|
|
198
|
+
@classmethod
|
|
199
|
+
def from_pretrained(
|
|
200
|
+
cls,
|
|
201
|
+
model_id_or_path: str,
|
|
202
|
+
weights_provider: str = "roboflow",
|
|
203
|
+
api_key: Optional[str] = None,
|
|
204
|
+
model_package_id: Optional[str] = None,
|
|
205
|
+
backend: Optional[
|
|
206
|
+
Union[str, BackendType, List[Union[str, BackendType]]]
|
|
207
|
+
] = None,
|
|
208
|
+
batch_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
209
|
+
quantization: Optional[
|
|
210
|
+
Union[str, Quantization, List[Union[str, Quantization]]]
|
|
211
|
+
] = None,
|
|
212
|
+
onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
|
|
213
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
214
|
+
default_onnx_trt_options: bool = True,
|
|
215
|
+
max_package_loading_attempts: Optional[int] = None,
|
|
216
|
+
verbose: bool = False,
|
|
217
|
+
model_download_file_lock_acquire_timeout: int = 10,
|
|
218
|
+
allow_untrusted_packages: bool = False,
|
|
219
|
+
trt_engine_host_code_allowed: bool = True,
|
|
220
|
+
allow_local_code_packages: bool = True,
|
|
221
|
+
verify_hash_while_download: bool = True,
|
|
222
|
+
download_files_without_hash: bool = False,
|
|
223
|
+
use_auto_resolution_cache: bool = True,
|
|
224
|
+
auto_resolution_cache: Optional[AutoResolutionCache] = None,
|
|
225
|
+
allow_direct_local_storage_loading: bool = True,
|
|
226
|
+
model_access_manager: Optional[ModelAccessManager] = None,
|
|
227
|
+
nms_fusion_preferences: Optional[Union[bool, dict]] = None,
|
|
228
|
+
model_type: Optional[str] = None,
|
|
229
|
+
task_type: Optional[str] = None,
|
|
230
|
+
allow_loading_dependency_models: bool = True,
|
|
231
|
+
dependency_models_params: Optional[dict] = None,
|
|
232
|
+
point_model_directory: Optional[Callable[[str], None]] = None,
|
|
233
|
+
forwarded_kwargs: Optional[List[str]] = None,
|
|
234
|
+
**kwargs,
|
|
235
|
+
) -> AnyModel:
|
|
236
|
+
if model_access_manager is None:
|
|
237
|
+
model_access_manager = LiberalModelAccessManager()
|
|
238
|
+
if model_access_manager.is_model_access_forbidden(
|
|
239
|
+
model_id=model_id_or_path, api_key=api_key
|
|
240
|
+
):
|
|
241
|
+
raise UnauthorizedModelAccessError(
|
|
242
|
+
message=f"Unauthorized not access model with ID: {model_package_id}. Are you sure you use valid "
|
|
243
|
+
f"API key? The default weights provider is Roboflow - see Roboflow authentication details: "
|
|
244
|
+
f"https://docs.roboflow.com/api-reference/authentication "
|
|
245
|
+
f"and export key to `ROBOFLOW_API_KEY` environment variable. If you use custom weights "
|
|
246
|
+
f"provider - verify access constraints relevant for the provider.",
|
|
247
|
+
help_url="https://todo",
|
|
248
|
+
)
|
|
249
|
+
if auto_resolution_cache is None:
|
|
250
|
+
|
|
251
|
+
def register_file_created_for_model_package(
|
|
252
|
+
file_path: str, model_id: str, package_id: str
|
|
253
|
+
) -> None:
|
|
254
|
+
access_identifiers = AccessIdentifiers(
|
|
255
|
+
model_id=model_id,
|
|
256
|
+
package_id=package_id,
|
|
257
|
+
api_key=api_key,
|
|
258
|
+
)
|
|
259
|
+
model_access_manager.on_file_created(
|
|
260
|
+
file_path=file_path,
|
|
261
|
+
access_identifiers=access_identifiers,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
auto_resolution_cache = BaseAutoLoadMetadataCache(
|
|
265
|
+
file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
266
|
+
verbose=verbose,
|
|
267
|
+
on_file_created=register_file_created_for_model_package,
|
|
268
|
+
on_file_deleted=model_access_manager.on_file_deleted,
|
|
269
|
+
)
|
|
270
|
+
model_init_kwargs = {
|
|
271
|
+
"onnx_execution_providers": onnx_execution_providers,
|
|
272
|
+
"device": device,
|
|
273
|
+
"default_onnx_trt_options": default_onnx_trt_options,
|
|
274
|
+
"engine_host_code_allowed": trt_engine_host_code_allowed,
|
|
275
|
+
}
|
|
276
|
+
model_init_kwargs.update(kwargs)
|
|
277
|
+
if not os.path.exists(model_id_or_path):
|
|
278
|
+
# QUESTION: is it enough to assume presence of local dir as the intent to load
|
|
279
|
+
# model from disc drive? What if we have clash of model id / model alias with
|
|
280
|
+
# contents of someone's local drive - shall we then try to load from both sources?
|
|
281
|
+
# that still may end up with ambiguous behaviour - probably the solution would be
|
|
282
|
+
# to require prefix like file://... to denote the intent of loading model from local
|
|
283
|
+
# drive?
|
|
284
|
+
if api_key is not None:
|
|
285
|
+
password_hashed = PasswordHasher()
|
|
286
|
+
api_key_hash = password_hashed.hash(api_key)
|
|
287
|
+
else:
|
|
288
|
+
api_key_hash = api_key
|
|
289
|
+
auto_negotiation_hash = hash_dict_content(
|
|
290
|
+
content={
|
|
291
|
+
"provider": weights_provider,
|
|
292
|
+
"model_id": model_id_or_path,
|
|
293
|
+
"api_key": api_key_hash,
|
|
294
|
+
"requested_model_package_id": model_package_id,
|
|
295
|
+
"requested_backends": backend,
|
|
296
|
+
"requested_batch_size": batch_size,
|
|
297
|
+
"requested_quantization": quantization,
|
|
298
|
+
"device": str(device),
|
|
299
|
+
"onnx_execution_providers": onnx_execution_providers,
|
|
300
|
+
"allow_untrusted_packages": allow_untrusted_packages,
|
|
301
|
+
"trt_engine_host_code_allowed": trt_engine_host_code_allowed,
|
|
302
|
+
"nms_fusion_preferences": nms_fusion_preferences,
|
|
303
|
+
}
|
|
304
|
+
)
|
|
305
|
+
model_from_access_manager = model_access_manager.retrieve_model_instance(
|
|
306
|
+
model_id=model_id_or_path,
|
|
307
|
+
package_id=model_package_id,
|
|
308
|
+
api_key=api_key,
|
|
309
|
+
loading_parameter_digest=auto_negotiation_hash,
|
|
310
|
+
)
|
|
311
|
+
if model_from_access_manager:
|
|
312
|
+
return model_from_access_manager
|
|
313
|
+
if forwarded_kwargs is None:
|
|
314
|
+
forwarded_kwargs = (
|
|
315
|
+
DEFAULT_KWARGS_PARAMS_TO_BE_FORWARDED_TO_DEPENDENT_MODELS
|
|
316
|
+
)
|
|
317
|
+
forwarded_kwargs_values = {
|
|
318
|
+
name: kwargs[name] for name in forwarded_kwargs if name in kwargs
|
|
319
|
+
}
|
|
320
|
+
model_from_cache = attempt_loading_model_with_auto_load_cache(
|
|
321
|
+
use_auto_resolution_cache=use_auto_resolution_cache,
|
|
322
|
+
auto_resolution_cache=auto_resolution_cache,
|
|
323
|
+
auto_negotiation_hash=auto_negotiation_hash,
|
|
324
|
+
model_access_manager=model_access_manager,
|
|
325
|
+
model_name_or_path=model_id_or_path,
|
|
326
|
+
model_init_kwargs=model_init_kwargs,
|
|
327
|
+
api_key=api_key,
|
|
328
|
+
allow_loading_dependency_models=allow_loading_dependency_models,
|
|
329
|
+
forwarded_kwargs_values=forwarded_kwargs_values,
|
|
330
|
+
verbose=verbose,
|
|
331
|
+
weights_provider=weights_provider,
|
|
332
|
+
max_package_loading_attempts=max_package_loading_attempts,
|
|
333
|
+
model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
334
|
+
allow_untrusted_packages=allow_untrusted_packages,
|
|
335
|
+
trt_engine_host_code_allowed=trt_engine_host_code_allowed,
|
|
336
|
+
allow_local_code_packages=allow_local_code_packages,
|
|
337
|
+
verify_hash_while_download=verify_hash_while_download,
|
|
338
|
+
download_files_without_hash=download_files_without_hash,
|
|
339
|
+
allow_direct_local_storage_loading=allow_direct_local_storage_loading,
|
|
340
|
+
dependency_models_params=dependency_models_params,
|
|
341
|
+
)
|
|
342
|
+
if model_from_cache:
|
|
343
|
+
return model_from_cache
|
|
344
|
+
try:
|
|
345
|
+
model_metadata = get_model_from_provider(
|
|
346
|
+
provider=weights_provider,
|
|
347
|
+
model_id=model_id_or_path,
|
|
348
|
+
api_key=api_key,
|
|
349
|
+
)
|
|
350
|
+
if (
|
|
351
|
+
model_metadata.model_dependencies
|
|
352
|
+
and not allow_loading_dependency_models
|
|
353
|
+
):
|
|
354
|
+
raise CorruptedModelPackageError(
|
|
355
|
+
message=f"Could not load model {model_id_or_path} as it defines another models which are "
|
|
356
|
+
f"it's dependency, but the auto-loader prevents loading dependencies at certain "
|
|
357
|
+
f"nesting depth to avoid excessive resolution procedure. This is a limitation of "
|
|
358
|
+
f"current implementation. Provide us the context of your use-case to get help.",
|
|
359
|
+
help_url="https://todo",
|
|
360
|
+
)
|
|
361
|
+
if model_metadata.model_id != model_id_or_path:
|
|
362
|
+
model_access_manager.on_model_alias_discovered(
|
|
363
|
+
alias=model_id_or_path,
|
|
364
|
+
model_id=model_metadata.model_id,
|
|
365
|
+
)
|
|
366
|
+
model_dependencies = model_metadata.model_dependencies or []
|
|
367
|
+
for model_dependency in model_dependencies:
|
|
368
|
+
model_access_manager.on_model_dependency_discovered(
|
|
369
|
+
base_model_id=model_dependency.model_id,
|
|
370
|
+
base_model_package_id=model_dependency.model_package_id,
|
|
371
|
+
dependent_model_id=model_metadata.model_id,
|
|
372
|
+
)
|
|
373
|
+
for model_package in model_metadata.model_packages:
|
|
374
|
+
package_access_identifiers = AccessIdentifiers(
|
|
375
|
+
model_id=model_metadata.model_id,
|
|
376
|
+
package_id=model_package.package_id,
|
|
377
|
+
api_key=api_key,
|
|
378
|
+
)
|
|
379
|
+
model_access_manager.on_model_package_access_granted(
|
|
380
|
+
package_access_identifiers
|
|
381
|
+
)
|
|
382
|
+
except UnauthorizedModelAccessError as error:
|
|
383
|
+
model_access_manager.on_model_access_forbidden(
|
|
384
|
+
model_id=model_id_or_path, api_key=api_key
|
|
385
|
+
)
|
|
386
|
+
raise error
|
|
387
|
+
# here we verify if de-aliasing or access confirmation from auth master changed something
|
|
388
|
+
model_from_access_manager = model_access_manager.retrieve_model_instance(
|
|
389
|
+
model_id=model_id_or_path,
|
|
390
|
+
package_id=model_package_id,
|
|
391
|
+
api_key=api_key,
|
|
392
|
+
loading_parameter_digest=auto_negotiation_hash,
|
|
393
|
+
)
|
|
394
|
+
if model_from_access_manager:
|
|
395
|
+
return model_from_access_manager
|
|
396
|
+
matching_model_packages = negotiate_model_packages(
|
|
397
|
+
model_architecture=model_metadata.model_architecture,
|
|
398
|
+
task_type=model_metadata.task_type,
|
|
399
|
+
model_packages=model_metadata.model_packages,
|
|
400
|
+
requested_model_package_id=model_package_id,
|
|
401
|
+
requested_backends=backend,
|
|
402
|
+
requested_batch_size=batch_size,
|
|
403
|
+
requested_quantization=quantization,
|
|
404
|
+
device=device,
|
|
405
|
+
onnx_execution_providers=onnx_execution_providers,
|
|
406
|
+
allow_untrusted_packages=allow_untrusted_packages,
|
|
407
|
+
trt_engine_host_code_allowed=trt_engine_host_code_allowed,
|
|
408
|
+
nms_fusion_preferences=nms_fusion_preferences,
|
|
409
|
+
verbose=verbose,
|
|
410
|
+
)
|
|
411
|
+
model_dependencies_instances = {}
|
|
412
|
+
model_dependencies_directories = {}
|
|
413
|
+
dependency_models_params = dependency_models_params or {}
|
|
414
|
+
for model_dependency in model_dependencies:
|
|
415
|
+
dependency_params = dependency_models_params.get(
|
|
416
|
+
model_dependency.name, {}
|
|
417
|
+
)
|
|
418
|
+
dependency_params["model_id_or_path"] = model_dependency.model_id
|
|
419
|
+
dependency_params["model_package_id"] = (
|
|
420
|
+
model_dependency.model_package_id
|
|
421
|
+
)
|
|
422
|
+
resolved_model_parameters = prepare_dependency_model_parameters(
|
|
423
|
+
model_parameters=dependency_params
|
|
424
|
+
)
|
|
425
|
+
verbose_info(
|
|
426
|
+
message=f"Initialising dependent model: {model_dependency.model_id}",
|
|
427
|
+
verbose_requested=verbose,
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
def model_directory_pointer(model_dir: str) -> None:
|
|
431
|
+
model_dependencies_directories[model_dependency.name] = model_dir
|
|
432
|
+
|
|
433
|
+
for name, value in forwarded_kwargs_values.items():
|
|
434
|
+
if name not in resolved_model_parameters.model_extra:
|
|
435
|
+
resolved_model_parameters.model_extra[name] = value
|
|
436
|
+
|
|
437
|
+
dependency_instance = AutoModel.from_pretrained(
|
|
438
|
+
model_id_or_path=resolved_model_parameters.model_id_or_path,
|
|
439
|
+
weights_provider=weights_provider,
|
|
440
|
+
api_key=api_key,
|
|
441
|
+
model_package_id=resolved_model_parameters.model_package_id,
|
|
442
|
+
backend=resolved_model_parameters.backend,
|
|
443
|
+
batch_size=resolved_model_parameters.batch_size,
|
|
444
|
+
quantization=resolved_model_parameters.quantization,
|
|
445
|
+
onnx_execution_providers=resolved_model_parameters.onnx_execution_providers,
|
|
446
|
+
device=resolved_model_parameters.device,
|
|
447
|
+
default_onnx_trt_options=resolved_model_parameters.default_onnx_trt_options,
|
|
448
|
+
max_package_loading_attempts=max_package_loading_attempts,
|
|
449
|
+
verbose=verbose,
|
|
450
|
+
model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
451
|
+
allow_untrusted_packages=allow_untrusted_packages,
|
|
452
|
+
trt_engine_host_code_allowed=trt_engine_host_code_allowed,
|
|
453
|
+
allow_local_code_packages=allow_local_code_packages,
|
|
454
|
+
verify_hash_while_download=verify_hash_while_download,
|
|
455
|
+
download_files_without_hash=download_files_without_hash,
|
|
456
|
+
use_auto_resolution_cache=use_auto_resolution_cache,
|
|
457
|
+
auto_resolution_cache=auto_resolution_cache,
|
|
458
|
+
allow_direct_local_storage_loading=allow_direct_local_storage_loading,
|
|
459
|
+
model_access_manager=model_access_manager,
|
|
460
|
+
nms_fusion_preferences=resolved_model_parameters.nms_fusion_preferences,
|
|
461
|
+
model_type=resolved_model_parameters.model_type,
|
|
462
|
+
task_type=resolved_model_parameters.task_type,
|
|
463
|
+
allow_loading_dependency_models=False,
|
|
464
|
+
dependency_models_params=None,
|
|
465
|
+
point_model_directory=model_directory_pointer,
|
|
466
|
+
**resolved_model_parameters.kwargs,
|
|
467
|
+
)
|
|
468
|
+
model_dependencies_instances[model_dependency.name] = (
|
|
469
|
+
dependency_instance
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
return attempt_loading_matching_model_packages(
|
|
473
|
+
model_id=model_id_or_path,
|
|
474
|
+
model_architecture=model_metadata.model_architecture,
|
|
475
|
+
task_type=model_metadata.task_type,
|
|
476
|
+
matching_model_packages=matching_model_packages,
|
|
477
|
+
model_init_kwargs=model_init_kwargs,
|
|
478
|
+
model_access_manager=model_access_manager,
|
|
479
|
+
auto_negotiation_hash=auto_negotiation_hash,
|
|
480
|
+
api_key=api_key,
|
|
481
|
+
model_dependencies=model_metadata.model_dependencies,
|
|
482
|
+
model_dependencies_instances=model_dependencies_instances,
|
|
483
|
+
model_dependencies_directories=model_dependencies_directories,
|
|
484
|
+
max_package_loading_attempts=max_package_loading_attempts,
|
|
485
|
+
model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
486
|
+
verify_hash_while_download=verify_hash_while_download,
|
|
487
|
+
download_files_without_hash=download_files_without_hash,
|
|
488
|
+
auto_resolution_cache=auto_resolution_cache,
|
|
489
|
+
use_auto_resolution_cache=use_auto_resolution_cache,
|
|
490
|
+
point_model_directory=point_model_directory,
|
|
491
|
+
verbose=verbose,
|
|
492
|
+
)
|
|
493
|
+
if not allow_direct_local_storage_loading:
|
|
494
|
+
raise DirectLocalStorageAccessError(
|
|
495
|
+
message="Attempted to load model directly pointing local path, rather than model ID. This "
|
|
496
|
+
"operation is forbidden as AutoModel.from_pretrained(...) was used with "
|
|
497
|
+
"`allow_direct_local_storage_loading=False`. If you are running `inference-models` outside Roboflow "
|
|
498
|
+
"hosted solutions - verify your setup. If you see this error on Roboflow platform - this "
|
|
499
|
+
"feature was disabled for security reason. In rare cases when you use valid model ID, the "
|
|
500
|
+
"clash of ID with local path may cause this error - we ask you to report the issue here: "
|
|
501
|
+
"https://github.com/roboflow/inference/issues.",
|
|
502
|
+
help_url="https://todo",
|
|
503
|
+
)
|
|
504
|
+
return attempt_loading_model_from_local_storage(
|
|
505
|
+
model_dir_or_weights_path=model_id_or_path,
|
|
506
|
+
allow_local_code_packages=allow_local_code_packages,
|
|
507
|
+
model_init_kwargs=model_init_kwargs,
|
|
508
|
+
model_type=model_type,
|
|
509
|
+
task_type=task_type,
|
|
510
|
+
backend_type=backend,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def attempt_loading_model_with_auto_load_cache(
|
|
515
|
+
use_auto_resolution_cache: bool,
|
|
516
|
+
auto_resolution_cache: AutoResolutionCache,
|
|
517
|
+
auto_negotiation_hash: str,
|
|
518
|
+
model_access_manager: ModelAccessManager,
|
|
519
|
+
model_name_or_path: str,
|
|
520
|
+
model_init_kwargs: dict,
|
|
521
|
+
api_key: Optional[str],
|
|
522
|
+
allow_loading_dependency_models: bool,
|
|
523
|
+
forwarded_kwargs_values: Dict[str, Any],
|
|
524
|
+
verbose: bool = False,
|
|
525
|
+
weights_provider: str = "roboflow",
|
|
526
|
+
max_package_loading_attempts: Optional[int] = None,
|
|
527
|
+
model_download_file_lock_acquire_timeout: int = 10,
|
|
528
|
+
allow_untrusted_packages: bool = False,
|
|
529
|
+
trt_engine_host_code_allowed: bool = True,
|
|
530
|
+
allow_local_code_packages: bool = True,
|
|
531
|
+
verify_hash_while_download: bool = True,
|
|
532
|
+
download_files_without_hash: bool = False,
|
|
533
|
+
allow_direct_local_storage_loading: bool = True,
|
|
534
|
+
dependency_models_params: Optional[dict] = None,
|
|
535
|
+
) -> Optional[AnyModel]:
|
|
536
|
+
if not use_auto_resolution_cache:
|
|
537
|
+
return None
|
|
538
|
+
verbose_info(
|
|
539
|
+
message=f"Attempt to load model {model_name_or_path} using auto-load cache.",
|
|
540
|
+
verbose_requested=verbose,
|
|
541
|
+
)
|
|
542
|
+
cache_entry = auto_resolution_cache.retrieve(
|
|
543
|
+
auto_negotiation_hash=auto_negotiation_hash
|
|
544
|
+
)
|
|
545
|
+
if cache_entry is None:
|
|
546
|
+
verbose_info(
|
|
547
|
+
message=f"Could not find auto-load cache for model {model_name_or_path}.",
|
|
548
|
+
verbose_requested=verbose,
|
|
549
|
+
)
|
|
550
|
+
return None
|
|
551
|
+
if not model_access_manager.is_model_package_access_granted(
|
|
552
|
+
model_id=cache_entry.model_id,
|
|
553
|
+
package_id=cache_entry.model_package_id,
|
|
554
|
+
api_key=api_key,
|
|
555
|
+
):
|
|
556
|
+
return None
|
|
557
|
+
if not all_files_exist(files=cache_entry.resolved_files):
|
|
558
|
+
verbose_info(
|
|
559
|
+
message=f"Could not find all required files denoted in auto-load cache for model {model_name_or_path}.",
|
|
560
|
+
verbose_requested=verbose,
|
|
561
|
+
)
|
|
562
|
+
return None
|
|
563
|
+
try:
|
|
564
|
+
model_dependencies = cache_entry.model_dependencies or []
|
|
565
|
+
if model_dependencies and not allow_loading_dependency_models:
|
|
566
|
+
raise CorruptedModelPackageError(
|
|
567
|
+
message=f"Could not load model {cache_entry.model_id} as it defines another models which are "
|
|
568
|
+
f"it's dependency, but the auto-loader prevents loading dependencies at certain "
|
|
569
|
+
f"nesting depth to avoid excessive resolution procedure. This is a limitation of "
|
|
570
|
+
f"current implementation. Provide us the context of your use-case to get help.",
|
|
571
|
+
help_url="https://todo",
|
|
572
|
+
)
|
|
573
|
+
model_dependencies_instances = {}
|
|
574
|
+
dependency_models_params = dependency_models_params or {}
|
|
575
|
+
for model_dependency in model_dependencies:
|
|
576
|
+
dependency_params = dependency_models_params.get(model_dependency.name, {})
|
|
577
|
+
dependency_params["model_id_or_path"] = model_dependency.model_id
|
|
578
|
+
dependency_params["model_package_id"] = model_dependency.model_package_id
|
|
579
|
+
resolved_model_parameters = prepare_dependency_model_parameters(
|
|
580
|
+
model_parameters=dependency_params
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
for name, value in forwarded_kwargs_values.items():
|
|
584
|
+
if name not in resolved_model_parameters.model_extra:
|
|
585
|
+
resolved_model_parameters.model_extra[name] = value
|
|
586
|
+
verbose_info(
|
|
587
|
+
message=f"Initialising dependent model: {model_dependency.model_id}",
|
|
588
|
+
verbose_requested=verbose,
|
|
589
|
+
)
|
|
590
|
+
dependency_instance = AutoModel.from_pretrained(
|
|
591
|
+
model_id_or_path=resolved_model_parameters.model_id_or_path,
|
|
592
|
+
weights_provider=weights_provider,
|
|
593
|
+
api_key=api_key,
|
|
594
|
+
model_package_id=resolved_model_parameters.model_package_id,
|
|
595
|
+
backend=resolved_model_parameters.backend,
|
|
596
|
+
batch_size=resolved_model_parameters.batch_size,
|
|
597
|
+
quantization=resolved_model_parameters.quantization,
|
|
598
|
+
onnx_execution_providers=resolved_model_parameters.onnx_execution_providers,
|
|
599
|
+
device=resolved_model_parameters.device,
|
|
600
|
+
default_onnx_trt_options=resolved_model_parameters.default_onnx_trt_options,
|
|
601
|
+
max_package_loading_attempts=max_package_loading_attempts,
|
|
602
|
+
verbose=verbose,
|
|
603
|
+
model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
604
|
+
allow_untrusted_packages=allow_untrusted_packages,
|
|
605
|
+
trt_engine_host_code_allowed=trt_engine_host_code_allowed,
|
|
606
|
+
allow_local_code_packages=allow_local_code_packages,
|
|
607
|
+
verify_hash_while_download=verify_hash_while_download,
|
|
608
|
+
download_files_without_hash=download_files_without_hash,
|
|
609
|
+
use_auto_resolution_cache=use_auto_resolution_cache,
|
|
610
|
+
auto_resolution_cache=auto_resolution_cache,
|
|
611
|
+
allow_direct_local_storage_loading=allow_direct_local_storage_loading,
|
|
612
|
+
model_access_manager=model_access_manager,
|
|
613
|
+
nms_fusion_preferences=resolved_model_parameters.nms_fusion_preferences,
|
|
614
|
+
model_type=resolved_model_parameters.model_type,
|
|
615
|
+
task_type=resolved_model_parameters.task_type,
|
|
616
|
+
allow_loading_dependency_models=False,
|
|
617
|
+
dependency_models_params=None,
|
|
618
|
+
**resolved_model_parameters.kwargs,
|
|
619
|
+
)
|
|
620
|
+
model_dependencies_instances[model_dependency.name] = dependency_instance
|
|
621
|
+
model_class = resolve_model_class(
|
|
622
|
+
model_architecture=cache_entry.model_architecture,
|
|
623
|
+
task_type=cache_entry.task_type,
|
|
624
|
+
backend=cache_entry.backend_type,
|
|
625
|
+
)
|
|
626
|
+
model_package_cache_dir = generate_model_package_cache_path(
|
|
627
|
+
model_id=cache_entry.model_id,
|
|
628
|
+
package_id=cache_entry.model_package_id,
|
|
629
|
+
)
|
|
630
|
+
model_init_kwargs[MODEL_DEPENDENCIES_KEY] = model_dependencies_instances
|
|
631
|
+
model = model_class.from_pretrained(
|
|
632
|
+
model_package_cache_dir, **model_init_kwargs
|
|
633
|
+
)
|
|
634
|
+
verbose_info(
|
|
635
|
+
message=f"Successfully loaded model {model_name_or_path} using auto-loading cache.",
|
|
636
|
+
verbose_requested=verbose,
|
|
637
|
+
)
|
|
638
|
+
return model
|
|
639
|
+
except Exception as error:
|
|
640
|
+
LOGGER.warning(
|
|
641
|
+
f"Encountered error {error} of type {type(error)} when attempted to load model using "
|
|
642
|
+
f"auto-load cache. This may indicate corrupted cache of inference bug. Contact Roboflow submitting "
|
|
643
|
+
f"issue under: https://github.com/roboflow/inference/issues/"
|
|
644
|
+
)
|
|
645
|
+
auto_resolution_cache.invalidate(auto_negotiation_hash=auto_negotiation_hash)
|
|
646
|
+
return None
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def all_files_exist(files: List[str]) -> bool:
|
|
650
|
+
return all(os.path.exists(f) for f in files)
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
def attempt_loading_matching_model_packages(
|
|
654
|
+
model_id: str,
|
|
655
|
+
model_architecture: ModelArchitecture,
|
|
656
|
+
task_type: Optional[TaskType],
|
|
657
|
+
matching_model_packages: List[ModelPackageMetadata],
|
|
658
|
+
model_init_kwargs: dict,
|
|
659
|
+
model_access_manager: ModelAccessManager,
|
|
660
|
+
auto_resolution_cache: AutoResolutionCache,
|
|
661
|
+
auto_negotiation_hash: str,
|
|
662
|
+
api_key: Optional[str],
|
|
663
|
+
model_dependencies: Optional[List[ModelDependency]],
|
|
664
|
+
model_dependencies_instances: Dict[str, AnyModel],
|
|
665
|
+
model_dependencies_directories: Dict[str, str],
|
|
666
|
+
max_package_loading_attempts: Optional[int] = None,
|
|
667
|
+
model_download_file_lock_acquire_timeout: int = 10,
|
|
668
|
+
verbose: bool = True,
|
|
669
|
+
verify_hash_while_download: bool = True,
|
|
670
|
+
download_files_without_hash: bool = False,
|
|
671
|
+
use_auto_resolution_cache: bool = True,
|
|
672
|
+
point_model_directory: Optional[Callable[[str], None]] = None,
|
|
673
|
+
) -> AnyModel:
|
|
674
|
+
if max_package_loading_attempts is not None:
|
|
675
|
+
matching_model_packages = matching_model_packages[:max_package_loading_attempts]
|
|
676
|
+
if not matching_model_packages:
|
|
677
|
+
raise ModelLoadingError(
|
|
678
|
+
message=f"Cannot load model {model_id} - no matching model package candidates for given model "
|
|
679
|
+
f"running in this environment.",
|
|
680
|
+
help_url="https://todo",
|
|
681
|
+
)
|
|
682
|
+
failed_load_attempts: List[Tuple[str, Exception]] = []
|
|
683
|
+
for model_package in matching_model_packages:
|
|
684
|
+
access_identifiers = AccessIdentifiers(
|
|
685
|
+
model_id=model_id,
|
|
686
|
+
package_id=model_package.package_id,
|
|
687
|
+
api_key=api_key,
|
|
688
|
+
)
|
|
689
|
+
verbose_info(
|
|
690
|
+
message=f"Attempt to load model package: {model_package.get_summary()}",
|
|
691
|
+
verbose_requested=verbose,
|
|
692
|
+
)
|
|
693
|
+
try:
|
|
694
|
+
model, model_package_cache_dir = initialize_model(
|
|
695
|
+
model_id=model_id,
|
|
696
|
+
model_architecture=model_architecture,
|
|
697
|
+
task_type=task_type,
|
|
698
|
+
model_package=model_package,
|
|
699
|
+
model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
700
|
+
model_init_kwargs=model_init_kwargs,
|
|
701
|
+
auto_resolution_cache=auto_resolution_cache,
|
|
702
|
+
auto_negotiation_hash=auto_negotiation_hash,
|
|
703
|
+
model_dependencies=model_dependencies,
|
|
704
|
+
model_dependencies_instances=model_dependencies_instances,
|
|
705
|
+
model_dependencies_directories=model_dependencies_directories,
|
|
706
|
+
verify_hash_while_download=verify_hash_while_download,
|
|
707
|
+
download_files_without_hash=download_files_without_hash,
|
|
708
|
+
on_file_created=partial(
|
|
709
|
+
model_access_manager.on_file_created,
|
|
710
|
+
access_identifiers=access_identifiers,
|
|
711
|
+
),
|
|
712
|
+
on_file_renamed=partial(
|
|
713
|
+
model_access_manager.on_file_renamed,
|
|
714
|
+
access_identifiers=access_identifiers,
|
|
715
|
+
),
|
|
716
|
+
on_symlink_created=partial(
|
|
717
|
+
model_access_manager.on_symlink_created,
|
|
718
|
+
access_identifiers=access_identifiers,
|
|
719
|
+
),
|
|
720
|
+
on_symlink_deleted=model_access_manager.on_symlink_deleted,
|
|
721
|
+
use_auto_resolution_cache=use_auto_resolution_cache,
|
|
722
|
+
)
|
|
723
|
+
model_access_manager.on_model_loaded(
|
|
724
|
+
model=model,
|
|
725
|
+
access_identifiers=access_identifiers,
|
|
726
|
+
model_storage_path=model_package_cache_dir,
|
|
727
|
+
)
|
|
728
|
+
if point_model_directory:
|
|
729
|
+
point_model_directory(model_package_cache_dir)
|
|
730
|
+
return model
|
|
731
|
+
except Exception as error:
|
|
732
|
+
LOGGER.warning(
|
|
733
|
+
f"Model package with id {model_package.package_id} that was selected to be loaded "
|
|
734
|
+
f"failed to load with error: {error} of type {error.__class__.__name__}. This may "
|
|
735
|
+
f"be caused several issues. If you see this warning after manually specifying model "
|
|
736
|
+
f"package to be loaded - make sure that all required dependencies are installed. If "
|
|
737
|
+
f"that warning is displayed when the model package was auto-selected - there is most "
|
|
738
|
+
f"likely a bug in `inference-models` and you should raise an issue providing full context of "
|
|
739
|
+
f"the event. https://github.com/roboflow/inference/issues"
|
|
740
|
+
)
|
|
741
|
+
failed_load_attempts.append((model_package.package_id, error))
|
|
742
|
+
|
|
743
|
+
summary_of_errors = "\n".join(
|
|
744
|
+
f"\t* model_package_id={model_package_id} error={error} error_type={error.__class__.__name__}"
|
|
745
|
+
for model_package_id, error in failed_load_attempts
|
|
746
|
+
)
|
|
747
|
+
raise ModelLoadingError(
|
|
748
|
+
message=f"Could not load any of model package candidate for model {model_id}. This may "
|
|
749
|
+
f"be caused several issues. If you see this warning after manually specifying model "
|
|
750
|
+
f"package to be loaded - make sure that all required dependencies are installed. If "
|
|
751
|
+
f"that warning is displayed when the model package was auto-selected - there is most "
|
|
752
|
+
f"likely a bug in `inference-models` and you should raise an issue providing full context of "
|
|
753
|
+
f"the event. https://github.com/roboflow/inference/issues\n\n"
|
|
754
|
+
f"Here is the summary of errors for specific model packages:\n{summary_of_errors}\n\n",
|
|
755
|
+
help_url="https://todo",
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def initialize_model(
|
|
760
|
+
model_id: str,
|
|
761
|
+
model_architecture: ModelArchitecture,
|
|
762
|
+
task_type: Optional[TaskType],
|
|
763
|
+
model_package: ModelPackageMetadata,
|
|
764
|
+
model_init_kwargs: dict,
|
|
765
|
+
auto_resolution_cache: AutoResolutionCache,
|
|
766
|
+
auto_negotiation_hash: str,
|
|
767
|
+
model_dependencies: Optional[List[ModelDependency]],
|
|
768
|
+
model_dependencies_instances: Dict[str, AnyModel],
|
|
769
|
+
model_dependencies_directories: Dict[str, str],
|
|
770
|
+
model_download_file_lock_acquire_timeout: int = 10,
|
|
771
|
+
verify_hash_while_download: bool = True,
|
|
772
|
+
download_files_without_hash: bool = False,
|
|
773
|
+
on_file_created: Optional[Callable[[str], None]] = None,
|
|
774
|
+
on_file_renamed: Optional[Callable[[str, str], None]] = None,
|
|
775
|
+
on_symlink_created: Optional[Callable[[str, str], None]] = None,
|
|
776
|
+
on_symlink_deleted: Optional[Callable[[str], None]] = None,
|
|
777
|
+
use_auto_resolution_cache: bool = True,
|
|
778
|
+
) -> Tuple[AnyModel, str]:
|
|
779
|
+
model_class = resolve_model_class(
|
|
780
|
+
model_architecture=model_architecture,
|
|
781
|
+
task_type=task_type,
|
|
782
|
+
backend=model_package.backend,
|
|
783
|
+
)
|
|
784
|
+
for artefact in model_package.package_artefacts:
|
|
785
|
+
if artefact.file_handle == MODEL_CONFIG_FILE_NAME:
|
|
786
|
+
raise CorruptedModelPackageError(
|
|
787
|
+
message=f"For model with id=`{model_id}` and package={model_package.package_id} discovered "
|
|
788
|
+
f"artefact named `{MODEL_CONFIG_FILE_NAME}` which collides with the config file that "
|
|
789
|
+
f"inference is supposed to create for a model in order for compatibility with offline "
|
|
790
|
+
f"loaders. This problem indicate a violation of model package contract and requires change in "
|
|
791
|
+
f"model package structure. If you experience this issue using hosted Roboflow solution, contact "
|
|
792
|
+
f"us to solve the problem.",
|
|
793
|
+
help_url="https://todo",
|
|
794
|
+
)
|
|
795
|
+
files_specs = [
|
|
796
|
+
(a.file_handle, a.download_url, a.md5_hash)
|
|
797
|
+
for a in model_package.package_artefacts
|
|
798
|
+
]
|
|
799
|
+
file_specs_with_hash = [f for f in files_specs if f[2] is not None]
|
|
800
|
+
file_specs_without_hash = [f for f in files_specs if f[2] is None]
|
|
801
|
+
shared_blobs_dir = generate_shared_blobs_path()
|
|
802
|
+
model_package_cache_dir = generate_model_package_cache_path(
|
|
803
|
+
model_id=model_id,
|
|
804
|
+
package_id=model_package.package_id,
|
|
805
|
+
)
|
|
806
|
+
os.makedirs(model_package_cache_dir, exist_ok=True)
|
|
807
|
+
shared_files_mapping = download_files_to_directory(
|
|
808
|
+
target_dir=shared_blobs_dir,
|
|
809
|
+
files_specs=file_specs_with_hash,
|
|
810
|
+
file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
811
|
+
verify_hash_while_download=verify_hash_while_download,
|
|
812
|
+
download_files_without_hash=download_files_without_hash,
|
|
813
|
+
name_after="md5_hash",
|
|
814
|
+
on_file_created=on_file_created,
|
|
815
|
+
on_file_renamed=on_file_renamed,
|
|
816
|
+
)
|
|
817
|
+
model_specific_files_mapping = download_files_to_directory(
|
|
818
|
+
target_dir=model_package_cache_dir,
|
|
819
|
+
files_specs=file_specs_without_hash,
|
|
820
|
+
file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
821
|
+
verify_hash_while_download=verify_hash_while_download,
|
|
822
|
+
download_files_without_hash=download_files_without_hash,
|
|
823
|
+
on_file_created=on_file_created,
|
|
824
|
+
on_file_renamed=on_file_renamed,
|
|
825
|
+
)
|
|
826
|
+
symlinks_mapping = create_symlinks_to_shared_blobs(
|
|
827
|
+
model_dir=model_package_cache_dir,
|
|
828
|
+
shared_files_mapping=shared_files_mapping,
|
|
829
|
+
model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
830
|
+
on_symlink_created=on_symlink_created,
|
|
831
|
+
on_symlink_deleted=on_symlink_deleted,
|
|
832
|
+
)
|
|
833
|
+
config_path = os.path.join(model_package_cache_dir, MODEL_CONFIG_FILE_NAME)
|
|
834
|
+
dump_model_config_for_offline_use(
|
|
835
|
+
config_path=config_path,
|
|
836
|
+
model_architecture=model_architecture,
|
|
837
|
+
task_type=task_type,
|
|
838
|
+
backend_type=model_package.backend,
|
|
839
|
+
file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
840
|
+
on_file_created=on_file_created,
|
|
841
|
+
)
|
|
842
|
+
resolved_files = set(shared_files_mapping.values())
|
|
843
|
+
resolved_files.update(model_specific_files_mapping.values())
|
|
844
|
+
resolved_files.update(symlinks_mapping.values())
|
|
845
|
+
resolved_files.add(config_path)
|
|
846
|
+
dependencies_resolved_files = handle_dependencies_directories_creation(
|
|
847
|
+
model_package_cache_dir=model_package_cache_dir,
|
|
848
|
+
model_dependencies_directories=model_dependencies_directories,
|
|
849
|
+
model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
850
|
+
on_symlink_created=on_symlink_created,
|
|
851
|
+
on_symlink_deleted=on_symlink_deleted,
|
|
852
|
+
)
|
|
853
|
+
resolved_files.update(dependencies_resolved_files)
|
|
854
|
+
model_init_kwargs[MODEL_DEPENDENCIES_KEY] = model_dependencies_instances
|
|
855
|
+
model = model_class.from_pretrained(model_package_cache_dir, **model_init_kwargs)
|
|
856
|
+
dump_auto_resolution_cache(
|
|
857
|
+
use_auto_resolution_cache=use_auto_resolution_cache,
|
|
858
|
+
auto_resolution_cache=auto_resolution_cache,
|
|
859
|
+
auto_negotiation_hash=auto_negotiation_hash,
|
|
860
|
+
model_id=model_id,
|
|
861
|
+
model_package_id=model_package.package_id,
|
|
862
|
+
model_architecture=model_architecture,
|
|
863
|
+
task_type=task_type,
|
|
864
|
+
backend_type=model_package.backend,
|
|
865
|
+
resolved_files=resolved_files,
|
|
866
|
+
model_dependencies=model_dependencies,
|
|
867
|
+
)
|
|
868
|
+
return model, model_package_cache_dir
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
def create_symlinks_to_shared_blobs(
|
|
872
|
+
model_dir: str,
|
|
873
|
+
shared_files_mapping: Dict[FileHandle, str],
|
|
874
|
+
model_download_file_lock_acquire_timeout: int = 10,
|
|
875
|
+
on_symlink_created: Optional[Callable[[str, str], None]] = None,
|
|
876
|
+
on_symlink_deleted: Optional[Callable[[str], None]] = None,
|
|
877
|
+
) -> Dict[str, str]:
|
|
878
|
+
# this function will not override existing files
|
|
879
|
+
os.makedirs(model_dir, exist_ok=True)
|
|
880
|
+
result = {}
|
|
881
|
+
for file_handle, source_path in shared_files_mapping.items():
|
|
882
|
+
link_name = os.path.join(model_dir, file_handle)
|
|
883
|
+
target_path = shared_files_mapping[file_handle]
|
|
884
|
+
result[file_handle] = link_name
|
|
885
|
+
if os.path.exists(link_name) and (
|
|
886
|
+
not os.path.islink(link_name) or os.path.realpath(link_name) == target_path
|
|
887
|
+
):
|
|
888
|
+
continue
|
|
889
|
+
handle_symlink_creation(
|
|
890
|
+
target_path=target_path,
|
|
891
|
+
link_name=link_name,
|
|
892
|
+
model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
|
|
893
|
+
on_symlink_created=on_symlink_created,
|
|
894
|
+
on_symlink_deleted=on_symlink_deleted,
|
|
895
|
+
)
|
|
896
|
+
return result
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def handle_symlink_creation(
|
|
900
|
+
target_path: str,
|
|
901
|
+
link_name: str,
|
|
902
|
+
model_download_file_lock_acquire_timeout: int = 10,
|
|
903
|
+
on_symlink_created: Optional[Callable[[str, str], None]] = None,
|
|
904
|
+
on_symlink_deleted: Optional[Callable[[str], None]] = None,
|
|
905
|
+
) -> None:
|
|
906
|
+
link_dir, link_file_name = os.path.split(os.path.abspath(link_name))
|
|
907
|
+
os.makedirs(link_dir, exist_ok=True)
|
|
908
|
+
lock_path = os.path.join(link_dir, f".{link_file_name}.lock")
|
|
909
|
+
with FileLock(lock_path, timeout=model_download_file_lock_acquire_timeout):
|
|
910
|
+
if os.path.islink(link_name):
|
|
911
|
+
# file does not exist, but is link = broken symlink - we should purge
|
|
912
|
+
os.remove(link_name)
|
|
913
|
+
if on_symlink_deleted:
|
|
914
|
+
on_symlink_deleted(link_name)
|
|
915
|
+
os.symlink(target_path, link_name)
|
|
916
|
+
if on_symlink_created:
|
|
917
|
+
on_symlink_created(target_path, link_name)
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
def dump_model_config_for_offline_use(
|
|
921
|
+
config_path: str,
|
|
922
|
+
model_architecture: Optional[ModelArchitecture],
|
|
923
|
+
task_type: TaskType,
|
|
924
|
+
backend_type: Optional[BackendType],
|
|
925
|
+
file_lock_acquire_timeout: int,
|
|
926
|
+
on_file_created: Optional[Callable[[str], None]] = None,
|
|
927
|
+
) -> None:
|
|
928
|
+
if os.path.exists(config_path):
|
|
929
|
+
# we kinda trust that what we did previously is right - in case when the file
|
|
930
|
+
# gets corrupted we may end up in problem - to be verified empirically
|
|
931
|
+
return None
|
|
932
|
+
target_file_dir, target_file_name = os.path.split(config_path)
|
|
933
|
+
lock_path = os.path.join(target_file_dir, f".{target_file_name}.lock")
|
|
934
|
+
with FileLock(lock_path, timeout=file_lock_acquire_timeout):
|
|
935
|
+
dump_json(
|
|
936
|
+
path=config_path,
|
|
937
|
+
content={
|
|
938
|
+
"model_architecture": model_architecture,
|
|
939
|
+
"task_type": task_type,
|
|
940
|
+
"backend_type": backend_type,
|
|
941
|
+
},
|
|
942
|
+
)
|
|
943
|
+
if on_file_created:
|
|
944
|
+
on_file_created(config_path)
|
|
945
|
+
|
|
946
|
+
|
|
947
|
+
def handle_dependencies_directories_creation(
|
|
948
|
+
model_package_cache_dir: str,
|
|
949
|
+
model_dependencies_directories: Dict[str, str],
|
|
950
|
+
model_download_file_lock_acquire_timeout: int = 10,
|
|
951
|
+
on_symlink_created: Optional[Callable[[str, str], None]] = None,
|
|
952
|
+
on_symlink_deleted: Optional[Callable[[str], None]] = None,
|
|
953
|
+
) -> Set[str]:
|
|
954
|
+
resolved_files = set()
|
|
955
|
+
if not model_dependencies_directories:
|
|
956
|
+
return resolved_files
|
|
957
|
+
for dependency_name, dependency_directory in model_dependencies_directories.items():
|
|
958
|
+
dependency_files = scan_dependency_directory_for_resolved_files(
|
|
959
|
+
dependency_directory=dependency_directory
|
|
960
|
+
)
|
|
961
|
+
resolved_files.update(dependency_files)
|
|
962
|
+
dependencies_sub_dir = os.path.join(
|
|
963
|
+
model_package_cache_dir, MODEL_DEPENDENCIES_SUB_DIR
|
|
964
|
+
)
|
|
965
|
+
target_dependency_dir = os.path.join(dependencies_sub_dir, dependency_name)
|
|
966
|
+
os.makedirs(dependencies_sub_dir, exist_ok=True)
|
|
967
|
+
dependency_lock_path = os.path.join(
|
|
968
|
+
dependencies_sub_dir, f".{dependency_name}.lock"
|
|
969
|
+
)
|
|
970
|
+
with FileLock(
|
|
971
|
+
dependency_lock_path, timeout=model_download_file_lock_acquire_timeout
|
|
972
|
+
):
|
|
973
|
+
if os.path.exists(target_dependency_dir) and os.path.islink(
|
|
974
|
+
target_dependency_dir
|
|
975
|
+
):
|
|
976
|
+
os.remove(target_dependency_dir)
|
|
977
|
+
if on_symlink_deleted:
|
|
978
|
+
on_symlink_deleted(target_dependency_dir)
|
|
979
|
+
if not os.path.exists(target_dependency_dir):
|
|
980
|
+
# Question: is it ok to only try to remove symlink and avoid doing anything else
|
|
981
|
+
# if we encounter actual file / dir there?
|
|
982
|
+
os.symlink(dependency_directory, target_dependency_dir)
|
|
983
|
+
if on_symlink_created:
|
|
984
|
+
on_symlink_created(dependency_directory, target_dependency_dir)
|
|
985
|
+
return resolved_files
|
|
986
|
+
|
|
987
|
+
|
|
988
|
+
def scan_dependency_directory_for_resolved_files(
|
|
989
|
+
dependency_directory: str,
|
|
990
|
+
) -> List[str]:
|
|
991
|
+
# we do not follow symlinks here, as the assumption is that we only support one level of nesting
|
|
992
|
+
# for packages, wo when we have dependency - this model must not have dependencies, so
|
|
993
|
+
# we will not encounter directories which are symlinks to be followed.
|
|
994
|
+
results = []
|
|
995
|
+
for current_dir, _, files in os.walk(dependency_directory):
|
|
996
|
+
for file in files:
|
|
997
|
+
if file.startswith(".") and file.endswith(".lock"):
|
|
998
|
+
continue
|
|
999
|
+
full_path = os.path.abspath(os.path.join(current_dir, file))
|
|
1000
|
+
results.append(full_path)
|
|
1001
|
+
if os.path.islink(full_path):
|
|
1002
|
+
results.append(os.readlink(full_path))
|
|
1003
|
+
return results
|
|
1004
|
+
|
|
1005
|
+
|
|
1006
|
+
def dump_auto_resolution_cache(
|
|
1007
|
+
use_auto_resolution_cache: bool,
|
|
1008
|
+
auto_resolution_cache: AutoResolutionCache,
|
|
1009
|
+
auto_negotiation_hash: str,
|
|
1010
|
+
model_id: str,
|
|
1011
|
+
model_package_id: str,
|
|
1012
|
+
model_architecture: Optional[ModelArchitecture],
|
|
1013
|
+
task_type: TaskType,
|
|
1014
|
+
backend_type: Optional[BackendType],
|
|
1015
|
+
resolved_files: Set[str],
|
|
1016
|
+
model_dependencies: Optional[List[ModelDependency]],
|
|
1017
|
+
) -> None:
|
|
1018
|
+
if not use_auto_resolution_cache:
|
|
1019
|
+
return None
|
|
1020
|
+
cache_content = AutoResolutionCacheEntry(
|
|
1021
|
+
model_id=model_id,
|
|
1022
|
+
model_package_id=model_package_id,
|
|
1023
|
+
resolved_files=resolved_files,
|
|
1024
|
+
model_architecture=model_architecture,
|
|
1025
|
+
task_type=task_type,
|
|
1026
|
+
backend_type=backend_type,
|
|
1027
|
+
created_at=datetime.now(),
|
|
1028
|
+
model_dependencies=model_dependencies,
|
|
1029
|
+
)
|
|
1030
|
+
auto_resolution_cache.register(
|
|
1031
|
+
auto_negotiation_hash=auto_negotiation_hash, cache_entry=cache_content
|
|
1032
|
+
)
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
def generate_shared_blobs_path() -> str:
|
|
1036
|
+
return os.path.join(INFERENCE_HOME, "shared-blobs")
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def generate_model_package_cache_path(model_id: str, package_id: str) -> str:
|
|
1040
|
+
ensure_package_id_is_os_safe(model_id=model_id, package_id=package_id)
|
|
1041
|
+
model_id_slug = slugify_model_id_to_os_safe_format(model_id=model_id)
|
|
1042
|
+
return os.path.join(INFERENCE_HOME, "models-cache", model_id_slug, package_id)
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
def ensure_package_id_is_os_safe(model_id: str, package_id: str) -> None:
|
|
1046
|
+
if re.search(r"[^A-Za-z0-9]", package_id):
|
|
1047
|
+
raise InsecureModelIdentifierError(
|
|
1048
|
+
message=f"Attempted to load model: {model_id} using package ID: {package_id} which "
|
|
1049
|
+
f"has invalid format. ID is expected to contain only ASCII characters and numbers to "
|
|
1050
|
+
f"ensure safety of local cache. If you see this error running your model on Roboflow platform, "
|
|
1051
|
+
f"raise the issue: https://github.com/roboflow/inference/issues. If you are running `inference` "
|
|
1052
|
+
f"outside of the platform, verify that your weights provider keeps the model packages identifiers "
|
|
1053
|
+
f"in the expected format.",
|
|
1054
|
+
help_url="https://TODO",
|
|
1055
|
+
)
|
|
1056
|
+
|
|
1057
|
+
|
|
1058
|
+
def slugify_model_id_to_os_safe_format(model_id: str) -> str:
|
|
1059
|
+
# Only ASCII
|
|
1060
|
+
model_id_slug = re.sub(r"[^A-Za-z0-9_-]+", "-", model_id)
|
|
1061
|
+
# Collapse multiple underscores/dashes
|
|
1062
|
+
model_id_slug = re.sub(r"[_-]{2,}", "-", model_id_slug)
|
|
1063
|
+
if not model_id_slug:
|
|
1064
|
+
model_id_slug = "special-char-only-model-id"
|
|
1065
|
+
if len(model_id_slug) > 48:
|
|
1066
|
+
model_id_slug = model_id_slug[:48]
|
|
1067
|
+
digest = hashlib.blake2s(model_id.encode("utf-8"), digest_size=4).hexdigest()
|
|
1068
|
+
return f"{model_id_slug}-{digest}"
|
|
1069
|
+
|
|
1070
|
+
|
|
1071
|
+
def attempt_loading_model_from_local_storage(
|
|
1072
|
+
model_dir_or_weights_path: str,
|
|
1073
|
+
allow_local_code_packages: bool,
|
|
1074
|
+
model_init_kwargs: dict,
|
|
1075
|
+
model_type: Optional[str] = None,
|
|
1076
|
+
task_type: Optional[str] = None,
|
|
1077
|
+
backend_type: Optional[
|
|
1078
|
+
Union[str, BackendType, List[Union[str, BackendType]]]
|
|
1079
|
+
] = None,
|
|
1080
|
+
) -> AnyModel:
|
|
1081
|
+
if os.path.isfile(model_dir_or_weights_path):
|
|
1082
|
+
return attempt_loading_model_from_checkpoint(
|
|
1083
|
+
checkpoint_path=model_dir_or_weights_path,
|
|
1084
|
+
model_init_kwargs=model_init_kwargs,
|
|
1085
|
+
model_type=model_type,
|
|
1086
|
+
task_type=task_type,
|
|
1087
|
+
backend_type=backend_type,
|
|
1088
|
+
)
|
|
1089
|
+
config_path = os.path.join(model_dir_or_weights_path, MODEL_CONFIG_FILE_NAME)
|
|
1090
|
+
model_config = parse_model_config(config_path=config_path)
|
|
1091
|
+
if model_config.is_library_model():
|
|
1092
|
+
return load_library_model_from_local_dir(
|
|
1093
|
+
model_dir=model_dir_or_weights_path,
|
|
1094
|
+
model_config=model_config,
|
|
1095
|
+
model_init_kwargs=model_init_kwargs,
|
|
1096
|
+
)
|
|
1097
|
+
if not allow_local_code_packages:
|
|
1098
|
+
raise ModelLoadingError(
|
|
1099
|
+
message=f"Attempted to load model from local package with arbitrary code. This is not allowed in "
|
|
1100
|
+
f"this environment. To let inference loading such models, use `allow_local_code_packages=True` "
|
|
1101
|
+
f"parameter of `AutoModel.from_pretrained(...)`. If you see this error while using one of Roboflow "
|
|
1102
|
+
f"hosted solution - contact us to solve the problem.",
|
|
1103
|
+
help_url="https://todo",
|
|
1104
|
+
)
|
|
1105
|
+
return load_model_from_local_package_with_arbitrary_code(
|
|
1106
|
+
model_dir=model_dir_or_weights_path,
|
|
1107
|
+
model_config=model_config,
|
|
1108
|
+
model_init_kwargs=model_init_kwargs,
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
|
|
1112
|
+
def attempt_loading_model_from_checkpoint(
|
|
1113
|
+
checkpoint_path: str,
|
|
1114
|
+
model_init_kwargs: dict,
|
|
1115
|
+
model_type: Optional[str] = None,
|
|
1116
|
+
task_type: Optional[str] = None,
|
|
1117
|
+
backend_type: Optional[
|
|
1118
|
+
Union[str, BackendType, List[Union[str, BackendType]]]
|
|
1119
|
+
] = None,
|
|
1120
|
+
) -> AnyModel:
|
|
1121
|
+
model_architecture, task_type, backend_type = resolve_models_registry_entry(
|
|
1122
|
+
model_type=model_type,
|
|
1123
|
+
task_type=task_type,
|
|
1124
|
+
backend_type=backend_type,
|
|
1125
|
+
)
|
|
1126
|
+
model_init_kwargs["model_type"] = model_type
|
|
1127
|
+
model_class = resolve_model_class(
|
|
1128
|
+
model_architecture=model_architecture,
|
|
1129
|
+
task_type=task_type,
|
|
1130
|
+
backend=backend_type,
|
|
1131
|
+
)
|
|
1132
|
+
return model_class.from_pretrained(checkpoint_path, **model_init_kwargs)
|
|
1133
|
+
|
|
1134
|
+
|
|
1135
|
+
def resolve_models_registry_entry(
|
|
1136
|
+
model_type: Optional[str],
|
|
1137
|
+
task_type: Optional[str] = None,
|
|
1138
|
+
backend_type: Optional[
|
|
1139
|
+
Union[str, BackendType, List[Union[str, BackendType]]]
|
|
1140
|
+
] = None,
|
|
1141
|
+
) -> Tuple[str, str, BackendType]:
|
|
1142
|
+
# TODO: in the future this check will grow in size
|
|
1143
|
+
if not model_type:
|
|
1144
|
+
raise ModelLoadingError(
|
|
1145
|
+
message="When loading model directly from checkpoint path, `model_type` parameter must be specified. "
|
|
1146
|
+
"Use one of the supported value, for example `rfdetr-nano` in case you refer checkpoint of "
|
|
1147
|
+
"RFDetr Nano model.",
|
|
1148
|
+
help_url="https://todo",
|
|
1149
|
+
)
|
|
1150
|
+
if model_type not in MODEL_TYPES_TO_LOAD_FROM_CHECKPOINT:
|
|
1151
|
+
raise ModelLoadingError(
|
|
1152
|
+
message="When loading model directly from checkpoint path, `model_type` parameter must define "
|
|
1153
|
+
"one of the type of model that support loading directly from the checkpoints. "
|
|
1154
|
+
f"Models supported in current version: {MODEL_TYPES_TO_LOAD_FROM_CHECKPOINT}",
|
|
1155
|
+
help_url="https://todo",
|
|
1156
|
+
)
|
|
1157
|
+
# a bit of hard coding here, over time we must maintain
|
|
1158
|
+
model_architecture = "rfdetr"
|
|
1159
|
+
if task_type is None:
|
|
1160
|
+
if model_type == "rfdetr-seg-preview":
|
|
1161
|
+
task_type = INSTANCE_SEGMENTATION_TASK
|
|
1162
|
+
else:
|
|
1163
|
+
task_type = OBJECT_DETECTION_TASK
|
|
1164
|
+
if task_type not in {OBJECT_DETECTION_TASK, INSTANCE_SEGMENTATION_TASK}:
|
|
1165
|
+
raise ModelLoadingError(
|
|
1166
|
+
message=f"When loading model directly from checkpoint path, set `model_type` as {model_type} and "
|
|
1167
|
+
f"`task_type` as {task_type}, whereas selected model do only support `{OBJECT_DETECTION_TASK}` "
|
|
1168
|
+
f"task while loading from checkpoint file.",
|
|
1169
|
+
help_url="https://todo",
|
|
1170
|
+
)
|
|
1171
|
+
if backend_type is None:
|
|
1172
|
+
backend_type = BackendType.TORCH
|
|
1173
|
+
if isinstance(backend_type, list) and len(backend_type) != 1:
|
|
1174
|
+
if len(backend_type) != 1:
|
|
1175
|
+
raise ModelLoadingError(
|
|
1176
|
+
message=f"When loading model directly from checkpoint path, set `backend` parameter to be {backend_type}, "
|
|
1177
|
+
f"whereas it is only supported to pass a single value.",
|
|
1178
|
+
help_url="https://todo",
|
|
1179
|
+
)
|
|
1180
|
+
backend_type = backend_type[0]
|
|
1181
|
+
if isinstance(backend_type, str):
|
|
1182
|
+
backend_type = parse_backend_type(value=backend_type)
|
|
1183
|
+
if backend_type is not BackendType.TORCH:
|
|
1184
|
+
raise ModelLoadingError(
|
|
1185
|
+
message=f"When loading model directly from checkpoint path, selected the following backend {backend_type}, "
|
|
1186
|
+
f"but the backend supported for model {model_type} is {BackendType.TORCH}",
|
|
1187
|
+
help_url="https://todo",
|
|
1188
|
+
)
|
|
1189
|
+
return model_architecture, task_type, backend_type
|
|
1190
|
+
|
|
1191
|
+
|
|
1192
|
+
def parse_model_config(config_path: str) -> InferenceModelConfig:
|
|
1193
|
+
if not os.path.isfile(config_path):
|
|
1194
|
+
raise ModelLoadingError(
|
|
1195
|
+
message=f"Could not find model config while attempting to load model from "
|
|
1196
|
+
f"local directory. This error may be caused by misconfiguration of model package (lack of config "
|
|
1197
|
+
f"file), as well as by clash between model_id or model alias and contents of local disc drive which "
|
|
1198
|
+
f"is possible when you have local directory in current dir which has the name colliding with the "
|
|
1199
|
+
f"model you attempt to load. If your intent was to load model from remote backend (not local "
|
|
1200
|
+
f"storage) - verify the contents of $PWD. If you see this problem while using one of Roboflow "
|
|
1201
|
+
f"hosted solutions - contact us to get help.",
|
|
1202
|
+
help_url="https://todo",
|
|
1203
|
+
)
|
|
1204
|
+
try:
|
|
1205
|
+
raw_config = read_json(path=config_path)
|
|
1206
|
+
except ValueError as error:
|
|
1207
|
+
raise CorruptedModelPackageError(
|
|
1208
|
+
message=f"Could not decode model config while attempting to load model from "
|
|
1209
|
+
f"local directory. This error may be caused by corrupted config file. Validate the content of your "
|
|
1210
|
+
f"model package and check in documentation the required format of model config file. "
|
|
1211
|
+
f"If you see this problem while using one of Roboflow hosted solutions - contact us to get help.",
|
|
1212
|
+
help_url="https://todo",
|
|
1213
|
+
) from error
|
|
1214
|
+
if not isinstance(raw_config, dict):
|
|
1215
|
+
raise CorruptedModelPackageError(
|
|
1216
|
+
message=f"While loading the model from local directory encountered corrupted model config file - config is "
|
|
1217
|
+
f"supposed to be a dictionary, instead decoded object of type: "
|
|
1218
|
+
f"{type(raw_config)}. If you see this problem while using one of Roboflow hosted solutions - "
|
|
1219
|
+
f"contact us to get help. Otherwise - verify the content of your model config.",
|
|
1220
|
+
help_url="https://todo",
|
|
1221
|
+
)
|
|
1222
|
+
backend_type = None
|
|
1223
|
+
if "backend_type" in raw_config:
|
|
1224
|
+
raw_backend_type = raw_config["backend_type"]
|
|
1225
|
+
try:
|
|
1226
|
+
backend_type = BackendType(raw_backend_type)
|
|
1227
|
+
except ValueError as e:
|
|
1228
|
+
raise CorruptedModelPackageError(
|
|
1229
|
+
message=f"While loading the model from local directory encountered corrupted model config "
|
|
1230
|
+
"- declared `backend_type` ({raw_backend_type}) is not supported by inference. "
|
|
1231
|
+
f"Supported values: {list(t.value for t in BackendType)}. If you see this problem while using "
|
|
1232
|
+
f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify the content "
|
|
1233
|
+
f"of your model config.",
|
|
1234
|
+
help_url="https://todo",
|
|
1235
|
+
) from e
|
|
1236
|
+
return InferenceModelConfig(
|
|
1237
|
+
model_architecture=raw_config.get("model_architecture"),
|
|
1238
|
+
task_type=raw_config.get("task_type"),
|
|
1239
|
+
backend_type=backend_type,
|
|
1240
|
+
model_module=raw_config.get("model_module"),
|
|
1241
|
+
model_class=raw_config.get("model_class"),
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1244
|
+
|
|
1245
|
+
def load_library_model_from_local_dir(
|
|
1246
|
+
model_dir: str,
|
|
1247
|
+
model_config: InferenceModelConfig,
|
|
1248
|
+
model_init_kwargs: dict,
|
|
1249
|
+
) -> AnyModel:
|
|
1250
|
+
model_class = resolve_model_class(
|
|
1251
|
+
model_architecture=model_config.model_architecture,
|
|
1252
|
+
task_type=model_config.task_type,
|
|
1253
|
+
backend=model_config.backend_type,
|
|
1254
|
+
)
|
|
1255
|
+
return model_class.from_pretrained(model_dir, **model_init_kwargs)
|
|
1256
|
+
|
|
1257
|
+
|
|
1258
|
+
def load_model_from_local_package_with_arbitrary_code(
|
|
1259
|
+
model_dir: str,
|
|
1260
|
+
model_config: InferenceModelConfig,
|
|
1261
|
+
model_init_kwargs: dict,
|
|
1262
|
+
) -> AnyModel:
|
|
1263
|
+
if model_config.model_module is None or model_config.model_class is None:
|
|
1264
|
+
raise CorruptedModelPackageError(
|
|
1265
|
+
message=f"While loading the model from local directory encountered corrupted model config file. "
|
|
1266
|
+
f"Config does not specify neither `model_module` name nor `model_class`, which are both "
|
|
1267
|
+
f"required to load models provided with arbitrary code. If you see this problem while using "
|
|
1268
|
+
f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify the content "
|
|
1269
|
+
f"of your model config.",
|
|
1270
|
+
help_url="https://todo",
|
|
1271
|
+
)
|
|
1272
|
+
model_module_path = os.path.join(model_dir, model_config.model_module)
|
|
1273
|
+
if not os.path.isfile(model_module_path):
|
|
1274
|
+
raise CorruptedModelPackageError(
|
|
1275
|
+
message=f"While loading the model from local directory encountered corrupted model config file. "
|
|
1276
|
+
f"Config pointed module {model_config.model_module}, but there is no file under "
|
|
1277
|
+
f"{model_module_path}. If you see this problem while using "
|
|
1278
|
+
f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify the content "
|
|
1279
|
+
f"of your model config.",
|
|
1280
|
+
help_url="https://todo",
|
|
1281
|
+
)
|
|
1282
|
+
model_class = load_class_from_path(
|
|
1283
|
+
module_path=model_module_path, class_name=model_config.model_class
|
|
1284
|
+
)
|
|
1285
|
+
return model_class.from_pretrained(model_dir, **model_init_kwargs)
|
|
1286
|
+
|
|
1287
|
+
|
|
1288
|
+
def load_class_from_path(module_path: str, class_name: str) -> AnyModel:
|
|
1289
|
+
if not os.path.exists(module_path):
|
|
1290
|
+
raise CorruptedModelPackageError(
|
|
1291
|
+
message=f"When loading local model with arbitrary code, encountered issue with loading the module. "
|
|
1292
|
+
"Could find the module under the path specified in model config. If you see this problem "
|
|
1293
|
+
f"while using one of Roboflow hosted solutions - contact us to get help. Otherwise - verify your "
|
|
1294
|
+
f"model package checking if you can load the module with model implementation within your "
|
|
1295
|
+
f"python environment.",
|
|
1296
|
+
help_url="https://todo",
|
|
1297
|
+
)
|
|
1298
|
+
module_name = os.path.splitext(os.path.basename(module_path))[0]
|
|
1299
|
+
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
|
1300
|
+
if spec is None:
|
|
1301
|
+
raise CorruptedModelPackageError(
|
|
1302
|
+
message=f"When loading local model with arbitrary code, encountered issue with loading the module. "
|
|
1303
|
+
"Could not build module specification. If you see this problem while using "
|
|
1304
|
+
f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify your "
|
|
1305
|
+
f"model package checking if you can load the module with model implementation within your "
|
|
1306
|
+
f"python environment.",
|
|
1307
|
+
help_url="https://todo",
|
|
1308
|
+
)
|
|
1309
|
+
module = importlib.util.module_from_spec(spec)
|
|
1310
|
+
loader = spec.loader
|
|
1311
|
+
if loader is None or not hasattr(loader, "exec_module"):
|
|
1312
|
+
raise CorruptedModelPackageError(
|
|
1313
|
+
message=f"When loading local model with arbitrary code, encountered issue with loading the module. "
|
|
1314
|
+
"Could not execute module loader. If you see this problem while using "
|
|
1315
|
+
f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify your "
|
|
1316
|
+
f"model package checking if you can load the module with model implementation within your "
|
|
1317
|
+
f"python environment.",
|
|
1318
|
+
help_url="https://todo",
|
|
1319
|
+
)
|
|
1320
|
+
try:
|
|
1321
|
+
loader.exec_module(module)
|
|
1322
|
+
except Exception as error:
|
|
1323
|
+
raise CorruptedModelPackageError(
|
|
1324
|
+
message=f"When loading local model with arbitrary code, encountered issue executing the module code "
|
|
1325
|
+
f"to retrieve model class. Details of the error: {error}. If you see this problem while using "
|
|
1326
|
+
f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify your "
|
|
1327
|
+
f"model package checking if you can load the module with model implementation within your "
|
|
1328
|
+
f"python environment.",
|
|
1329
|
+
help_url="https://todo",
|
|
1330
|
+
)
|
|
1331
|
+
if not hasattr(module, class_name):
|
|
1332
|
+
raise CorruptedModelPackageError(
|
|
1333
|
+
message=f"When loading local model with arbitrary code, encountered issue with loading the module. "
|
|
1334
|
+
f"Module `{module_name}` has no class `{class_name}`. If you see this problem while using "
|
|
1335
|
+
f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify your "
|
|
1336
|
+
f"model package checking if you can load the module with model implementation within your "
|
|
1337
|
+
f"python environment. It may also be the case that configuration file of the model points "
|
|
1338
|
+
f"to invalid class name.",
|
|
1339
|
+
help_url="https://todo",
|
|
1340
|
+
)
|
|
1341
|
+
return getattr(module, class_name)
|