inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1341 @@
1
+ import hashlib
2
+ import importlib
3
+ import importlib.util
4
+ import os.path
5
+ import re
6
+ from datetime import datetime
7
+ from functools import partial
8
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
9
+
10
+ import torch
11
+ from argon2 import PasswordHasher
12
+ from filelock import FileLock
13
+ from rich.console import Console
14
+ from rich.text import Text
15
+
16
+ from inference_models.configuration import DEFAULT_DEVICE, INFERENCE_HOME
17
+ from inference_models.errors import (
18
+ CorruptedModelPackageError,
19
+ DirectLocalStorageAccessError,
20
+ InsecureModelIdentifierError,
21
+ ModelLoadingError,
22
+ NoModelPackagesAvailableError,
23
+ UnauthorizedModelAccessError,
24
+ )
25
+ from inference_models.logger import LOGGER, verbose_info
26
+ from inference_models.models.auto_loaders.access_manager import (
27
+ AccessIdentifiers,
28
+ LiberalModelAccessManager,
29
+ ModelAccessManager,
30
+ )
31
+ from inference_models.models.auto_loaders.auto_negotiation import (
32
+ negotiate_model_packages,
33
+ parse_backend_type,
34
+ )
35
+ from inference_models.models.auto_loaders.auto_resolution_cache import (
36
+ AutoResolutionCache,
37
+ AutoResolutionCacheEntry,
38
+ BaseAutoLoadMetadataCache,
39
+ )
40
+ from inference_models.models.auto_loaders.constants import (
41
+ MODEL_DEPENDENCIES_KEY,
42
+ MODEL_DEPENDENCIES_SUB_DIR,
43
+ )
44
+ from inference_models.models.auto_loaders.dependency_models import (
45
+ prepare_dependency_model_parameters,
46
+ )
47
+ from inference_models.models.auto_loaders.entities import (
48
+ MODEL_CONFIG_FILE_NAME,
49
+ AnyModel,
50
+ BackendType,
51
+ InferenceModelConfig,
52
+ ModelArchitecture,
53
+ TaskType,
54
+ )
55
+ from inference_models.models.auto_loaders.models_registry import (
56
+ INSTANCE_SEGMENTATION_TASK,
57
+ OBJECT_DETECTION_TASK,
58
+ resolve_model_class,
59
+ )
60
+ from inference_models.models.auto_loaders.presentation_utils import (
61
+ calculate_artefacts_size,
62
+ calculate_size_of_all_model_packages_artefacts,
63
+ render_model_package_details_table,
64
+ render_runtime_x_ray,
65
+ render_table_with_model_overview,
66
+ render_table_with_model_packages,
67
+ )
68
+ from inference_models.runtime_introspection.core import x_ray_runtime_environment
69
+ from inference_models.utils.download import FileHandle, download_files_to_directory
70
+ from inference_models.utils.file_system import dump_json, read_json
71
+ from inference_models.utils.hashing import hash_dict_content
72
+ from inference_models.weights_providers.core import get_model_from_provider
73
+ from inference_models.weights_providers.entities import (
74
+ ModelDependency,
75
+ ModelPackageMetadata,
76
+ Quantization,
77
+ )
78
+
79
+ MODEL_TYPES_TO_LOAD_FROM_CHECKPOINT = {
80
+ "rfdetr-base",
81
+ "rfdetr-small",
82
+ "rfdetr-medium",
83
+ "rfdetr-nano",
84
+ "rfdetr-large",
85
+ "rfdetr-seg-preview",
86
+ }
87
+
88
+ DEFAULT_KWARGS_PARAMS_TO_BE_FORWARDED_TO_DEPENDENT_MODELS = [
89
+ "owlv2_enforce_model_compilation",
90
+ "owlv2_class_embeddings_cache",
91
+ "owlv2_images_embeddings_cache",
92
+ ]
93
+
94
+
95
+ class AutoModel:
96
+
97
+ @classmethod
98
+ def describe_model(
99
+ cls,
100
+ model_id: str,
101
+ weights_provider: str = "roboflow",
102
+ api_key: Optional[str] = None,
103
+ pull_artefacts_size: bool = False,
104
+ ) -> None:
105
+ model_metadata = get_model_from_provider(
106
+ provider=weights_provider,
107
+ model_id=model_id,
108
+ api_key=api_key,
109
+ )
110
+ model_packages_size = None
111
+ if pull_artefacts_size:
112
+ model_packages_size = calculate_size_of_all_model_packages_artefacts(
113
+ model_packages=model_metadata.model_packages
114
+ )
115
+ console = Console()
116
+ model_overview_table = render_table_with_model_overview(
117
+ model_id=model_metadata.model_id,
118
+ requested_model_id=model_id,
119
+ model_architecture=model_metadata.model_architecture,
120
+ model_variant=model_metadata.model_variant,
121
+ task_type=model_metadata.task_type,
122
+ weights_provider=weights_provider,
123
+ registered_packages=len(model_metadata.model_packages),
124
+ model_dependencies=model_metadata.model_dependencies,
125
+ )
126
+ console.print(model_overview_table)
127
+ console.print("\n")
128
+ packages_overview_table = render_table_with_model_packages(
129
+ model_packages=model_metadata.model_packages,
130
+ model_packages_size=model_packages_size,
131
+ )
132
+ console.print(packages_overview_table)
133
+ text = Text.assemble(
134
+ ("\nWant to check more details about specific package?", "bold"),
135
+ "\nUse AutoModel.describe_model_package('model_id', 'package_id').",
136
+ )
137
+ console.print(text)
138
+ if not pull_artefacts_size:
139
+ text = Text.assemble(
140
+ ("\nWant to verify the size of model package?", "bold"),
141
+ "\nUse AutoModel.describe_model('model_id', pull_artefacts_size=True) - the execution will be "
142
+ "slightly longer, as we must collect the size of all elements of model package.",
143
+ )
144
+ console.print(text)
145
+
146
+ @classmethod
147
+ def describe_model_package(
148
+ cls,
149
+ model_id: str,
150
+ package_id: str,
151
+ weights_provider: str = "roboflow",
152
+ api_key: Optional[str] = None,
153
+ pull_artefacts_size: bool = True,
154
+ ) -> None:
155
+ model_metadata = get_model_from_provider(
156
+ provider=weights_provider,
157
+ model_id=model_id,
158
+ api_key=api_key,
159
+ )
160
+ selected_package = None
161
+ for package in model_metadata.model_packages:
162
+ if package.package_id == package_id:
163
+ selected_package = package
164
+ if selected_package is None:
165
+ raise NoModelPackagesAvailableError(
166
+ message=f"Selected model package {package_id} does not exist for model {model_id}. Make sure provided "
167
+ f"value is valid.",
168
+ help_url="https://todo",
169
+ )
170
+ artefacts_size = None
171
+ if pull_artefacts_size:
172
+ artefacts_size = calculate_artefacts_size(
173
+ package_artefacts=selected_package.package_artefacts
174
+ )
175
+ table = render_model_package_details_table(
176
+ model_id=model_metadata.model_id,
177
+ requested_model_id=model_id,
178
+ artefacts_size=artefacts_size,
179
+ model_package=selected_package,
180
+ )
181
+ console = Console()
182
+ console.print(table)
183
+ if not pull_artefacts_size:
184
+ text = Text.assemble(
185
+ ("\nWant to verify the size of model package?", "bold"),
186
+ "\nUse AutoModel.describe_model_package('model_id', 'package_id', pull_artefacts_size=True)"
187
+ "- the execution will be slightly longer, as we must collect the size of all elements of model package.",
188
+ )
189
+ console.print(text)
190
+
191
+ @classmethod
192
+ def describe_compute_environment(cls) -> None:
193
+ runtime_x_ray = x_ray_runtime_environment()
194
+ table = render_runtime_x_ray(runtime_x_ray=runtime_x_ray)
195
+ console = Console()
196
+ console.print(table)
197
+
198
+ @classmethod
199
+ def from_pretrained(
200
+ cls,
201
+ model_id_or_path: str,
202
+ weights_provider: str = "roboflow",
203
+ api_key: Optional[str] = None,
204
+ model_package_id: Optional[str] = None,
205
+ backend: Optional[
206
+ Union[str, BackendType, List[Union[str, BackendType]]]
207
+ ] = None,
208
+ batch_size: Optional[Union[int, Tuple[int, int]]] = None,
209
+ quantization: Optional[
210
+ Union[str, Quantization, List[Union[str, Quantization]]]
211
+ ] = None,
212
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
213
+ device: torch.device = DEFAULT_DEVICE,
214
+ default_onnx_trt_options: bool = True,
215
+ max_package_loading_attempts: Optional[int] = None,
216
+ verbose: bool = False,
217
+ model_download_file_lock_acquire_timeout: int = 10,
218
+ allow_untrusted_packages: bool = False,
219
+ trt_engine_host_code_allowed: bool = True,
220
+ allow_local_code_packages: bool = True,
221
+ verify_hash_while_download: bool = True,
222
+ download_files_without_hash: bool = False,
223
+ use_auto_resolution_cache: bool = True,
224
+ auto_resolution_cache: Optional[AutoResolutionCache] = None,
225
+ allow_direct_local_storage_loading: bool = True,
226
+ model_access_manager: Optional[ModelAccessManager] = None,
227
+ nms_fusion_preferences: Optional[Union[bool, dict]] = None,
228
+ model_type: Optional[str] = None,
229
+ task_type: Optional[str] = None,
230
+ allow_loading_dependency_models: bool = True,
231
+ dependency_models_params: Optional[dict] = None,
232
+ point_model_directory: Optional[Callable[[str], None]] = None,
233
+ forwarded_kwargs: Optional[List[str]] = None,
234
+ **kwargs,
235
+ ) -> AnyModel:
236
+ if model_access_manager is None:
237
+ model_access_manager = LiberalModelAccessManager()
238
+ if model_access_manager.is_model_access_forbidden(
239
+ model_id=model_id_or_path, api_key=api_key
240
+ ):
241
+ raise UnauthorizedModelAccessError(
242
+ message=f"Unauthorized not access model with ID: {model_package_id}. Are you sure you use valid "
243
+ f"API key? The default weights provider is Roboflow - see Roboflow authentication details: "
244
+ f"https://docs.roboflow.com/api-reference/authentication "
245
+ f"and export key to `ROBOFLOW_API_KEY` environment variable. If you use custom weights "
246
+ f"provider - verify access constraints relevant for the provider.",
247
+ help_url="https://todo",
248
+ )
249
+ if auto_resolution_cache is None:
250
+
251
+ def register_file_created_for_model_package(
252
+ file_path: str, model_id: str, package_id: str
253
+ ) -> None:
254
+ access_identifiers = AccessIdentifiers(
255
+ model_id=model_id,
256
+ package_id=package_id,
257
+ api_key=api_key,
258
+ )
259
+ model_access_manager.on_file_created(
260
+ file_path=file_path,
261
+ access_identifiers=access_identifiers,
262
+ )
263
+
264
+ auto_resolution_cache = BaseAutoLoadMetadataCache(
265
+ file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
266
+ verbose=verbose,
267
+ on_file_created=register_file_created_for_model_package,
268
+ on_file_deleted=model_access_manager.on_file_deleted,
269
+ )
270
+ model_init_kwargs = {
271
+ "onnx_execution_providers": onnx_execution_providers,
272
+ "device": device,
273
+ "default_onnx_trt_options": default_onnx_trt_options,
274
+ "engine_host_code_allowed": trt_engine_host_code_allowed,
275
+ }
276
+ model_init_kwargs.update(kwargs)
277
+ if not os.path.exists(model_id_or_path):
278
+ # QUESTION: is it enough to assume presence of local dir as the intent to load
279
+ # model from disc drive? What if we have clash of model id / model alias with
280
+ # contents of someone's local drive - shall we then try to load from both sources?
281
+ # that still may end up with ambiguous behaviour - probably the solution would be
282
+ # to require prefix like file://... to denote the intent of loading model from local
283
+ # drive?
284
+ if api_key is not None:
285
+ password_hashed = PasswordHasher()
286
+ api_key_hash = password_hashed.hash(api_key)
287
+ else:
288
+ api_key_hash = api_key
289
+ auto_negotiation_hash = hash_dict_content(
290
+ content={
291
+ "provider": weights_provider,
292
+ "model_id": model_id_or_path,
293
+ "api_key": api_key_hash,
294
+ "requested_model_package_id": model_package_id,
295
+ "requested_backends": backend,
296
+ "requested_batch_size": batch_size,
297
+ "requested_quantization": quantization,
298
+ "device": str(device),
299
+ "onnx_execution_providers": onnx_execution_providers,
300
+ "allow_untrusted_packages": allow_untrusted_packages,
301
+ "trt_engine_host_code_allowed": trt_engine_host_code_allowed,
302
+ "nms_fusion_preferences": nms_fusion_preferences,
303
+ }
304
+ )
305
+ model_from_access_manager = model_access_manager.retrieve_model_instance(
306
+ model_id=model_id_or_path,
307
+ package_id=model_package_id,
308
+ api_key=api_key,
309
+ loading_parameter_digest=auto_negotiation_hash,
310
+ )
311
+ if model_from_access_manager:
312
+ return model_from_access_manager
313
+ if forwarded_kwargs is None:
314
+ forwarded_kwargs = (
315
+ DEFAULT_KWARGS_PARAMS_TO_BE_FORWARDED_TO_DEPENDENT_MODELS
316
+ )
317
+ forwarded_kwargs_values = {
318
+ name: kwargs[name] for name in forwarded_kwargs if name in kwargs
319
+ }
320
+ model_from_cache = attempt_loading_model_with_auto_load_cache(
321
+ use_auto_resolution_cache=use_auto_resolution_cache,
322
+ auto_resolution_cache=auto_resolution_cache,
323
+ auto_negotiation_hash=auto_negotiation_hash,
324
+ model_access_manager=model_access_manager,
325
+ model_name_or_path=model_id_or_path,
326
+ model_init_kwargs=model_init_kwargs,
327
+ api_key=api_key,
328
+ allow_loading_dependency_models=allow_loading_dependency_models,
329
+ forwarded_kwargs_values=forwarded_kwargs_values,
330
+ verbose=verbose,
331
+ weights_provider=weights_provider,
332
+ max_package_loading_attempts=max_package_loading_attempts,
333
+ model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
334
+ allow_untrusted_packages=allow_untrusted_packages,
335
+ trt_engine_host_code_allowed=trt_engine_host_code_allowed,
336
+ allow_local_code_packages=allow_local_code_packages,
337
+ verify_hash_while_download=verify_hash_while_download,
338
+ download_files_without_hash=download_files_without_hash,
339
+ allow_direct_local_storage_loading=allow_direct_local_storage_loading,
340
+ dependency_models_params=dependency_models_params,
341
+ )
342
+ if model_from_cache:
343
+ return model_from_cache
344
+ try:
345
+ model_metadata = get_model_from_provider(
346
+ provider=weights_provider,
347
+ model_id=model_id_or_path,
348
+ api_key=api_key,
349
+ )
350
+ if (
351
+ model_metadata.model_dependencies
352
+ and not allow_loading_dependency_models
353
+ ):
354
+ raise CorruptedModelPackageError(
355
+ message=f"Could not load model {model_id_or_path} as it defines another models which are "
356
+ f"it's dependency, but the auto-loader prevents loading dependencies at certain "
357
+ f"nesting depth to avoid excessive resolution procedure. This is a limitation of "
358
+ f"current implementation. Provide us the context of your use-case to get help.",
359
+ help_url="https://todo",
360
+ )
361
+ if model_metadata.model_id != model_id_or_path:
362
+ model_access_manager.on_model_alias_discovered(
363
+ alias=model_id_or_path,
364
+ model_id=model_metadata.model_id,
365
+ )
366
+ model_dependencies = model_metadata.model_dependencies or []
367
+ for model_dependency in model_dependencies:
368
+ model_access_manager.on_model_dependency_discovered(
369
+ base_model_id=model_dependency.model_id,
370
+ base_model_package_id=model_dependency.model_package_id,
371
+ dependent_model_id=model_metadata.model_id,
372
+ )
373
+ for model_package in model_metadata.model_packages:
374
+ package_access_identifiers = AccessIdentifiers(
375
+ model_id=model_metadata.model_id,
376
+ package_id=model_package.package_id,
377
+ api_key=api_key,
378
+ )
379
+ model_access_manager.on_model_package_access_granted(
380
+ package_access_identifiers
381
+ )
382
+ except UnauthorizedModelAccessError as error:
383
+ model_access_manager.on_model_access_forbidden(
384
+ model_id=model_id_or_path, api_key=api_key
385
+ )
386
+ raise error
387
+ # here we verify if de-aliasing or access confirmation from auth master changed something
388
+ model_from_access_manager = model_access_manager.retrieve_model_instance(
389
+ model_id=model_id_or_path,
390
+ package_id=model_package_id,
391
+ api_key=api_key,
392
+ loading_parameter_digest=auto_negotiation_hash,
393
+ )
394
+ if model_from_access_manager:
395
+ return model_from_access_manager
396
+ matching_model_packages = negotiate_model_packages(
397
+ model_architecture=model_metadata.model_architecture,
398
+ task_type=model_metadata.task_type,
399
+ model_packages=model_metadata.model_packages,
400
+ requested_model_package_id=model_package_id,
401
+ requested_backends=backend,
402
+ requested_batch_size=batch_size,
403
+ requested_quantization=quantization,
404
+ device=device,
405
+ onnx_execution_providers=onnx_execution_providers,
406
+ allow_untrusted_packages=allow_untrusted_packages,
407
+ trt_engine_host_code_allowed=trt_engine_host_code_allowed,
408
+ nms_fusion_preferences=nms_fusion_preferences,
409
+ verbose=verbose,
410
+ )
411
+ model_dependencies_instances = {}
412
+ model_dependencies_directories = {}
413
+ dependency_models_params = dependency_models_params or {}
414
+ for model_dependency in model_dependencies:
415
+ dependency_params = dependency_models_params.get(
416
+ model_dependency.name, {}
417
+ )
418
+ dependency_params["model_id_or_path"] = model_dependency.model_id
419
+ dependency_params["model_package_id"] = (
420
+ model_dependency.model_package_id
421
+ )
422
+ resolved_model_parameters = prepare_dependency_model_parameters(
423
+ model_parameters=dependency_params
424
+ )
425
+ verbose_info(
426
+ message=f"Initialising dependent model: {model_dependency.model_id}",
427
+ verbose_requested=verbose,
428
+ )
429
+
430
+ def model_directory_pointer(model_dir: str) -> None:
431
+ model_dependencies_directories[model_dependency.name] = model_dir
432
+
433
+ for name, value in forwarded_kwargs_values.items():
434
+ if name not in resolved_model_parameters.model_extra:
435
+ resolved_model_parameters.model_extra[name] = value
436
+
437
+ dependency_instance = AutoModel.from_pretrained(
438
+ model_id_or_path=resolved_model_parameters.model_id_or_path,
439
+ weights_provider=weights_provider,
440
+ api_key=api_key,
441
+ model_package_id=resolved_model_parameters.model_package_id,
442
+ backend=resolved_model_parameters.backend,
443
+ batch_size=resolved_model_parameters.batch_size,
444
+ quantization=resolved_model_parameters.quantization,
445
+ onnx_execution_providers=resolved_model_parameters.onnx_execution_providers,
446
+ device=resolved_model_parameters.device,
447
+ default_onnx_trt_options=resolved_model_parameters.default_onnx_trt_options,
448
+ max_package_loading_attempts=max_package_loading_attempts,
449
+ verbose=verbose,
450
+ model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
451
+ allow_untrusted_packages=allow_untrusted_packages,
452
+ trt_engine_host_code_allowed=trt_engine_host_code_allowed,
453
+ allow_local_code_packages=allow_local_code_packages,
454
+ verify_hash_while_download=verify_hash_while_download,
455
+ download_files_without_hash=download_files_without_hash,
456
+ use_auto_resolution_cache=use_auto_resolution_cache,
457
+ auto_resolution_cache=auto_resolution_cache,
458
+ allow_direct_local_storage_loading=allow_direct_local_storage_loading,
459
+ model_access_manager=model_access_manager,
460
+ nms_fusion_preferences=resolved_model_parameters.nms_fusion_preferences,
461
+ model_type=resolved_model_parameters.model_type,
462
+ task_type=resolved_model_parameters.task_type,
463
+ allow_loading_dependency_models=False,
464
+ dependency_models_params=None,
465
+ point_model_directory=model_directory_pointer,
466
+ **resolved_model_parameters.kwargs,
467
+ )
468
+ model_dependencies_instances[model_dependency.name] = (
469
+ dependency_instance
470
+ )
471
+
472
+ return attempt_loading_matching_model_packages(
473
+ model_id=model_id_or_path,
474
+ model_architecture=model_metadata.model_architecture,
475
+ task_type=model_metadata.task_type,
476
+ matching_model_packages=matching_model_packages,
477
+ model_init_kwargs=model_init_kwargs,
478
+ model_access_manager=model_access_manager,
479
+ auto_negotiation_hash=auto_negotiation_hash,
480
+ api_key=api_key,
481
+ model_dependencies=model_metadata.model_dependencies,
482
+ model_dependencies_instances=model_dependencies_instances,
483
+ model_dependencies_directories=model_dependencies_directories,
484
+ max_package_loading_attempts=max_package_loading_attempts,
485
+ model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
486
+ verify_hash_while_download=verify_hash_while_download,
487
+ download_files_without_hash=download_files_without_hash,
488
+ auto_resolution_cache=auto_resolution_cache,
489
+ use_auto_resolution_cache=use_auto_resolution_cache,
490
+ point_model_directory=point_model_directory,
491
+ verbose=verbose,
492
+ )
493
+ if not allow_direct_local_storage_loading:
494
+ raise DirectLocalStorageAccessError(
495
+ message="Attempted to load model directly pointing local path, rather than model ID. This "
496
+ "operation is forbidden as AutoModel.from_pretrained(...) was used with "
497
+ "`allow_direct_local_storage_loading=False`. If you are running `inference-models` outside Roboflow "
498
+ "hosted solutions - verify your setup. If you see this error on Roboflow platform - this "
499
+ "feature was disabled for security reason. In rare cases when you use valid model ID, the "
500
+ "clash of ID with local path may cause this error - we ask you to report the issue here: "
501
+ "https://github.com/roboflow/inference/issues.",
502
+ help_url="https://todo",
503
+ )
504
+ return attempt_loading_model_from_local_storage(
505
+ model_dir_or_weights_path=model_id_or_path,
506
+ allow_local_code_packages=allow_local_code_packages,
507
+ model_init_kwargs=model_init_kwargs,
508
+ model_type=model_type,
509
+ task_type=task_type,
510
+ backend_type=backend,
511
+ )
512
+
513
+
514
+ def attempt_loading_model_with_auto_load_cache(
515
+ use_auto_resolution_cache: bool,
516
+ auto_resolution_cache: AutoResolutionCache,
517
+ auto_negotiation_hash: str,
518
+ model_access_manager: ModelAccessManager,
519
+ model_name_or_path: str,
520
+ model_init_kwargs: dict,
521
+ api_key: Optional[str],
522
+ allow_loading_dependency_models: bool,
523
+ forwarded_kwargs_values: Dict[str, Any],
524
+ verbose: bool = False,
525
+ weights_provider: str = "roboflow",
526
+ max_package_loading_attempts: Optional[int] = None,
527
+ model_download_file_lock_acquire_timeout: int = 10,
528
+ allow_untrusted_packages: bool = False,
529
+ trt_engine_host_code_allowed: bool = True,
530
+ allow_local_code_packages: bool = True,
531
+ verify_hash_while_download: bool = True,
532
+ download_files_without_hash: bool = False,
533
+ allow_direct_local_storage_loading: bool = True,
534
+ dependency_models_params: Optional[dict] = None,
535
+ ) -> Optional[AnyModel]:
536
+ if not use_auto_resolution_cache:
537
+ return None
538
+ verbose_info(
539
+ message=f"Attempt to load model {model_name_or_path} using auto-load cache.",
540
+ verbose_requested=verbose,
541
+ )
542
+ cache_entry = auto_resolution_cache.retrieve(
543
+ auto_negotiation_hash=auto_negotiation_hash
544
+ )
545
+ if cache_entry is None:
546
+ verbose_info(
547
+ message=f"Could not find auto-load cache for model {model_name_or_path}.",
548
+ verbose_requested=verbose,
549
+ )
550
+ return None
551
+ if not model_access_manager.is_model_package_access_granted(
552
+ model_id=cache_entry.model_id,
553
+ package_id=cache_entry.model_package_id,
554
+ api_key=api_key,
555
+ ):
556
+ return None
557
+ if not all_files_exist(files=cache_entry.resolved_files):
558
+ verbose_info(
559
+ message=f"Could not find all required files denoted in auto-load cache for model {model_name_or_path}.",
560
+ verbose_requested=verbose,
561
+ )
562
+ return None
563
+ try:
564
+ model_dependencies = cache_entry.model_dependencies or []
565
+ if model_dependencies and not allow_loading_dependency_models:
566
+ raise CorruptedModelPackageError(
567
+ message=f"Could not load model {cache_entry.model_id} as it defines another models which are "
568
+ f"it's dependency, but the auto-loader prevents loading dependencies at certain "
569
+ f"nesting depth to avoid excessive resolution procedure. This is a limitation of "
570
+ f"current implementation. Provide us the context of your use-case to get help.",
571
+ help_url="https://todo",
572
+ )
573
+ model_dependencies_instances = {}
574
+ dependency_models_params = dependency_models_params or {}
575
+ for model_dependency in model_dependencies:
576
+ dependency_params = dependency_models_params.get(model_dependency.name, {})
577
+ dependency_params["model_id_or_path"] = model_dependency.model_id
578
+ dependency_params["model_package_id"] = model_dependency.model_package_id
579
+ resolved_model_parameters = prepare_dependency_model_parameters(
580
+ model_parameters=dependency_params
581
+ )
582
+
583
+ for name, value in forwarded_kwargs_values.items():
584
+ if name not in resolved_model_parameters.model_extra:
585
+ resolved_model_parameters.model_extra[name] = value
586
+ verbose_info(
587
+ message=f"Initialising dependent model: {model_dependency.model_id}",
588
+ verbose_requested=verbose,
589
+ )
590
+ dependency_instance = AutoModel.from_pretrained(
591
+ model_id_or_path=resolved_model_parameters.model_id_or_path,
592
+ weights_provider=weights_provider,
593
+ api_key=api_key,
594
+ model_package_id=resolved_model_parameters.model_package_id,
595
+ backend=resolved_model_parameters.backend,
596
+ batch_size=resolved_model_parameters.batch_size,
597
+ quantization=resolved_model_parameters.quantization,
598
+ onnx_execution_providers=resolved_model_parameters.onnx_execution_providers,
599
+ device=resolved_model_parameters.device,
600
+ default_onnx_trt_options=resolved_model_parameters.default_onnx_trt_options,
601
+ max_package_loading_attempts=max_package_loading_attempts,
602
+ verbose=verbose,
603
+ model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
604
+ allow_untrusted_packages=allow_untrusted_packages,
605
+ trt_engine_host_code_allowed=trt_engine_host_code_allowed,
606
+ allow_local_code_packages=allow_local_code_packages,
607
+ verify_hash_while_download=verify_hash_while_download,
608
+ download_files_without_hash=download_files_without_hash,
609
+ use_auto_resolution_cache=use_auto_resolution_cache,
610
+ auto_resolution_cache=auto_resolution_cache,
611
+ allow_direct_local_storage_loading=allow_direct_local_storage_loading,
612
+ model_access_manager=model_access_manager,
613
+ nms_fusion_preferences=resolved_model_parameters.nms_fusion_preferences,
614
+ model_type=resolved_model_parameters.model_type,
615
+ task_type=resolved_model_parameters.task_type,
616
+ allow_loading_dependency_models=False,
617
+ dependency_models_params=None,
618
+ **resolved_model_parameters.kwargs,
619
+ )
620
+ model_dependencies_instances[model_dependency.name] = dependency_instance
621
+ model_class = resolve_model_class(
622
+ model_architecture=cache_entry.model_architecture,
623
+ task_type=cache_entry.task_type,
624
+ backend=cache_entry.backend_type,
625
+ )
626
+ model_package_cache_dir = generate_model_package_cache_path(
627
+ model_id=cache_entry.model_id,
628
+ package_id=cache_entry.model_package_id,
629
+ )
630
+ model_init_kwargs[MODEL_DEPENDENCIES_KEY] = model_dependencies_instances
631
+ model = model_class.from_pretrained(
632
+ model_package_cache_dir, **model_init_kwargs
633
+ )
634
+ verbose_info(
635
+ message=f"Successfully loaded model {model_name_or_path} using auto-loading cache.",
636
+ verbose_requested=verbose,
637
+ )
638
+ return model
639
+ except Exception as error:
640
+ LOGGER.warning(
641
+ f"Encountered error {error} of type {type(error)} when attempted to load model using "
642
+ f"auto-load cache. This may indicate corrupted cache of inference bug. Contact Roboflow submitting "
643
+ f"issue under: https://github.com/roboflow/inference/issues/"
644
+ )
645
+ auto_resolution_cache.invalidate(auto_negotiation_hash=auto_negotiation_hash)
646
+ return None
647
+
648
+
649
+ def all_files_exist(files: List[str]) -> bool:
650
+ return all(os.path.exists(f) for f in files)
651
+
652
+
653
+ def attempt_loading_matching_model_packages(
654
+ model_id: str,
655
+ model_architecture: ModelArchitecture,
656
+ task_type: Optional[TaskType],
657
+ matching_model_packages: List[ModelPackageMetadata],
658
+ model_init_kwargs: dict,
659
+ model_access_manager: ModelAccessManager,
660
+ auto_resolution_cache: AutoResolutionCache,
661
+ auto_negotiation_hash: str,
662
+ api_key: Optional[str],
663
+ model_dependencies: Optional[List[ModelDependency]],
664
+ model_dependencies_instances: Dict[str, AnyModel],
665
+ model_dependencies_directories: Dict[str, str],
666
+ max_package_loading_attempts: Optional[int] = None,
667
+ model_download_file_lock_acquire_timeout: int = 10,
668
+ verbose: bool = True,
669
+ verify_hash_while_download: bool = True,
670
+ download_files_without_hash: bool = False,
671
+ use_auto_resolution_cache: bool = True,
672
+ point_model_directory: Optional[Callable[[str], None]] = None,
673
+ ) -> AnyModel:
674
+ if max_package_loading_attempts is not None:
675
+ matching_model_packages = matching_model_packages[:max_package_loading_attempts]
676
+ if not matching_model_packages:
677
+ raise ModelLoadingError(
678
+ message=f"Cannot load model {model_id} - no matching model package candidates for given model "
679
+ f"running in this environment.",
680
+ help_url="https://todo",
681
+ )
682
+ failed_load_attempts: List[Tuple[str, Exception]] = []
683
+ for model_package in matching_model_packages:
684
+ access_identifiers = AccessIdentifiers(
685
+ model_id=model_id,
686
+ package_id=model_package.package_id,
687
+ api_key=api_key,
688
+ )
689
+ verbose_info(
690
+ message=f"Attempt to load model package: {model_package.get_summary()}",
691
+ verbose_requested=verbose,
692
+ )
693
+ try:
694
+ model, model_package_cache_dir = initialize_model(
695
+ model_id=model_id,
696
+ model_architecture=model_architecture,
697
+ task_type=task_type,
698
+ model_package=model_package,
699
+ model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
700
+ model_init_kwargs=model_init_kwargs,
701
+ auto_resolution_cache=auto_resolution_cache,
702
+ auto_negotiation_hash=auto_negotiation_hash,
703
+ model_dependencies=model_dependencies,
704
+ model_dependencies_instances=model_dependencies_instances,
705
+ model_dependencies_directories=model_dependencies_directories,
706
+ verify_hash_while_download=verify_hash_while_download,
707
+ download_files_without_hash=download_files_without_hash,
708
+ on_file_created=partial(
709
+ model_access_manager.on_file_created,
710
+ access_identifiers=access_identifiers,
711
+ ),
712
+ on_file_renamed=partial(
713
+ model_access_manager.on_file_renamed,
714
+ access_identifiers=access_identifiers,
715
+ ),
716
+ on_symlink_created=partial(
717
+ model_access_manager.on_symlink_created,
718
+ access_identifiers=access_identifiers,
719
+ ),
720
+ on_symlink_deleted=model_access_manager.on_symlink_deleted,
721
+ use_auto_resolution_cache=use_auto_resolution_cache,
722
+ )
723
+ model_access_manager.on_model_loaded(
724
+ model=model,
725
+ access_identifiers=access_identifiers,
726
+ model_storage_path=model_package_cache_dir,
727
+ )
728
+ if point_model_directory:
729
+ point_model_directory(model_package_cache_dir)
730
+ return model
731
+ except Exception as error:
732
+ LOGGER.warning(
733
+ f"Model package with id {model_package.package_id} that was selected to be loaded "
734
+ f"failed to load with error: {error} of type {error.__class__.__name__}. This may "
735
+ f"be caused several issues. If you see this warning after manually specifying model "
736
+ f"package to be loaded - make sure that all required dependencies are installed. If "
737
+ f"that warning is displayed when the model package was auto-selected - there is most "
738
+ f"likely a bug in `inference-models` and you should raise an issue providing full context of "
739
+ f"the event. https://github.com/roboflow/inference/issues"
740
+ )
741
+ failed_load_attempts.append((model_package.package_id, error))
742
+
743
+ summary_of_errors = "\n".join(
744
+ f"\t* model_package_id={model_package_id} error={error} error_type={error.__class__.__name__}"
745
+ for model_package_id, error in failed_load_attempts
746
+ )
747
+ raise ModelLoadingError(
748
+ message=f"Could not load any of model package candidate for model {model_id}. This may "
749
+ f"be caused several issues. If you see this warning after manually specifying model "
750
+ f"package to be loaded - make sure that all required dependencies are installed. If "
751
+ f"that warning is displayed when the model package was auto-selected - there is most "
752
+ f"likely a bug in `inference-models` and you should raise an issue providing full context of "
753
+ f"the event. https://github.com/roboflow/inference/issues\n\n"
754
+ f"Here is the summary of errors for specific model packages:\n{summary_of_errors}\n\n",
755
+ help_url="https://todo",
756
+ )
757
+
758
+
759
+ def initialize_model(
760
+ model_id: str,
761
+ model_architecture: ModelArchitecture,
762
+ task_type: Optional[TaskType],
763
+ model_package: ModelPackageMetadata,
764
+ model_init_kwargs: dict,
765
+ auto_resolution_cache: AutoResolutionCache,
766
+ auto_negotiation_hash: str,
767
+ model_dependencies: Optional[List[ModelDependency]],
768
+ model_dependencies_instances: Dict[str, AnyModel],
769
+ model_dependencies_directories: Dict[str, str],
770
+ model_download_file_lock_acquire_timeout: int = 10,
771
+ verify_hash_while_download: bool = True,
772
+ download_files_without_hash: bool = False,
773
+ on_file_created: Optional[Callable[[str], None]] = None,
774
+ on_file_renamed: Optional[Callable[[str, str], None]] = None,
775
+ on_symlink_created: Optional[Callable[[str, str], None]] = None,
776
+ on_symlink_deleted: Optional[Callable[[str], None]] = None,
777
+ use_auto_resolution_cache: bool = True,
778
+ ) -> Tuple[AnyModel, str]:
779
+ model_class = resolve_model_class(
780
+ model_architecture=model_architecture,
781
+ task_type=task_type,
782
+ backend=model_package.backend,
783
+ )
784
+ for artefact in model_package.package_artefacts:
785
+ if artefact.file_handle == MODEL_CONFIG_FILE_NAME:
786
+ raise CorruptedModelPackageError(
787
+ message=f"For model with id=`{model_id}` and package={model_package.package_id} discovered "
788
+ f"artefact named `{MODEL_CONFIG_FILE_NAME}` which collides with the config file that "
789
+ f"inference is supposed to create for a model in order for compatibility with offline "
790
+ f"loaders. This problem indicate a violation of model package contract and requires change in "
791
+ f"model package structure. If you experience this issue using hosted Roboflow solution, contact "
792
+ f"us to solve the problem.",
793
+ help_url="https://todo",
794
+ )
795
+ files_specs = [
796
+ (a.file_handle, a.download_url, a.md5_hash)
797
+ for a in model_package.package_artefacts
798
+ ]
799
+ file_specs_with_hash = [f for f in files_specs if f[2] is not None]
800
+ file_specs_without_hash = [f for f in files_specs if f[2] is None]
801
+ shared_blobs_dir = generate_shared_blobs_path()
802
+ model_package_cache_dir = generate_model_package_cache_path(
803
+ model_id=model_id,
804
+ package_id=model_package.package_id,
805
+ )
806
+ os.makedirs(model_package_cache_dir, exist_ok=True)
807
+ shared_files_mapping = download_files_to_directory(
808
+ target_dir=shared_blobs_dir,
809
+ files_specs=file_specs_with_hash,
810
+ file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
811
+ verify_hash_while_download=verify_hash_while_download,
812
+ download_files_without_hash=download_files_without_hash,
813
+ name_after="md5_hash",
814
+ on_file_created=on_file_created,
815
+ on_file_renamed=on_file_renamed,
816
+ )
817
+ model_specific_files_mapping = download_files_to_directory(
818
+ target_dir=model_package_cache_dir,
819
+ files_specs=file_specs_without_hash,
820
+ file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
821
+ verify_hash_while_download=verify_hash_while_download,
822
+ download_files_without_hash=download_files_without_hash,
823
+ on_file_created=on_file_created,
824
+ on_file_renamed=on_file_renamed,
825
+ )
826
+ symlinks_mapping = create_symlinks_to_shared_blobs(
827
+ model_dir=model_package_cache_dir,
828
+ shared_files_mapping=shared_files_mapping,
829
+ model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
830
+ on_symlink_created=on_symlink_created,
831
+ on_symlink_deleted=on_symlink_deleted,
832
+ )
833
+ config_path = os.path.join(model_package_cache_dir, MODEL_CONFIG_FILE_NAME)
834
+ dump_model_config_for_offline_use(
835
+ config_path=config_path,
836
+ model_architecture=model_architecture,
837
+ task_type=task_type,
838
+ backend_type=model_package.backend,
839
+ file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
840
+ on_file_created=on_file_created,
841
+ )
842
+ resolved_files = set(shared_files_mapping.values())
843
+ resolved_files.update(model_specific_files_mapping.values())
844
+ resolved_files.update(symlinks_mapping.values())
845
+ resolved_files.add(config_path)
846
+ dependencies_resolved_files = handle_dependencies_directories_creation(
847
+ model_package_cache_dir=model_package_cache_dir,
848
+ model_dependencies_directories=model_dependencies_directories,
849
+ model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
850
+ on_symlink_created=on_symlink_created,
851
+ on_symlink_deleted=on_symlink_deleted,
852
+ )
853
+ resolved_files.update(dependencies_resolved_files)
854
+ model_init_kwargs[MODEL_DEPENDENCIES_KEY] = model_dependencies_instances
855
+ model = model_class.from_pretrained(model_package_cache_dir, **model_init_kwargs)
856
+ dump_auto_resolution_cache(
857
+ use_auto_resolution_cache=use_auto_resolution_cache,
858
+ auto_resolution_cache=auto_resolution_cache,
859
+ auto_negotiation_hash=auto_negotiation_hash,
860
+ model_id=model_id,
861
+ model_package_id=model_package.package_id,
862
+ model_architecture=model_architecture,
863
+ task_type=task_type,
864
+ backend_type=model_package.backend,
865
+ resolved_files=resolved_files,
866
+ model_dependencies=model_dependencies,
867
+ )
868
+ return model, model_package_cache_dir
869
+
870
+
871
+ def create_symlinks_to_shared_blobs(
872
+ model_dir: str,
873
+ shared_files_mapping: Dict[FileHandle, str],
874
+ model_download_file_lock_acquire_timeout: int = 10,
875
+ on_symlink_created: Optional[Callable[[str, str], None]] = None,
876
+ on_symlink_deleted: Optional[Callable[[str], None]] = None,
877
+ ) -> Dict[str, str]:
878
+ # this function will not override existing files
879
+ os.makedirs(model_dir, exist_ok=True)
880
+ result = {}
881
+ for file_handle, source_path in shared_files_mapping.items():
882
+ link_name = os.path.join(model_dir, file_handle)
883
+ target_path = shared_files_mapping[file_handle]
884
+ result[file_handle] = link_name
885
+ if os.path.exists(link_name) and (
886
+ not os.path.islink(link_name) or os.path.realpath(link_name) == target_path
887
+ ):
888
+ continue
889
+ handle_symlink_creation(
890
+ target_path=target_path,
891
+ link_name=link_name,
892
+ model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
893
+ on_symlink_created=on_symlink_created,
894
+ on_symlink_deleted=on_symlink_deleted,
895
+ )
896
+ return result
897
+
898
+
899
+ def handle_symlink_creation(
900
+ target_path: str,
901
+ link_name: str,
902
+ model_download_file_lock_acquire_timeout: int = 10,
903
+ on_symlink_created: Optional[Callable[[str, str], None]] = None,
904
+ on_symlink_deleted: Optional[Callable[[str], None]] = None,
905
+ ) -> None:
906
+ link_dir, link_file_name = os.path.split(os.path.abspath(link_name))
907
+ os.makedirs(link_dir, exist_ok=True)
908
+ lock_path = os.path.join(link_dir, f".{link_file_name}.lock")
909
+ with FileLock(lock_path, timeout=model_download_file_lock_acquire_timeout):
910
+ if os.path.islink(link_name):
911
+ # file does not exist, but is link = broken symlink - we should purge
912
+ os.remove(link_name)
913
+ if on_symlink_deleted:
914
+ on_symlink_deleted(link_name)
915
+ os.symlink(target_path, link_name)
916
+ if on_symlink_created:
917
+ on_symlink_created(target_path, link_name)
918
+
919
+
920
+ def dump_model_config_for_offline_use(
921
+ config_path: str,
922
+ model_architecture: Optional[ModelArchitecture],
923
+ task_type: TaskType,
924
+ backend_type: Optional[BackendType],
925
+ file_lock_acquire_timeout: int,
926
+ on_file_created: Optional[Callable[[str], None]] = None,
927
+ ) -> None:
928
+ if os.path.exists(config_path):
929
+ # we kinda trust that what we did previously is right - in case when the file
930
+ # gets corrupted we may end up in problem - to be verified empirically
931
+ return None
932
+ target_file_dir, target_file_name = os.path.split(config_path)
933
+ lock_path = os.path.join(target_file_dir, f".{target_file_name}.lock")
934
+ with FileLock(lock_path, timeout=file_lock_acquire_timeout):
935
+ dump_json(
936
+ path=config_path,
937
+ content={
938
+ "model_architecture": model_architecture,
939
+ "task_type": task_type,
940
+ "backend_type": backend_type,
941
+ },
942
+ )
943
+ if on_file_created:
944
+ on_file_created(config_path)
945
+
946
+
947
+ def handle_dependencies_directories_creation(
948
+ model_package_cache_dir: str,
949
+ model_dependencies_directories: Dict[str, str],
950
+ model_download_file_lock_acquire_timeout: int = 10,
951
+ on_symlink_created: Optional[Callable[[str, str], None]] = None,
952
+ on_symlink_deleted: Optional[Callable[[str], None]] = None,
953
+ ) -> Set[str]:
954
+ resolved_files = set()
955
+ if not model_dependencies_directories:
956
+ return resolved_files
957
+ for dependency_name, dependency_directory in model_dependencies_directories.items():
958
+ dependency_files = scan_dependency_directory_for_resolved_files(
959
+ dependency_directory=dependency_directory
960
+ )
961
+ resolved_files.update(dependency_files)
962
+ dependencies_sub_dir = os.path.join(
963
+ model_package_cache_dir, MODEL_DEPENDENCIES_SUB_DIR
964
+ )
965
+ target_dependency_dir = os.path.join(dependencies_sub_dir, dependency_name)
966
+ os.makedirs(dependencies_sub_dir, exist_ok=True)
967
+ dependency_lock_path = os.path.join(
968
+ dependencies_sub_dir, f".{dependency_name}.lock"
969
+ )
970
+ with FileLock(
971
+ dependency_lock_path, timeout=model_download_file_lock_acquire_timeout
972
+ ):
973
+ if os.path.exists(target_dependency_dir) and os.path.islink(
974
+ target_dependency_dir
975
+ ):
976
+ os.remove(target_dependency_dir)
977
+ if on_symlink_deleted:
978
+ on_symlink_deleted(target_dependency_dir)
979
+ if not os.path.exists(target_dependency_dir):
980
+ # Question: is it ok to only try to remove symlink and avoid doing anything else
981
+ # if we encounter actual file / dir there?
982
+ os.symlink(dependency_directory, target_dependency_dir)
983
+ if on_symlink_created:
984
+ on_symlink_created(dependency_directory, target_dependency_dir)
985
+ return resolved_files
986
+
987
+
988
+ def scan_dependency_directory_for_resolved_files(
989
+ dependency_directory: str,
990
+ ) -> List[str]:
991
+ # we do not follow symlinks here, as the assumption is that we only support one level of nesting
992
+ # for packages, wo when we have dependency - this model must not have dependencies, so
993
+ # we will not encounter directories which are symlinks to be followed.
994
+ results = []
995
+ for current_dir, _, files in os.walk(dependency_directory):
996
+ for file in files:
997
+ if file.startswith(".") and file.endswith(".lock"):
998
+ continue
999
+ full_path = os.path.abspath(os.path.join(current_dir, file))
1000
+ results.append(full_path)
1001
+ if os.path.islink(full_path):
1002
+ results.append(os.readlink(full_path))
1003
+ return results
1004
+
1005
+
1006
+ def dump_auto_resolution_cache(
1007
+ use_auto_resolution_cache: bool,
1008
+ auto_resolution_cache: AutoResolutionCache,
1009
+ auto_negotiation_hash: str,
1010
+ model_id: str,
1011
+ model_package_id: str,
1012
+ model_architecture: Optional[ModelArchitecture],
1013
+ task_type: TaskType,
1014
+ backend_type: Optional[BackendType],
1015
+ resolved_files: Set[str],
1016
+ model_dependencies: Optional[List[ModelDependency]],
1017
+ ) -> None:
1018
+ if not use_auto_resolution_cache:
1019
+ return None
1020
+ cache_content = AutoResolutionCacheEntry(
1021
+ model_id=model_id,
1022
+ model_package_id=model_package_id,
1023
+ resolved_files=resolved_files,
1024
+ model_architecture=model_architecture,
1025
+ task_type=task_type,
1026
+ backend_type=backend_type,
1027
+ created_at=datetime.now(),
1028
+ model_dependencies=model_dependencies,
1029
+ )
1030
+ auto_resolution_cache.register(
1031
+ auto_negotiation_hash=auto_negotiation_hash, cache_entry=cache_content
1032
+ )
1033
+
1034
+
1035
+ def generate_shared_blobs_path() -> str:
1036
+ return os.path.join(INFERENCE_HOME, "shared-blobs")
1037
+
1038
+
1039
+ def generate_model_package_cache_path(model_id: str, package_id: str) -> str:
1040
+ ensure_package_id_is_os_safe(model_id=model_id, package_id=package_id)
1041
+ model_id_slug = slugify_model_id_to_os_safe_format(model_id=model_id)
1042
+ return os.path.join(INFERENCE_HOME, "models-cache", model_id_slug, package_id)
1043
+
1044
+
1045
+ def ensure_package_id_is_os_safe(model_id: str, package_id: str) -> None:
1046
+ if re.search(r"[^A-Za-z0-9]", package_id):
1047
+ raise InsecureModelIdentifierError(
1048
+ message=f"Attempted to load model: {model_id} using package ID: {package_id} which "
1049
+ f"has invalid format. ID is expected to contain only ASCII characters and numbers to "
1050
+ f"ensure safety of local cache. If you see this error running your model on Roboflow platform, "
1051
+ f"raise the issue: https://github.com/roboflow/inference/issues. If you are running `inference` "
1052
+ f"outside of the platform, verify that your weights provider keeps the model packages identifiers "
1053
+ f"in the expected format.",
1054
+ help_url="https://TODO",
1055
+ )
1056
+
1057
+
1058
+ def slugify_model_id_to_os_safe_format(model_id: str) -> str:
1059
+ # Only ASCII
1060
+ model_id_slug = re.sub(r"[^A-Za-z0-9_-]+", "-", model_id)
1061
+ # Collapse multiple underscores/dashes
1062
+ model_id_slug = re.sub(r"[_-]{2,}", "-", model_id_slug)
1063
+ if not model_id_slug:
1064
+ model_id_slug = "special-char-only-model-id"
1065
+ if len(model_id_slug) > 48:
1066
+ model_id_slug = model_id_slug[:48]
1067
+ digest = hashlib.blake2s(model_id.encode("utf-8"), digest_size=4).hexdigest()
1068
+ return f"{model_id_slug}-{digest}"
1069
+
1070
+
1071
+ def attempt_loading_model_from_local_storage(
1072
+ model_dir_or_weights_path: str,
1073
+ allow_local_code_packages: bool,
1074
+ model_init_kwargs: dict,
1075
+ model_type: Optional[str] = None,
1076
+ task_type: Optional[str] = None,
1077
+ backend_type: Optional[
1078
+ Union[str, BackendType, List[Union[str, BackendType]]]
1079
+ ] = None,
1080
+ ) -> AnyModel:
1081
+ if os.path.isfile(model_dir_or_weights_path):
1082
+ return attempt_loading_model_from_checkpoint(
1083
+ checkpoint_path=model_dir_or_weights_path,
1084
+ model_init_kwargs=model_init_kwargs,
1085
+ model_type=model_type,
1086
+ task_type=task_type,
1087
+ backend_type=backend_type,
1088
+ )
1089
+ config_path = os.path.join(model_dir_or_weights_path, MODEL_CONFIG_FILE_NAME)
1090
+ model_config = parse_model_config(config_path=config_path)
1091
+ if model_config.is_library_model():
1092
+ return load_library_model_from_local_dir(
1093
+ model_dir=model_dir_or_weights_path,
1094
+ model_config=model_config,
1095
+ model_init_kwargs=model_init_kwargs,
1096
+ )
1097
+ if not allow_local_code_packages:
1098
+ raise ModelLoadingError(
1099
+ message=f"Attempted to load model from local package with arbitrary code. This is not allowed in "
1100
+ f"this environment. To let inference loading such models, use `allow_local_code_packages=True` "
1101
+ f"parameter of `AutoModel.from_pretrained(...)`. If you see this error while using one of Roboflow "
1102
+ f"hosted solution - contact us to solve the problem.",
1103
+ help_url="https://todo",
1104
+ )
1105
+ return load_model_from_local_package_with_arbitrary_code(
1106
+ model_dir=model_dir_or_weights_path,
1107
+ model_config=model_config,
1108
+ model_init_kwargs=model_init_kwargs,
1109
+ )
1110
+
1111
+
1112
+ def attempt_loading_model_from_checkpoint(
1113
+ checkpoint_path: str,
1114
+ model_init_kwargs: dict,
1115
+ model_type: Optional[str] = None,
1116
+ task_type: Optional[str] = None,
1117
+ backend_type: Optional[
1118
+ Union[str, BackendType, List[Union[str, BackendType]]]
1119
+ ] = None,
1120
+ ) -> AnyModel:
1121
+ model_architecture, task_type, backend_type = resolve_models_registry_entry(
1122
+ model_type=model_type,
1123
+ task_type=task_type,
1124
+ backend_type=backend_type,
1125
+ )
1126
+ model_init_kwargs["model_type"] = model_type
1127
+ model_class = resolve_model_class(
1128
+ model_architecture=model_architecture,
1129
+ task_type=task_type,
1130
+ backend=backend_type,
1131
+ )
1132
+ return model_class.from_pretrained(checkpoint_path, **model_init_kwargs)
1133
+
1134
+
1135
+ def resolve_models_registry_entry(
1136
+ model_type: Optional[str],
1137
+ task_type: Optional[str] = None,
1138
+ backend_type: Optional[
1139
+ Union[str, BackendType, List[Union[str, BackendType]]]
1140
+ ] = None,
1141
+ ) -> Tuple[str, str, BackendType]:
1142
+ # TODO: in the future this check will grow in size
1143
+ if not model_type:
1144
+ raise ModelLoadingError(
1145
+ message="When loading model directly from checkpoint path, `model_type` parameter must be specified. "
1146
+ "Use one of the supported value, for example `rfdetr-nano` in case you refer checkpoint of "
1147
+ "RFDetr Nano model.",
1148
+ help_url="https://todo",
1149
+ )
1150
+ if model_type not in MODEL_TYPES_TO_LOAD_FROM_CHECKPOINT:
1151
+ raise ModelLoadingError(
1152
+ message="When loading model directly from checkpoint path, `model_type` parameter must define "
1153
+ "one of the type of model that support loading directly from the checkpoints. "
1154
+ f"Models supported in current version: {MODEL_TYPES_TO_LOAD_FROM_CHECKPOINT}",
1155
+ help_url="https://todo",
1156
+ )
1157
+ # a bit of hard coding here, over time we must maintain
1158
+ model_architecture = "rfdetr"
1159
+ if task_type is None:
1160
+ if model_type == "rfdetr-seg-preview":
1161
+ task_type = INSTANCE_SEGMENTATION_TASK
1162
+ else:
1163
+ task_type = OBJECT_DETECTION_TASK
1164
+ if task_type not in {OBJECT_DETECTION_TASK, INSTANCE_SEGMENTATION_TASK}:
1165
+ raise ModelLoadingError(
1166
+ message=f"When loading model directly from checkpoint path, set `model_type` as {model_type} and "
1167
+ f"`task_type` as {task_type}, whereas selected model do only support `{OBJECT_DETECTION_TASK}` "
1168
+ f"task while loading from checkpoint file.",
1169
+ help_url="https://todo",
1170
+ )
1171
+ if backend_type is None:
1172
+ backend_type = BackendType.TORCH
1173
+ if isinstance(backend_type, list) and len(backend_type) != 1:
1174
+ if len(backend_type) != 1:
1175
+ raise ModelLoadingError(
1176
+ message=f"When loading model directly from checkpoint path, set `backend` parameter to be {backend_type}, "
1177
+ f"whereas it is only supported to pass a single value.",
1178
+ help_url="https://todo",
1179
+ )
1180
+ backend_type = backend_type[0]
1181
+ if isinstance(backend_type, str):
1182
+ backend_type = parse_backend_type(value=backend_type)
1183
+ if backend_type is not BackendType.TORCH:
1184
+ raise ModelLoadingError(
1185
+ message=f"When loading model directly from checkpoint path, selected the following backend {backend_type}, "
1186
+ f"but the backend supported for model {model_type} is {BackendType.TORCH}",
1187
+ help_url="https://todo",
1188
+ )
1189
+ return model_architecture, task_type, backend_type
1190
+
1191
+
1192
+ def parse_model_config(config_path: str) -> InferenceModelConfig:
1193
+ if not os.path.isfile(config_path):
1194
+ raise ModelLoadingError(
1195
+ message=f"Could not find model config while attempting to load model from "
1196
+ f"local directory. This error may be caused by misconfiguration of model package (lack of config "
1197
+ f"file), as well as by clash between model_id or model alias and contents of local disc drive which "
1198
+ f"is possible when you have local directory in current dir which has the name colliding with the "
1199
+ f"model you attempt to load. If your intent was to load model from remote backend (not local "
1200
+ f"storage) - verify the contents of $PWD. If you see this problem while using one of Roboflow "
1201
+ f"hosted solutions - contact us to get help.",
1202
+ help_url="https://todo",
1203
+ )
1204
+ try:
1205
+ raw_config = read_json(path=config_path)
1206
+ except ValueError as error:
1207
+ raise CorruptedModelPackageError(
1208
+ message=f"Could not decode model config while attempting to load model from "
1209
+ f"local directory. This error may be caused by corrupted config file. Validate the content of your "
1210
+ f"model package and check in documentation the required format of model config file. "
1211
+ f"If you see this problem while using one of Roboflow hosted solutions - contact us to get help.",
1212
+ help_url="https://todo",
1213
+ ) from error
1214
+ if not isinstance(raw_config, dict):
1215
+ raise CorruptedModelPackageError(
1216
+ message=f"While loading the model from local directory encountered corrupted model config file - config is "
1217
+ f"supposed to be a dictionary, instead decoded object of type: "
1218
+ f"{type(raw_config)}. If you see this problem while using one of Roboflow hosted solutions - "
1219
+ f"contact us to get help. Otherwise - verify the content of your model config.",
1220
+ help_url="https://todo",
1221
+ )
1222
+ backend_type = None
1223
+ if "backend_type" in raw_config:
1224
+ raw_backend_type = raw_config["backend_type"]
1225
+ try:
1226
+ backend_type = BackendType(raw_backend_type)
1227
+ except ValueError as e:
1228
+ raise CorruptedModelPackageError(
1229
+ message=f"While loading the model from local directory encountered corrupted model config "
1230
+ "- declared `backend_type` ({raw_backend_type}) is not supported by inference. "
1231
+ f"Supported values: {list(t.value for t in BackendType)}. If you see this problem while using "
1232
+ f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify the content "
1233
+ f"of your model config.",
1234
+ help_url="https://todo",
1235
+ ) from e
1236
+ return InferenceModelConfig(
1237
+ model_architecture=raw_config.get("model_architecture"),
1238
+ task_type=raw_config.get("task_type"),
1239
+ backend_type=backend_type,
1240
+ model_module=raw_config.get("model_module"),
1241
+ model_class=raw_config.get("model_class"),
1242
+ )
1243
+
1244
+
1245
+ def load_library_model_from_local_dir(
1246
+ model_dir: str,
1247
+ model_config: InferenceModelConfig,
1248
+ model_init_kwargs: dict,
1249
+ ) -> AnyModel:
1250
+ model_class = resolve_model_class(
1251
+ model_architecture=model_config.model_architecture,
1252
+ task_type=model_config.task_type,
1253
+ backend=model_config.backend_type,
1254
+ )
1255
+ return model_class.from_pretrained(model_dir, **model_init_kwargs)
1256
+
1257
+
1258
+ def load_model_from_local_package_with_arbitrary_code(
1259
+ model_dir: str,
1260
+ model_config: InferenceModelConfig,
1261
+ model_init_kwargs: dict,
1262
+ ) -> AnyModel:
1263
+ if model_config.model_module is None or model_config.model_class is None:
1264
+ raise CorruptedModelPackageError(
1265
+ message=f"While loading the model from local directory encountered corrupted model config file. "
1266
+ f"Config does not specify neither `model_module` name nor `model_class`, which are both "
1267
+ f"required to load models provided with arbitrary code. If you see this problem while using "
1268
+ f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify the content "
1269
+ f"of your model config.",
1270
+ help_url="https://todo",
1271
+ )
1272
+ model_module_path = os.path.join(model_dir, model_config.model_module)
1273
+ if not os.path.isfile(model_module_path):
1274
+ raise CorruptedModelPackageError(
1275
+ message=f"While loading the model from local directory encountered corrupted model config file. "
1276
+ f"Config pointed module {model_config.model_module}, but there is no file under "
1277
+ f"{model_module_path}. If you see this problem while using "
1278
+ f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify the content "
1279
+ f"of your model config.",
1280
+ help_url="https://todo",
1281
+ )
1282
+ model_class = load_class_from_path(
1283
+ module_path=model_module_path, class_name=model_config.model_class
1284
+ )
1285
+ return model_class.from_pretrained(model_dir, **model_init_kwargs)
1286
+
1287
+
1288
+ def load_class_from_path(module_path: str, class_name: str) -> AnyModel:
1289
+ if not os.path.exists(module_path):
1290
+ raise CorruptedModelPackageError(
1291
+ message=f"When loading local model with arbitrary code, encountered issue with loading the module. "
1292
+ "Could find the module under the path specified in model config. If you see this problem "
1293
+ f"while using one of Roboflow hosted solutions - contact us to get help. Otherwise - verify your "
1294
+ f"model package checking if you can load the module with model implementation within your "
1295
+ f"python environment.",
1296
+ help_url="https://todo",
1297
+ )
1298
+ module_name = os.path.splitext(os.path.basename(module_path))[0]
1299
+ spec = importlib.util.spec_from_file_location(module_name, module_path)
1300
+ if spec is None:
1301
+ raise CorruptedModelPackageError(
1302
+ message=f"When loading local model with arbitrary code, encountered issue with loading the module. "
1303
+ "Could not build module specification. If you see this problem while using "
1304
+ f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify your "
1305
+ f"model package checking if you can load the module with model implementation within your "
1306
+ f"python environment.",
1307
+ help_url="https://todo",
1308
+ )
1309
+ module = importlib.util.module_from_spec(spec)
1310
+ loader = spec.loader
1311
+ if loader is None or not hasattr(loader, "exec_module"):
1312
+ raise CorruptedModelPackageError(
1313
+ message=f"When loading local model with arbitrary code, encountered issue with loading the module. "
1314
+ "Could not execute module loader. If you see this problem while using "
1315
+ f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify your "
1316
+ f"model package checking if you can load the module with model implementation within your "
1317
+ f"python environment.",
1318
+ help_url="https://todo",
1319
+ )
1320
+ try:
1321
+ loader.exec_module(module)
1322
+ except Exception as error:
1323
+ raise CorruptedModelPackageError(
1324
+ message=f"When loading local model with arbitrary code, encountered issue executing the module code "
1325
+ f"to retrieve model class. Details of the error: {error}. If you see this problem while using "
1326
+ f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify your "
1327
+ f"model package checking if you can load the module with model implementation within your "
1328
+ f"python environment.",
1329
+ help_url="https://todo",
1330
+ )
1331
+ if not hasattr(module, class_name):
1332
+ raise CorruptedModelPackageError(
1333
+ message=f"When loading local model with arbitrary code, encountered issue with loading the module. "
1334
+ f"Module `{module_name}` has no class `{class_name}`. If you see this problem while using "
1335
+ f"one of Roboflow hosted solutions - contact us to get help. Otherwise - verify your "
1336
+ f"model package checking if you can load the module with model implementation within your "
1337
+ f"python environment. It may also be the case that configuration file of the model points "
1338
+ f"to invalid class name.",
1339
+ help_url="https://todo",
1340
+ )
1341
+ return getattr(module, class_name)