inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,608 @@
1
+ import hashlib
2
+ import math
3
+ import os
4
+ from concurrent.futures import FIRST_EXCEPTION, ThreadPoolExecutor, wait
5
+ from threading import Lock
6
+ from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union
7
+ from uuid import uuid4
8
+
9
+ import backoff
10
+ import requests
11
+ from filelock import FileLock
12
+ from requests import Response, Timeout
13
+ from rich.progress import (
14
+ BarColumn,
15
+ DownloadColumn,
16
+ Progress,
17
+ TimeRemainingColumn,
18
+ TransferSpeedColumn,
19
+ )
20
+
21
+ from inference_models.configuration import (
22
+ API_CALLS_MAX_TRIES,
23
+ API_CALLS_TIMEOUT,
24
+ DISABLE_INTERACTIVE_PROGRESS_BARS,
25
+ IDEMPOTENT_API_REQUEST_CODES_TO_RETRY,
26
+ )
27
+ from inference_models.errors import (
28
+ FileHashSumMissmatch,
29
+ InvalidParameterError,
30
+ RetryError,
31
+ UntrustedFileError,
32
+ )
33
+ from inference_models.logger import LOGGER
34
+ from inference_models.utils.file_system import (
35
+ ensure_parent_dir_exists,
36
+ pre_allocate_file,
37
+ remove_file_if_exists,
38
+ stream_file_bytes,
39
+ )
40
+
41
+ FileHandle = str
42
+ DownloadUrl = str
43
+ MD5Hash = Optional[str]
44
+
45
+ MIN_SIZE_FOR_THREADED_DOWNLOAD = 32 * 1024 * 1024 # 32MB
46
+ MIN_THREAD_CHUNK_SIZE = 16 * 1024 * 1024 # 16MB
47
+ DEFAULT_STREAM_DOWNLOAD_CHUNK = 1 * 1024 * 1024 # 1MB
48
+
49
+
50
+ class HashNullObject:
51
+
52
+ def update(self, *args, **kwargs) -> None:
53
+ pass
54
+
55
+ def hexdigest(self) -> None:
56
+ return None
57
+
58
+
59
+ def download_files_to_directory(
60
+ target_dir: str,
61
+ files_specs: List[Tuple[FileHandle, DownloadUrl, MD5Hash]],
62
+ verbose: bool = True,
63
+ response_codes_to_retry: Optional[Set[int]] = None,
64
+ request_timeout: Optional[int] = None,
65
+ max_parallel_downloads: int = 8,
66
+ max_threads_per_download: int = 8,
67
+ file_lock_acquire_timeout: int = 10,
68
+ verify_hash_while_download: bool = True,
69
+ download_files_without_hash: bool = False,
70
+ name_after: Literal["file_handle", "md5_hash"] = "file_handle",
71
+ on_file_created: Optional[Callable[[str], None]] = None,
72
+ on_file_renamed: Optional[Callable[[str, str], None]] = None,
73
+ ) -> Dict[str, str]:
74
+ if name_after not in {"file_handle", "md5_hash"}:
75
+ raise InvalidParameterError(
76
+ message="Function download_files_to_directory(...) was called with "
77
+ f"invalid value of parameter `name_after` - received value `{name_after}`. "
78
+ f"This is a bug in `inference-models` - submit new issue under "
79
+ f"https://github.com/roboflow/inference/issues/",
80
+ help_url="https://todo",
81
+ )
82
+ if DISABLE_INTERACTIVE_PROGRESS_BARS:
83
+ verbose = False
84
+ files_mapping = construct_files_path_mapping(
85
+ target_dir=target_dir,
86
+ files_specs=files_specs,
87
+ name_after=name_after,
88
+ )
89
+ files_specs = exclude_existing_files(
90
+ files_specs=files_specs,
91
+ files_mapping=files_mapping,
92
+ )
93
+ if not files_specs:
94
+ return files_mapping
95
+ if response_codes_to_retry is None:
96
+ response_codes_to_retry = IDEMPOTENT_API_REQUEST_CODES_TO_RETRY
97
+ if request_timeout is None:
98
+ request_timeout = API_CALLS_TIMEOUT
99
+ if not download_files_without_hash:
100
+ untrusted_files = [f[1] for f in files_specs if f[2] is None]
101
+ if len(untrusted_files) > 0:
102
+ raise UntrustedFileError(
103
+ message=f"While downloading files detected {len(untrusted_files)} untrusted file(s): {untrusted_files} "
104
+ f"without MD5 hash sum to verify the download content. The download method was used with "
105
+ f"`download_files_without_hash=False` - which prevents from downloading such files. If you see "
106
+ f"this error while using hosted Roboflow serving option - contact us to get support.",
107
+ help_url="https://todo",
108
+ )
109
+ os.makedirs(target_dir, exist_ok=True)
110
+ progress = Progress(
111
+ "[progress.description]{task.description}",
112
+ BarColumn(),
113
+ DownloadColumn(),
114
+ TransferSpeedColumn(),
115
+ TimeRemainingColumn(),
116
+ disable=not verbose,
117
+ )
118
+ download_id = str(uuid4())
119
+ with progress:
120
+ with ThreadPoolExecutor(max_workers=max_parallel_downloads) as executor:
121
+ futures = []
122
+ for file_handle, download_url, md5_hash in files_specs:
123
+ future = executor.submit(
124
+ safe_download_file,
125
+ target_file_path=files_mapping[file_handle],
126
+ download_url=download_url,
127
+ md5_hash=md5_hash,
128
+ verify_hash_while_download=verify_hash_while_download,
129
+ download_id=download_id,
130
+ progress=progress,
131
+ response_codes_to_retry=response_codes_to_retry,
132
+ request_timeout=request_timeout,
133
+ max_threads_per_download=max_threads_per_download,
134
+ file_lock_acquire_timeout=file_lock_acquire_timeout,
135
+ on_file_created=on_file_created,
136
+ on_file_renamed=on_file_renamed,
137
+ )
138
+ futures.append(future)
139
+ done_futures, pending_futures = wait(futures, return_when=FIRST_EXCEPTION)
140
+ for pending_future in pending_futures:
141
+ pending_future.cancel()
142
+ _ = wait(pending_futures)
143
+ for future in done_futures:
144
+ future_exception = future.exception()
145
+ if future_exception:
146
+ raise future_exception
147
+ return files_mapping
148
+
149
+
150
+ def construct_files_path_mapping(
151
+ target_dir: str,
152
+ files_specs: List[Tuple[FileHandle, DownloadUrl, MD5Hash]],
153
+ name_after: Literal["file_handle", "md5_hash"] = "file_handle",
154
+ ) -> Dict[FileHandle, str]:
155
+ result = {}
156
+ for file_handle, download_url, content_hash in files_specs:
157
+ if name_after == "md5_hash" and content_hash is None:
158
+ raise UntrustedFileError(
159
+ message="Attempted to download file without declared hash sum when "
160
+ "`name_after='md5_hash'` - this problem is either misconfiguration "
161
+ "of download procedure in `inference-models` or bug in the codebase. "
162
+ "If you see this error using hosted Roboflow solution - contact us to get "
163
+ "help. Running locally, verify the download code and raise an issue if you see "
164
+ "a bug: https://github.com/roboflow/inference/issues/",
165
+ help_url="https://todo",
166
+ )
167
+ if name_after == "md5_hash":
168
+ target_path = os.path.join(target_dir, content_hash)
169
+ else:
170
+ target_path = os.path.join(target_dir, file_handle)
171
+ result[file_handle] = target_path
172
+ return result
173
+
174
+
175
+ def exclude_existing_files(
176
+ files_specs: List[Tuple[FileHandle, DownloadUrl, MD5Hash]],
177
+ files_mapping: Dict[FileHandle, str],
178
+ ) -> List[Tuple[FileHandle, DownloadUrl, MD5Hash]]:
179
+ result = []
180
+ for file_specs in files_specs:
181
+ target_path = files_mapping[file_specs[0]]
182
+ if not os.path.isfile(target_path):
183
+ result.append(file_specs)
184
+ return result
185
+
186
+
187
+ def safe_download_file(
188
+ target_file_path: str,
189
+ download_url: str,
190
+ download_id: str,
191
+ md5_hash: MD5Hash,
192
+ verify_hash_while_download: bool,
193
+ progress: Progress,
194
+ response_codes_to_retry: Set[int],
195
+ request_timeout: int,
196
+ max_threads_per_download: int,
197
+ file_lock_acquire_timeout: int,
198
+ on_file_created: Optional[Callable[[str], None]] = None,
199
+ on_file_renamed: Optional[Callable[[str, str], None]] = None,
200
+ ) -> None:
201
+ ensure_parent_dir_exists(path=target_file_path)
202
+ target_file_dir, target_file_name = os.path.split(target_file_path)
203
+ lock_path = os.path.join(target_file_dir, f".{target_file_name}.lock")
204
+ tmp_download_file = os.path.abspath(
205
+ os.path.join(target_file_dir, f"{target_file_name}.{download_id}")
206
+ )
207
+ try:
208
+ with FileLock(lock_path, timeout=file_lock_acquire_timeout):
209
+ safe_execute_download(
210
+ download_url=download_url,
211
+ tmp_download_file=tmp_download_file,
212
+ target_file_path=target_file_path,
213
+ md5_hash=md5_hash,
214
+ verify_hash_while_download=verify_hash_while_download,
215
+ progress=progress,
216
+ response_codes_to_retry=response_codes_to_retry,
217
+ request_timeout=request_timeout,
218
+ max_threads_per_download=max_threads_per_download,
219
+ original_file_name=target_file_name,
220
+ on_file_created=on_file_created,
221
+ on_file_renamed=on_file_renamed,
222
+ )
223
+ finally:
224
+ remove_file_if_exists(path=tmp_download_file)
225
+
226
+
227
+ def safe_execute_download(
228
+ download_url: str,
229
+ tmp_download_file: str,
230
+ target_file_path: str,
231
+ md5_hash: MD5Hash,
232
+ verify_hash_while_download: bool,
233
+ progress: Progress,
234
+ response_codes_to_retry: Set[int],
235
+ request_timeout: int,
236
+ max_threads_per_download: int,
237
+ original_file_name: str,
238
+ on_file_created: Optional[Callable[[str], None]] = None,
239
+ on_file_renamed: Optional[Callable[[str, str], None]] = None,
240
+ ) -> None:
241
+ expected_file_size = safe_check_range_download_option(
242
+ url=download_url,
243
+ timeout=request_timeout,
244
+ response_codes_to_retry=response_codes_to_retry,
245
+ )
246
+ download_task = progress.add_task(
247
+ description=f"{original_file_name}: Download",
248
+ total=expected_file_size,
249
+ start=True,
250
+ visible=True,
251
+ )
252
+ hash_calculation_task = (
253
+ []
254
+ ) # yeah, this is a dirty trick to add task in closure in runtime
255
+
256
+ progress_task_lock = Lock()
257
+
258
+ def on_chunk_downloaded(bytes_num: int) -> None:
259
+ with progress_task_lock:
260
+ progress.advance(download_task, bytes_num)
261
+
262
+ def on_hash_calculation_started() -> None:
263
+ if len(hash_calculation_task) > 0:
264
+ return None
265
+ progress.remove_task(download_task)
266
+ new_hash_calculation_task = progress.add_task(
267
+ description=f"{original_file_name}: Verify hash",
268
+ total=expected_file_size,
269
+ start=True,
270
+ visible=True,
271
+ )
272
+ hash_calculation_task.append(new_hash_calculation_task)
273
+
274
+ def on_hash_chunk_calculated(bytes_num: int) -> None:
275
+ if len(hash_calculation_task) != 1:
276
+ return None
277
+ progress.advance(hash_calculation_task[0], bytes_num)
278
+
279
+ if (
280
+ expected_file_size is None
281
+ or expected_file_size < MIN_SIZE_FOR_THREADED_DOWNLOAD
282
+ or max_threads_per_download <= 1
283
+ ):
284
+ stream_download(
285
+ url=download_url,
286
+ target_path=tmp_download_file,
287
+ timeout=request_timeout,
288
+ md5_hash=md5_hash,
289
+ verify_hash_while_download=verify_hash_while_download,
290
+ response_codes_to_retry=response_codes_to_retry,
291
+ on_chunk_downloaded=on_chunk_downloaded,
292
+ on_file_created=on_file_created,
293
+ )
294
+ else:
295
+ threaded_download_file(
296
+ url=download_url,
297
+ target_path=tmp_download_file,
298
+ file_size=expected_file_size,
299
+ response_codes_to_retry=response_codes_to_retry,
300
+ request_timeout=request_timeout,
301
+ md5_hash=md5_hash,
302
+ verify_hash_while_download=verify_hash_while_download,
303
+ max_threads_per_download=max_threads_per_download,
304
+ on_chunk_downloaded=on_chunk_downloaded,
305
+ on_file_created=on_file_created,
306
+ on_hash_calculation_started=on_hash_calculation_started,
307
+ on_hash_chunk_calculated=on_hash_chunk_calculated,
308
+ )
309
+ os.rename(tmp_download_file, target_file_path)
310
+ if on_file_renamed:
311
+ on_file_renamed(tmp_download_file, target_file_path)
312
+
313
+
314
+ def safe_check_range_download_option(
315
+ url: str, timeout: int, response_codes_to_retry: Set[int]
316
+ ) -> Optional[int]:
317
+ try:
318
+ return check_range_download_option(
319
+ url=url, timeout=timeout, response_codes_to_retry=response_codes_to_retry
320
+ )
321
+ except Exception:
322
+ LOGGER.warning(f"Cannot use range requests for {url}")
323
+ return None
324
+
325
+
326
+ @backoff.on_exception(
327
+ backoff.constant,
328
+ exception=RetryError,
329
+ max_tries=API_CALLS_MAX_TRIES,
330
+ interval=1,
331
+ )
332
+ def check_range_download_option(
333
+ url: str, timeout: int, response_codes_to_retry: Set[int]
334
+ ) -> Optional[int]:
335
+ try:
336
+ response = requests.head(url, timeout=timeout)
337
+ except (OSError, Timeout, requests.exceptions.ConnectionError):
338
+ raise RetryError(
339
+ message=f"Connectivity error for URL: {url}", help_url="https://todo"
340
+ )
341
+ if response.status_code in response_codes_to_retry:
342
+ raise RetryError(
343
+ message=f"Remote server returned response code {response.status_code} for URL {url}",
344
+ help_url="https://todo",
345
+ )
346
+ response.raise_for_status()
347
+ accept_ranges = response.headers.get("accept-ranges", "none")
348
+ content_length = response.headers.get("content-length")
349
+ if "bytes" not in accept_ranges.lower():
350
+ return None
351
+ if not content_length:
352
+ return None
353
+ return int(content_length)
354
+
355
+
356
+ @backoff.on_exception(
357
+ backoff.constant,
358
+ exception=RetryError,
359
+ max_tries=API_CALLS_MAX_TRIES,
360
+ interval=1,
361
+ )
362
+ def get_content_length(
363
+ url: str,
364
+ timeout: Optional[int] = None,
365
+ response_codes_to_retry: Optional[Set[int]] = None,
366
+ ) -> Optional[int]:
367
+ if response_codes_to_retry is None:
368
+ response_codes_to_retry = IDEMPOTENT_API_REQUEST_CODES_TO_RETRY
369
+ if timeout is None:
370
+ timeout = API_CALLS_TIMEOUT
371
+ try:
372
+ response = requests.head(url, timeout=timeout)
373
+ except (OSError, Timeout, requests.exceptions.ConnectionError):
374
+ raise RetryError(
375
+ message=f"Connectivity error for URL: {url}", help_url="https://todo"
376
+ )
377
+ if response.status_code in response_codes_to_retry:
378
+ raise RetryError(
379
+ message=f"Remote server returned response code {response.status_code} for URL {url}",
380
+ help_url="https://todo",
381
+ )
382
+ response.raise_for_status()
383
+ content_length = response.headers.get("content-length")
384
+ if content_length is None:
385
+ return None
386
+ return int(content_length)
387
+
388
+
389
+ def threaded_download_file(
390
+ url: str,
391
+ target_path: str,
392
+ file_size: int,
393
+ response_codes_to_retry: Set[int],
394
+ request_timeout: int,
395
+ max_threads_per_download: int,
396
+ md5_hash: MD5Hash,
397
+ verify_hash_while_download: bool,
398
+ on_chunk_downloaded: Optional[Callable[[int], None]] = None,
399
+ on_file_created: Optional[Callable[[str], None]] = None,
400
+ on_hash_calculation_started: Optional[Callable[[], None]] = None,
401
+ on_hash_chunk_calculated: Optional[Callable[[int], None]] = None,
402
+ ) -> None:
403
+ chunks_boundaries = generate_chunks_boundaries(
404
+ file_size=file_size,
405
+ max_threads=max_threads_per_download,
406
+ min_chunk_size=MIN_THREAD_CHUNK_SIZE,
407
+ )
408
+ pre_allocate_file(
409
+ path=target_path, file_size=file_size, on_file_created=on_file_created
410
+ )
411
+ futures = []
412
+ max_workers = min(len(chunks_boundaries), max_threads_per_download)
413
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
414
+ for start, end in chunks_boundaries:
415
+ future = executor.submit(
416
+ download_chunk,
417
+ url=url,
418
+ start=start,
419
+ end=end,
420
+ target_path=target_path,
421
+ timeout=request_timeout,
422
+ response_codes_to_retry=response_codes_to_retry,
423
+ on_chunk_downloaded=on_chunk_downloaded,
424
+ )
425
+ futures.append(future)
426
+ done_futures, pending_futures = wait(futures, return_when=FIRST_EXCEPTION)
427
+ for pending_future in pending_futures:
428
+ pending_future.cancel()
429
+ _ = wait(pending_futures)
430
+ for future in done_futures:
431
+ future_exception = future.exception()
432
+ if future_exception:
433
+ raise future_exception
434
+ if not verify_hash_while_download:
435
+ return None
436
+ if on_hash_calculation_started:
437
+ on_hash_calculation_started()
438
+ verify_hash_sum_of_local_file(
439
+ url=url,
440
+ file_path=target_path,
441
+ expected_md5_hash=md5_hash,
442
+ on_hash_chunk_calculated=on_hash_chunk_calculated,
443
+ )
444
+
445
+
446
+ def verify_hash_sum_of_local_file(
447
+ url: str,
448
+ file_path: str,
449
+ expected_md5_hash: MD5Hash,
450
+ on_hash_chunk_calculated: Optional[Callable[[int], None]] = None,
451
+ ) -> None:
452
+ computed_hash = hashlib.md5()
453
+ for file_chunk in stream_file_bytes(
454
+ path=file_path, chunk_size=MIN_THREAD_CHUNK_SIZE
455
+ ):
456
+ computed_hash.update(file_chunk)
457
+ if on_hash_chunk_calculated:
458
+ on_hash_chunk_calculated(len(file_chunk))
459
+ if computed_hash.hexdigest() != expected_md5_hash:
460
+ raise FileHashSumMissmatch(
461
+ f"Could not confirm the validity of file content for url: {url}. "
462
+ f"Expected MD5: {expected_md5_hash}, calculated hash: {computed_hash.hexdigest()}",
463
+ help_url="https://todo",
464
+ )
465
+
466
+
467
+ def generate_chunks_boundaries(
468
+ file_size: int,
469
+ max_threads: int,
470
+ min_chunk_size: int,
471
+ ) -> List[Tuple[int, int]]:
472
+ if file_size <= 0:
473
+ return []
474
+ chunk_size = math.ceil(file_size / max_threads)
475
+ if chunk_size < min_chunk_size:
476
+ chunk_size = min_chunk_size
477
+ ranges = []
478
+ accumulated_size = 0
479
+ while accumulated_size < file_size:
480
+ ranges.append((accumulated_size, accumulated_size + chunk_size - 1))
481
+ accumulated_size += chunk_size
482
+ ranges[-1] = (ranges[-1][0], file_size - 1)
483
+ return ranges
484
+
485
+
486
+ @backoff.on_exception(
487
+ backoff.constant,
488
+ exception=RetryError,
489
+ max_tries=API_CALLS_MAX_TRIES,
490
+ interval=1,
491
+ )
492
+ def download_chunk(
493
+ url: str,
494
+ start: int,
495
+ end: int,
496
+ target_path: str,
497
+ timeout: int,
498
+ response_codes_to_retry: Set[int],
499
+ file_chunk: int = DEFAULT_STREAM_DOWNLOAD_CHUNK,
500
+ on_chunk_downloaded: Optional[Callable[[int], None]] = None,
501
+ ) -> None:
502
+ headers = {"Range": f"bytes={start}-{end}"}
503
+ try:
504
+ with requests.get(
505
+ url, headers=headers, stream=True, timeout=timeout
506
+ ) as response:
507
+ if response.status_code in response_codes_to_retry:
508
+ raise RetryError(
509
+ message=f"File hosting returned {response.status_code}",
510
+ help_url="https://todo",
511
+ )
512
+ response.raise_for_status()
513
+ if response.status_code != 206:
514
+ raise RetryError(
515
+ message=f"Server does not support range requests (returned {response.status_code} instead of 206)",
516
+ help_url="https://todo",
517
+ )
518
+ _handle_stream_download(
519
+ response=response,
520
+ target_path=target_path,
521
+ file_chunk=file_chunk,
522
+ on_chunk_downloaded=on_chunk_downloaded,
523
+ file_open_mode="r+b",
524
+ offset=start,
525
+ )
526
+ except (ConnectionError, Timeout, requests.exceptions.ConnectionError):
527
+ raise RetryError(
528
+ message=f"Connectivity error",
529
+ help_url="https://todo",
530
+ )
531
+
532
+
533
+ @backoff.on_exception(
534
+ backoff.constant,
535
+ exception=RetryError,
536
+ max_tries=API_CALLS_MAX_TRIES,
537
+ interval=1,
538
+ )
539
+ def stream_download(
540
+ url: str,
541
+ target_path: str,
542
+ timeout: int,
543
+ response_codes_to_retry: Set[int],
544
+ md5_hash: MD5Hash,
545
+ verify_hash_while_download: bool,
546
+ file_chunk: int = DEFAULT_STREAM_DOWNLOAD_CHUNK,
547
+ on_chunk_downloaded: Optional[Callable[[int], None]] = None,
548
+ on_file_created: Optional[Callable[[str], None]] = None,
549
+ ) -> None:
550
+ ensure_parent_dir_exists(path=target_path)
551
+ computed_hash = (
552
+ HashNullObject()
553
+ if md5_hash is None or verify_hash_while_download is None
554
+ else hashlib.md5()
555
+ )
556
+ try:
557
+ with requests.get(url, stream=True, timeout=timeout) as response:
558
+ if response.status_code in response_codes_to_retry:
559
+ raise RetryError(
560
+ message=f"File hosting returned {response.status_code}",
561
+ help_url="https://todo",
562
+ )
563
+ response.raise_for_status()
564
+ _handle_stream_download(
565
+ response=response,
566
+ target_path=target_path,
567
+ file_chunk=file_chunk,
568
+ on_chunk_downloaded=on_chunk_downloaded,
569
+ content_storage=computed_hash,
570
+ on_file_created=on_file_created,
571
+ )
572
+ except (ConnectionError, Timeout, requests.exceptions.ConnectionError):
573
+ raise RetryError(
574
+ message=f"Connectivity error",
575
+ help_url="https://todo",
576
+ )
577
+ if not verify_hash_while_download:
578
+ return None
579
+ if computed_hash.hexdigest() != md5_hash:
580
+ raise FileHashSumMissmatch(
581
+ f"Could not confirm the validity of file content for url: {url}. Expected MD5: {md5_hash}, "
582
+ f"calculated hash: {computed_hash.hexdigest()}",
583
+ help_url="https://todo",
584
+ )
585
+ return None
586
+
587
+
588
+ def _handle_stream_download(
589
+ response: Response,
590
+ target_path: str,
591
+ file_chunk: int,
592
+ on_chunk_downloaded: Optional[Callable[[int], None]] = None,
593
+ file_open_mode: str = "wb",
594
+ offset: Optional[int] = None,
595
+ content_storage: Optional[Union[hashlib.md5, HashNullObject]] = None,
596
+ on_file_created: Optional[Callable[[str], None]] = None,
597
+ ) -> None:
598
+ with open(target_path, file_open_mode) as file:
599
+ if on_file_created:
600
+ on_file_created(target_path)
601
+ if offset:
602
+ file.seek(offset)
603
+ for chunk in response.iter_content(file_chunk):
604
+ file.write(chunk)
605
+ if content_storage is not None:
606
+ content_storage.update(chunk)
607
+ if on_chunk_downloaded:
608
+ on_chunk_downloaded(len(chunk))
@@ -0,0 +1,28 @@
1
+ from typing import Any, List
2
+
3
+ from inference_models.errors import InvalidEnvVariable
4
+
5
+
6
+ def parse_comma_separated_values(values: str) -> List[str]:
7
+ if not values:
8
+ return []
9
+ return [v.strip() for v in values.split(",") if v.strip()]
10
+
11
+
12
+ def str2bool(value: Any) -> bool:
13
+ if isinstance(value, bool):
14
+ return value
15
+ if not issubclass(type(value), str):
16
+ raise InvalidEnvVariable(
17
+ message=f"Expected a boolean environment variable (true or false) but got '{value}'",
18
+ help_url="https://todo",
19
+ )
20
+ if value.lower() == "true":
21
+ return True
22
+ elif value.lower() == "false":
23
+ return False
24
+ else:
25
+ raise InvalidEnvVariable(
26
+ message=f"Expected a boolean environment variable (true or false) but got '{value}'",
27
+ help_url="https://todo",
28
+ )
@@ -0,0 +1,51 @@
1
+ import json
2
+ import os
3
+ from typing import Callable, Generator, Optional, Union
4
+
5
+
6
+ def stream_file_lines(path: str) -> Generator[str, None, None]:
7
+ with open(path, "r") as f:
8
+ for line in f.readlines():
9
+ stripped_line = line.strip()
10
+ if stripped_line:
11
+ yield stripped_line
12
+
13
+
14
+ def stream_file_bytes(path: str, chunk_size: int) -> Generator[bytes, None, None]:
15
+ chunk_size = max(chunk_size, 1)
16
+ with open(path, "rb") as f:
17
+ chunk = f.read(chunk_size)
18
+ while chunk:
19
+ yield chunk
20
+ chunk = f.read(chunk_size)
21
+
22
+
23
+ def read_json(path: str) -> Optional[Union[dict, list]]:
24
+ with open(path) as f:
25
+ return json.load(f)
26
+
27
+
28
+ def dump_json(path: str, content: Union[dict, list]) -> None:
29
+ ensure_parent_dir_exists(path=path)
30
+ with open(path, "w") as f:
31
+ json.dump(content, f)
32
+
33
+
34
+ def pre_allocate_file(
35
+ path: str, file_size: int, on_file_created: Optional[Callable[[str], None]] = None
36
+ ) -> None:
37
+ ensure_parent_dir_exists(path=path)
38
+ with open(path, "wb") as f:
39
+ if on_file_created:
40
+ on_file_created(path)
41
+ f.truncate(file_size)
42
+
43
+
44
+ def ensure_parent_dir_exists(path: str) -> None:
45
+ parent_dir = os.path.dirname(os.path.abspath(path))
46
+ os.makedirs(parent_dir, exist_ok=True)
47
+
48
+
49
+ def remove_file_if_exists(path: str) -> None:
50
+ if os.path.isfile(path):
51
+ os.remove(path)
@@ -0,0 +1,7 @@
1
+ import hashlib
2
+ import json
3
+
4
+
5
+ def hash_dict_content(content: dict) -> str:
6
+ content_string = json.dumps(content, sort_keys=True)
7
+ return hashlib.sha256(content_string.encode()).hexdigest()