inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,608 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import math
|
|
3
|
+
import os
|
|
4
|
+
from concurrent.futures import FIRST_EXCEPTION, ThreadPoolExecutor, wait
|
|
5
|
+
from threading import Lock
|
|
6
|
+
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union
|
|
7
|
+
from uuid import uuid4
|
|
8
|
+
|
|
9
|
+
import backoff
|
|
10
|
+
import requests
|
|
11
|
+
from filelock import FileLock
|
|
12
|
+
from requests import Response, Timeout
|
|
13
|
+
from rich.progress import (
|
|
14
|
+
BarColumn,
|
|
15
|
+
DownloadColumn,
|
|
16
|
+
Progress,
|
|
17
|
+
TimeRemainingColumn,
|
|
18
|
+
TransferSpeedColumn,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from inference_models.configuration import (
|
|
22
|
+
API_CALLS_MAX_TRIES,
|
|
23
|
+
API_CALLS_TIMEOUT,
|
|
24
|
+
DISABLE_INTERACTIVE_PROGRESS_BARS,
|
|
25
|
+
IDEMPOTENT_API_REQUEST_CODES_TO_RETRY,
|
|
26
|
+
)
|
|
27
|
+
from inference_models.errors import (
|
|
28
|
+
FileHashSumMissmatch,
|
|
29
|
+
InvalidParameterError,
|
|
30
|
+
RetryError,
|
|
31
|
+
UntrustedFileError,
|
|
32
|
+
)
|
|
33
|
+
from inference_models.logger import LOGGER
|
|
34
|
+
from inference_models.utils.file_system import (
|
|
35
|
+
ensure_parent_dir_exists,
|
|
36
|
+
pre_allocate_file,
|
|
37
|
+
remove_file_if_exists,
|
|
38
|
+
stream_file_bytes,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
FileHandle = str
|
|
42
|
+
DownloadUrl = str
|
|
43
|
+
MD5Hash = Optional[str]
|
|
44
|
+
|
|
45
|
+
MIN_SIZE_FOR_THREADED_DOWNLOAD = 32 * 1024 * 1024 # 32MB
|
|
46
|
+
MIN_THREAD_CHUNK_SIZE = 16 * 1024 * 1024 # 16MB
|
|
47
|
+
DEFAULT_STREAM_DOWNLOAD_CHUNK = 1 * 1024 * 1024 # 1MB
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class HashNullObject:
|
|
51
|
+
|
|
52
|
+
def update(self, *args, **kwargs) -> None:
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
def hexdigest(self) -> None:
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def download_files_to_directory(
|
|
60
|
+
target_dir: str,
|
|
61
|
+
files_specs: List[Tuple[FileHandle, DownloadUrl, MD5Hash]],
|
|
62
|
+
verbose: bool = True,
|
|
63
|
+
response_codes_to_retry: Optional[Set[int]] = None,
|
|
64
|
+
request_timeout: Optional[int] = None,
|
|
65
|
+
max_parallel_downloads: int = 8,
|
|
66
|
+
max_threads_per_download: int = 8,
|
|
67
|
+
file_lock_acquire_timeout: int = 10,
|
|
68
|
+
verify_hash_while_download: bool = True,
|
|
69
|
+
download_files_without_hash: bool = False,
|
|
70
|
+
name_after: Literal["file_handle", "md5_hash"] = "file_handle",
|
|
71
|
+
on_file_created: Optional[Callable[[str], None]] = None,
|
|
72
|
+
on_file_renamed: Optional[Callable[[str, str], None]] = None,
|
|
73
|
+
) -> Dict[str, str]:
|
|
74
|
+
if name_after not in {"file_handle", "md5_hash"}:
|
|
75
|
+
raise InvalidParameterError(
|
|
76
|
+
message="Function download_files_to_directory(...) was called with "
|
|
77
|
+
f"invalid value of parameter `name_after` - received value `{name_after}`. "
|
|
78
|
+
f"This is a bug in `inference-models` - submit new issue under "
|
|
79
|
+
f"https://github.com/roboflow/inference/issues/",
|
|
80
|
+
help_url="https://todo",
|
|
81
|
+
)
|
|
82
|
+
if DISABLE_INTERACTIVE_PROGRESS_BARS:
|
|
83
|
+
verbose = False
|
|
84
|
+
files_mapping = construct_files_path_mapping(
|
|
85
|
+
target_dir=target_dir,
|
|
86
|
+
files_specs=files_specs,
|
|
87
|
+
name_after=name_after,
|
|
88
|
+
)
|
|
89
|
+
files_specs = exclude_existing_files(
|
|
90
|
+
files_specs=files_specs,
|
|
91
|
+
files_mapping=files_mapping,
|
|
92
|
+
)
|
|
93
|
+
if not files_specs:
|
|
94
|
+
return files_mapping
|
|
95
|
+
if response_codes_to_retry is None:
|
|
96
|
+
response_codes_to_retry = IDEMPOTENT_API_REQUEST_CODES_TO_RETRY
|
|
97
|
+
if request_timeout is None:
|
|
98
|
+
request_timeout = API_CALLS_TIMEOUT
|
|
99
|
+
if not download_files_without_hash:
|
|
100
|
+
untrusted_files = [f[1] for f in files_specs if f[2] is None]
|
|
101
|
+
if len(untrusted_files) > 0:
|
|
102
|
+
raise UntrustedFileError(
|
|
103
|
+
message=f"While downloading files detected {len(untrusted_files)} untrusted file(s): {untrusted_files} "
|
|
104
|
+
f"without MD5 hash sum to verify the download content. The download method was used with "
|
|
105
|
+
f"`download_files_without_hash=False` - which prevents from downloading such files. If you see "
|
|
106
|
+
f"this error while using hosted Roboflow serving option - contact us to get support.",
|
|
107
|
+
help_url="https://todo",
|
|
108
|
+
)
|
|
109
|
+
os.makedirs(target_dir, exist_ok=True)
|
|
110
|
+
progress = Progress(
|
|
111
|
+
"[progress.description]{task.description}",
|
|
112
|
+
BarColumn(),
|
|
113
|
+
DownloadColumn(),
|
|
114
|
+
TransferSpeedColumn(),
|
|
115
|
+
TimeRemainingColumn(),
|
|
116
|
+
disable=not verbose,
|
|
117
|
+
)
|
|
118
|
+
download_id = str(uuid4())
|
|
119
|
+
with progress:
|
|
120
|
+
with ThreadPoolExecutor(max_workers=max_parallel_downloads) as executor:
|
|
121
|
+
futures = []
|
|
122
|
+
for file_handle, download_url, md5_hash in files_specs:
|
|
123
|
+
future = executor.submit(
|
|
124
|
+
safe_download_file,
|
|
125
|
+
target_file_path=files_mapping[file_handle],
|
|
126
|
+
download_url=download_url,
|
|
127
|
+
md5_hash=md5_hash,
|
|
128
|
+
verify_hash_while_download=verify_hash_while_download,
|
|
129
|
+
download_id=download_id,
|
|
130
|
+
progress=progress,
|
|
131
|
+
response_codes_to_retry=response_codes_to_retry,
|
|
132
|
+
request_timeout=request_timeout,
|
|
133
|
+
max_threads_per_download=max_threads_per_download,
|
|
134
|
+
file_lock_acquire_timeout=file_lock_acquire_timeout,
|
|
135
|
+
on_file_created=on_file_created,
|
|
136
|
+
on_file_renamed=on_file_renamed,
|
|
137
|
+
)
|
|
138
|
+
futures.append(future)
|
|
139
|
+
done_futures, pending_futures = wait(futures, return_when=FIRST_EXCEPTION)
|
|
140
|
+
for pending_future in pending_futures:
|
|
141
|
+
pending_future.cancel()
|
|
142
|
+
_ = wait(pending_futures)
|
|
143
|
+
for future in done_futures:
|
|
144
|
+
future_exception = future.exception()
|
|
145
|
+
if future_exception:
|
|
146
|
+
raise future_exception
|
|
147
|
+
return files_mapping
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def construct_files_path_mapping(
|
|
151
|
+
target_dir: str,
|
|
152
|
+
files_specs: List[Tuple[FileHandle, DownloadUrl, MD5Hash]],
|
|
153
|
+
name_after: Literal["file_handle", "md5_hash"] = "file_handle",
|
|
154
|
+
) -> Dict[FileHandle, str]:
|
|
155
|
+
result = {}
|
|
156
|
+
for file_handle, download_url, content_hash in files_specs:
|
|
157
|
+
if name_after == "md5_hash" and content_hash is None:
|
|
158
|
+
raise UntrustedFileError(
|
|
159
|
+
message="Attempted to download file without declared hash sum when "
|
|
160
|
+
"`name_after='md5_hash'` - this problem is either misconfiguration "
|
|
161
|
+
"of download procedure in `inference-models` or bug in the codebase. "
|
|
162
|
+
"If you see this error using hosted Roboflow solution - contact us to get "
|
|
163
|
+
"help. Running locally, verify the download code and raise an issue if you see "
|
|
164
|
+
"a bug: https://github.com/roboflow/inference/issues/",
|
|
165
|
+
help_url="https://todo",
|
|
166
|
+
)
|
|
167
|
+
if name_after == "md5_hash":
|
|
168
|
+
target_path = os.path.join(target_dir, content_hash)
|
|
169
|
+
else:
|
|
170
|
+
target_path = os.path.join(target_dir, file_handle)
|
|
171
|
+
result[file_handle] = target_path
|
|
172
|
+
return result
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def exclude_existing_files(
|
|
176
|
+
files_specs: List[Tuple[FileHandle, DownloadUrl, MD5Hash]],
|
|
177
|
+
files_mapping: Dict[FileHandle, str],
|
|
178
|
+
) -> List[Tuple[FileHandle, DownloadUrl, MD5Hash]]:
|
|
179
|
+
result = []
|
|
180
|
+
for file_specs in files_specs:
|
|
181
|
+
target_path = files_mapping[file_specs[0]]
|
|
182
|
+
if not os.path.isfile(target_path):
|
|
183
|
+
result.append(file_specs)
|
|
184
|
+
return result
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def safe_download_file(
|
|
188
|
+
target_file_path: str,
|
|
189
|
+
download_url: str,
|
|
190
|
+
download_id: str,
|
|
191
|
+
md5_hash: MD5Hash,
|
|
192
|
+
verify_hash_while_download: bool,
|
|
193
|
+
progress: Progress,
|
|
194
|
+
response_codes_to_retry: Set[int],
|
|
195
|
+
request_timeout: int,
|
|
196
|
+
max_threads_per_download: int,
|
|
197
|
+
file_lock_acquire_timeout: int,
|
|
198
|
+
on_file_created: Optional[Callable[[str], None]] = None,
|
|
199
|
+
on_file_renamed: Optional[Callable[[str, str], None]] = None,
|
|
200
|
+
) -> None:
|
|
201
|
+
ensure_parent_dir_exists(path=target_file_path)
|
|
202
|
+
target_file_dir, target_file_name = os.path.split(target_file_path)
|
|
203
|
+
lock_path = os.path.join(target_file_dir, f".{target_file_name}.lock")
|
|
204
|
+
tmp_download_file = os.path.abspath(
|
|
205
|
+
os.path.join(target_file_dir, f"{target_file_name}.{download_id}")
|
|
206
|
+
)
|
|
207
|
+
try:
|
|
208
|
+
with FileLock(lock_path, timeout=file_lock_acquire_timeout):
|
|
209
|
+
safe_execute_download(
|
|
210
|
+
download_url=download_url,
|
|
211
|
+
tmp_download_file=tmp_download_file,
|
|
212
|
+
target_file_path=target_file_path,
|
|
213
|
+
md5_hash=md5_hash,
|
|
214
|
+
verify_hash_while_download=verify_hash_while_download,
|
|
215
|
+
progress=progress,
|
|
216
|
+
response_codes_to_retry=response_codes_to_retry,
|
|
217
|
+
request_timeout=request_timeout,
|
|
218
|
+
max_threads_per_download=max_threads_per_download,
|
|
219
|
+
original_file_name=target_file_name,
|
|
220
|
+
on_file_created=on_file_created,
|
|
221
|
+
on_file_renamed=on_file_renamed,
|
|
222
|
+
)
|
|
223
|
+
finally:
|
|
224
|
+
remove_file_if_exists(path=tmp_download_file)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def safe_execute_download(
|
|
228
|
+
download_url: str,
|
|
229
|
+
tmp_download_file: str,
|
|
230
|
+
target_file_path: str,
|
|
231
|
+
md5_hash: MD5Hash,
|
|
232
|
+
verify_hash_while_download: bool,
|
|
233
|
+
progress: Progress,
|
|
234
|
+
response_codes_to_retry: Set[int],
|
|
235
|
+
request_timeout: int,
|
|
236
|
+
max_threads_per_download: int,
|
|
237
|
+
original_file_name: str,
|
|
238
|
+
on_file_created: Optional[Callable[[str], None]] = None,
|
|
239
|
+
on_file_renamed: Optional[Callable[[str, str], None]] = None,
|
|
240
|
+
) -> None:
|
|
241
|
+
expected_file_size = safe_check_range_download_option(
|
|
242
|
+
url=download_url,
|
|
243
|
+
timeout=request_timeout,
|
|
244
|
+
response_codes_to_retry=response_codes_to_retry,
|
|
245
|
+
)
|
|
246
|
+
download_task = progress.add_task(
|
|
247
|
+
description=f"{original_file_name}: Download",
|
|
248
|
+
total=expected_file_size,
|
|
249
|
+
start=True,
|
|
250
|
+
visible=True,
|
|
251
|
+
)
|
|
252
|
+
hash_calculation_task = (
|
|
253
|
+
[]
|
|
254
|
+
) # yeah, this is a dirty trick to add task in closure in runtime
|
|
255
|
+
|
|
256
|
+
progress_task_lock = Lock()
|
|
257
|
+
|
|
258
|
+
def on_chunk_downloaded(bytes_num: int) -> None:
|
|
259
|
+
with progress_task_lock:
|
|
260
|
+
progress.advance(download_task, bytes_num)
|
|
261
|
+
|
|
262
|
+
def on_hash_calculation_started() -> None:
|
|
263
|
+
if len(hash_calculation_task) > 0:
|
|
264
|
+
return None
|
|
265
|
+
progress.remove_task(download_task)
|
|
266
|
+
new_hash_calculation_task = progress.add_task(
|
|
267
|
+
description=f"{original_file_name}: Verify hash",
|
|
268
|
+
total=expected_file_size,
|
|
269
|
+
start=True,
|
|
270
|
+
visible=True,
|
|
271
|
+
)
|
|
272
|
+
hash_calculation_task.append(new_hash_calculation_task)
|
|
273
|
+
|
|
274
|
+
def on_hash_chunk_calculated(bytes_num: int) -> None:
|
|
275
|
+
if len(hash_calculation_task) != 1:
|
|
276
|
+
return None
|
|
277
|
+
progress.advance(hash_calculation_task[0], bytes_num)
|
|
278
|
+
|
|
279
|
+
if (
|
|
280
|
+
expected_file_size is None
|
|
281
|
+
or expected_file_size < MIN_SIZE_FOR_THREADED_DOWNLOAD
|
|
282
|
+
or max_threads_per_download <= 1
|
|
283
|
+
):
|
|
284
|
+
stream_download(
|
|
285
|
+
url=download_url,
|
|
286
|
+
target_path=tmp_download_file,
|
|
287
|
+
timeout=request_timeout,
|
|
288
|
+
md5_hash=md5_hash,
|
|
289
|
+
verify_hash_while_download=verify_hash_while_download,
|
|
290
|
+
response_codes_to_retry=response_codes_to_retry,
|
|
291
|
+
on_chunk_downloaded=on_chunk_downloaded,
|
|
292
|
+
on_file_created=on_file_created,
|
|
293
|
+
)
|
|
294
|
+
else:
|
|
295
|
+
threaded_download_file(
|
|
296
|
+
url=download_url,
|
|
297
|
+
target_path=tmp_download_file,
|
|
298
|
+
file_size=expected_file_size,
|
|
299
|
+
response_codes_to_retry=response_codes_to_retry,
|
|
300
|
+
request_timeout=request_timeout,
|
|
301
|
+
md5_hash=md5_hash,
|
|
302
|
+
verify_hash_while_download=verify_hash_while_download,
|
|
303
|
+
max_threads_per_download=max_threads_per_download,
|
|
304
|
+
on_chunk_downloaded=on_chunk_downloaded,
|
|
305
|
+
on_file_created=on_file_created,
|
|
306
|
+
on_hash_calculation_started=on_hash_calculation_started,
|
|
307
|
+
on_hash_chunk_calculated=on_hash_chunk_calculated,
|
|
308
|
+
)
|
|
309
|
+
os.rename(tmp_download_file, target_file_path)
|
|
310
|
+
if on_file_renamed:
|
|
311
|
+
on_file_renamed(tmp_download_file, target_file_path)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def safe_check_range_download_option(
|
|
315
|
+
url: str, timeout: int, response_codes_to_retry: Set[int]
|
|
316
|
+
) -> Optional[int]:
|
|
317
|
+
try:
|
|
318
|
+
return check_range_download_option(
|
|
319
|
+
url=url, timeout=timeout, response_codes_to_retry=response_codes_to_retry
|
|
320
|
+
)
|
|
321
|
+
except Exception:
|
|
322
|
+
LOGGER.warning(f"Cannot use range requests for {url}")
|
|
323
|
+
return None
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
@backoff.on_exception(
|
|
327
|
+
backoff.constant,
|
|
328
|
+
exception=RetryError,
|
|
329
|
+
max_tries=API_CALLS_MAX_TRIES,
|
|
330
|
+
interval=1,
|
|
331
|
+
)
|
|
332
|
+
def check_range_download_option(
|
|
333
|
+
url: str, timeout: int, response_codes_to_retry: Set[int]
|
|
334
|
+
) -> Optional[int]:
|
|
335
|
+
try:
|
|
336
|
+
response = requests.head(url, timeout=timeout)
|
|
337
|
+
except (OSError, Timeout, requests.exceptions.ConnectionError):
|
|
338
|
+
raise RetryError(
|
|
339
|
+
message=f"Connectivity error for URL: {url}", help_url="https://todo"
|
|
340
|
+
)
|
|
341
|
+
if response.status_code in response_codes_to_retry:
|
|
342
|
+
raise RetryError(
|
|
343
|
+
message=f"Remote server returned response code {response.status_code} for URL {url}",
|
|
344
|
+
help_url="https://todo",
|
|
345
|
+
)
|
|
346
|
+
response.raise_for_status()
|
|
347
|
+
accept_ranges = response.headers.get("accept-ranges", "none")
|
|
348
|
+
content_length = response.headers.get("content-length")
|
|
349
|
+
if "bytes" not in accept_ranges.lower():
|
|
350
|
+
return None
|
|
351
|
+
if not content_length:
|
|
352
|
+
return None
|
|
353
|
+
return int(content_length)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@backoff.on_exception(
|
|
357
|
+
backoff.constant,
|
|
358
|
+
exception=RetryError,
|
|
359
|
+
max_tries=API_CALLS_MAX_TRIES,
|
|
360
|
+
interval=1,
|
|
361
|
+
)
|
|
362
|
+
def get_content_length(
|
|
363
|
+
url: str,
|
|
364
|
+
timeout: Optional[int] = None,
|
|
365
|
+
response_codes_to_retry: Optional[Set[int]] = None,
|
|
366
|
+
) -> Optional[int]:
|
|
367
|
+
if response_codes_to_retry is None:
|
|
368
|
+
response_codes_to_retry = IDEMPOTENT_API_REQUEST_CODES_TO_RETRY
|
|
369
|
+
if timeout is None:
|
|
370
|
+
timeout = API_CALLS_TIMEOUT
|
|
371
|
+
try:
|
|
372
|
+
response = requests.head(url, timeout=timeout)
|
|
373
|
+
except (OSError, Timeout, requests.exceptions.ConnectionError):
|
|
374
|
+
raise RetryError(
|
|
375
|
+
message=f"Connectivity error for URL: {url}", help_url="https://todo"
|
|
376
|
+
)
|
|
377
|
+
if response.status_code in response_codes_to_retry:
|
|
378
|
+
raise RetryError(
|
|
379
|
+
message=f"Remote server returned response code {response.status_code} for URL {url}",
|
|
380
|
+
help_url="https://todo",
|
|
381
|
+
)
|
|
382
|
+
response.raise_for_status()
|
|
383
|
+
content_length = response.headers.get("content-length")
|
|
384
|
+
if content_length is None:
|
|
385
|
+
return None
|
|
386
|
+
return int(content_length)
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def threaded_download_file(
|
|
390
|
+
url: str,
|
|
391
|
+
target_path: str,
|
|
392
|
+
file_size: int,
|
|
393
|
+
response_codes_to_retry: Set[int],
|
|
394
|
+
request_timeout: int,
|
|
395
|
+
max_threads_per_download: int,
|
|
396
|
+
md5_hash: MD5Hash,
|
|
397
|
+
verify_hash_while_download: bool,
|
|
398
|
+
on_chunk_downloaded: Optional[Callable[[int], None]] = None,
|
|
399
|
+
on_file_created: Optional[Callable[[str], None]] = None,
|
|
400
|
+
on_hash_calculation_started: Optional[Callable[[], None]] = None,
|
|
401
|
+
on_hash_chunk_calculated: Optional[Callable[[int], None]] = None,
|
|
402
|
+
) -> None:
|
|
403
|
+
chunks_boundaries = generate_chunks_boundaries(
|
|
404
|
+
file_size=file_size,
|
|
405
|
+
max_threads=max_threads_per_download,
|
|
406
|
+
min_chunk_size=MIN_THREAD_CHUNK_SIZE,
|
|
407
|
+
)
|
|
408
|
+
pre_allocate_file(
|
|
409
|
+
path=target_path, file_size=file_size, on_file_created=on_file_created
|
|
410
|
+
)
|
|
411
|
+
futures = []
|
|
412
|
+
max_workers = min(len(chunks_boundaries), max_threads_per_download)
|
|
413
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
414
|
+
for start, end in chunks_boundaries:
|
|
415
|
+
future = executor.submit(
|
|
416
|
+
download_chunk,
|
|
417
|
+
url=url,
|
|
418
|
+
start=start,
|
|
419
|
+
end=end,
|
|
420
|
+
target_path=target_path,
|
|
421
|
+
timeout=request_timeout,
|
|
422
|
+
response_codes_to_retry=response_codes_to_retry,
|
|
423
|
+
on_chunk_downloaded=on_chunk_downloaded,
|
|
424
|
+
)
|
|
425
|
+
futures.append(future)
|
|
426
|
+
done_futures, pending_futures = wait(futures, return_when=FIRST_EXCEPTION)
|
|
427
|
+
for pending_future in pending_futures:
|
|
428
|
+
pending_future.cancel()
|
|
429
|
+
_ = wait(pending_futures)
|
|
430
|
+
for future in done_futures:
|
|
431
|
+
future_exception = future.exception()
|
|
432
|
+
if future_exception:
|
|
433
|
+
raise future_exception
|
|
434
|
+
if not verify_hash_while_download:
|
|
435
|
+
return None
|
|
436
|
+
if on_hash_calculation_started:
|
|
437
|
+
on_hash_calculation_started()
|
|
438
|
+
verify_hash_sum_of_local_file(
|
|
439
|
+
url=url,
|
|
440
|
+
file_path=target_path,
|
|
441
|
+
expected_md5_hash=md5_hash,
|
|
442
|
+
on_hash_chunk_calculated=on_hash_chunk_calculated,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def verify_hash_sum_of_local_file(
|
|
447
|
+
url: str,
|
|
448
|
+
file_path: str,
|
|
449
|
+
expected_md5_hash: MD5Hash,
|
|
450
|
+
on_hash_chunk_calculated: Optional[Callable[[int], None]] = None,
|
|
451
|
+
) -> None:
|
|
452
|
+
computed_hash = hashlib.md5()
|
|
453
|
+
for file_chunk in stream_file_bytes(
|
|
454
|
+
path=file_path, chunk_size=MIN_THREAD_CHUNK_SIZE
|
|
455
|
+
):
|
|
456
|
+
computed_hash.update(file_chunk)
|
|
457
|
+
if on_hash_chunk_calculated:
|
|
458
|
+
on_hash_chunk_calculated(len(file_chunk))
|
|
459
|
+
if computed_hash.hexdigest() != expected_md5_hash:
|
|
460
|
+
raise FileHashSumMissmatch(
|
|
461
|
+
f"Could not confirm the validity of file content for url: {url}. "
|
|
462
|
+
f"Expected MD5: {expected_md5_hash}, calculated hash: {computed_hash.hexdigest()}",
|
|
463
|
+
help_url="https://todo",
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def generate_chunks_boundaries(
|
|
468
|
+
file_size: int,
|
|
469
|
+
max_threads: int,
|
|
470
|
+
min_chunk_size: int,
|
|
471
|
+
) -> List[Tuple[int, int]]:
|
|
472
|
+
if file_size <= 0:
|
|
473
|
+
return []
|
|
474
|
+
chunk_size = math.ceil(file_size / max_threads)
|
|
475
|
+
if chunk_size < min_chunk_size:
|
|
476
|
+
chunk_size = min_chunk_size
|
|
477
|
+
ranges = []
|
|
478
|
+
accumulated_size = 0
|
|
479
|
+
while accumulated_size < file_size:
|
|
480
|
+
ranges.append((accumulated_size, accumulated_size + chunk_size - 1))
|
|
481
|
+
accumulated_size += chunk_size
|
|
482
|
+
ranges[-1] = (ranges[-1][0], file_size - 1)
|
|
483
|
+
return ranges
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
@backoff.on_exception(
|
|
487
|
+
backoff.constant,
|
|
488
|
+
exception=RetryError,
|
|
489
|
+
max_tries=API_CALLS_MAX_TRIES,
|
|
490
|
+
interval=1,
|
|
491
|
+
)
|
|
492
|
+
def download_chunk(
|
|
493
|
+
url: str,
|
|
494
|
+
start: int,
|
|
495
|
+
end: int,
|
|
496
|
+
target_path: str,
|
|
497
|
+
timeout: int,
|
|
498
|
+
response_codes_to_retry: Set[int],
|
|
499
|
+
file_chunk: int = DEFAULT_STREAM_DOWNLOAD_CHUNK,
|
|
500
|
+
on_chunk_downloaded: Optional[Callable[[int], None]] = None,
|
|
501
|
+
) -> None:
|
|
502
|
+
headers = {"Range": f"bytes={start}-{end}"}
|
|
503
|
+
try:
|
|
504
|
+
with requests.get(
|
|
505
|
+
url, headers=headers, stream=True, timeout=timeout
|
|
506
|
+
) as response:
|
|
507
|
+
if response.status_code in response_codes_to_retry:
|
|
508
|
+
raise RetryError(
|
|
509
|
+
message=f"File hosting returned {response.status_code}",
|
|
510
|
+
help_url="https://todo",
|
|
511
|
+
)
|
|
512
|
+
response.raise_for_status()
|
|
513
|
+
if response.status_code != 206:
|
|
514
|
+
raise RetryError(
|
|
515
|
+
message=f"Server does not support range requests (returned {response.status_code} instead of 206)",
|
|
516
|
+
help_url="https://todo",
|
|
517
|
+
)
|
|
518
|
+
_handle_stream_download(
|
|
519
|
+
response=response,
|
|
520
|
+
target_path=target_path,
|
|
521
|
+
file_chunk=file_chunk,
|
|
522
|
+
on_chunk_downloaded=on_chunk_downloaded,
|
|
523
|
+
file_open_mode="r+b",
|
|
524
|
+
offset=start,
|
|
525
|
+
)
|
|
526
|
+
except (ConnectionError, Timeout, requests.exceptions.ConnectionError):
|
|
527
|
+
raise RetryError(
|
|
528
|
+
message=f"Connectivity error",
|
|
529
|
+
help_url="https://todo",
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
@backoff.on_exception(
|
|
534
|
+
backoff.constant,
|
|
535
|
+
exception=RetryError,
|
|
536
|
+
max_tries=API_CALLS_MAX_TRIES,
|
|
537
|
+
interval=1,
|
|
538
|
+
)
|
|
539
|
+
def stream_download(
|
|
540
|
+
url: str,
|
|
541
|
+
target_path: str,
|
|
542
|
+
timeout: int,
|
|
543
|
+
response_codes_to_retry: Set[int],
|
|
544
|
+
md5_hash: MD5Hash,
|
|
545
|
+
verify_hash_while_download: bool,
|
|
546
|
+
file_chunk: int = DEFAULT_STREAM_DOWNLOAD_CHUNK,
|
|
547
|
+
on_chunk_downloaded: Optional[Callable[[int], None]] = None,
|
|
548
|
+
on_file_created: Optional[Callable[[str], None]] = None,
|
|
549
|
+
) -> None:
|
|
550
|
+
ensure_parent_dir_exists(path=target_path)
|
|
551
|
+
computed_hash = (
|
|
552
|
+
HashNullObject()
|
|
553
|
+
if md5_hash is None or verify_hash_while_download is None
|
|
554
|
+
else hashlib.md5()
|
|
555
|
+
)
|
|
556
|
+
try:
|
|
557
|
+
with requests.get(url, stream=True, timeout=timeout) as response:
|
|
558
|
+
if response.status_code in response_codes_to_retry:
|
|
559
|
+
raise RetryError(
|
|
560
|
+
message=f"File hosting returned {response.status_code}",
|
|
561
|
+
help_url="https://todo",
|
|
562
|
+
)
|
|
563
|
+
response.raise_for_status()
|
|
564
|
+
_handle_stream_download(
|
|
565
|
+
response=response,
|
|
566
|
+
target_path=target_path,
|
|
567
|
+
file_chunk=file_chunk,
|
|
568
|
+
on_chunk_downloaded=on_chunk_downloaded,
|
|
569
|
+
content_storage=computed_hash,
|
|
570
|
+
on_file_created=on_file_created,
|
|
571
|
+
)
|
|
572
|
+
except (ConnectionError, Timeout, requests.exceptions.ConnectionError):
|
|
573
|
+
raise RetryError(
|
|
574
|
+
message=f"Connectivity error",
|
|
575
|
+
help_url="https://todo",
|
|
576
|
+
)
|
|
577
|
+
if not verify_hash_while_download:
|
|
578
|
+
return None
|
|
579
|
+
if computed_hash.hexdigest() != md5_hash:
|
|
580
|
+
raise FileHashSumMissmatch(
|
|
581
|
+
f"Could not confirm the validity of file content for url: {url}. Expected MD5: {md5_hash}, "
|
|
582
|
+
f"calculated hash: {computed_hash.hexdigest()}",
|
|
583
|
+
help_url="https://todo",
|
|
584
|
+
)
|
|
585
|
+
return None
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def _handle_stream_download(
|
|
589
|
+
response: Response,
|
|
590
|
+
target_path: str,
|
|
591
|
+
file_chunk: int,
|
|
592
|
+
on_chunk_downloaded: Optional[Callable[[int], None]] = None,
|
|
593
|
+
file_open_mode: str = "wb",
|
|
594
|
+
offset: Optional[int] = None,
|
|
595
|
+
content_storage: Optional[Union[hashlib.md5, HashNullObject]] = None,
|
|
596
|
+
on_file_created: Optional[Callable[[str], None]] = None,
|
|
597
|
+
) -> None:
|
|
598
|
+
with open(target_path, file_open_mode) as file:
|
|
599
|
+
if on_file_created:
|
|
600
|
+
on_file_created(target_path)
|
|
601
|
+
if offset:
|
|
602
|
+
file.seek(offset)
|
|
603
|
+
for chunk in response.iter_content(file_chunk):
|
|
604
|
+
file.write(chunk)
|
|
605
|
+
if content_storage is not None:
|
|
606
|
+
content_storage.update(chunk)
|
|
607
|
+
if on_chunk_downloaded:
|
|
608
|
+
on_chunk_downloaded(len(chunk))
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import Any, List
|
|
2
|
+
|
|
3
|
+
from inference_models.errors import InvalidEnvVariable
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def parse_comma_separated_values(values: str) -> List[str]:
|
|
7
|
+
if not values:
|
|
8
|
+
return []
|
|
9
|
+
return [v.strip() for v in values.split(",") if v.strip()]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def str2bool(value: Any) -> bool:
|
|
13
|
+
if isinstance(value, bool):
|
|
14
|
+
return value
|
|
15
|
+
if not issubclass(type(value), str):
|
|
16
|
+
raise InvalidEnvVariable(
|
|
17
|
+
message=f"Expected a boolean environment variable (true or false) but got '{value}'",
|
|
18
|
+
help_url="https://todo",
|
|
19
|
+
)
|
|
20
|
+
if value.lower() == "true":
|
|
21
|
+
return True
|
|
22
|
+
elif value.lower() == "false":
|
|
23
|
+
return False
|
|
24
|
+
else:
|
|
25
|
+
raise InvalidEnvVariable(
|
|
26
|
+
message=f"Expected a boolean environment variable (true or false) but got '{value}'",
|
|
27
|
+
help_url="https://todo",
|
|
28
|
+
)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Callable, Generator, Optional, Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def stream_file_lines(path: str) -> Generator[str, None, None]:
|
|
7
|
+
with open(path, "r") as f:
|
|
8
|
+
for line in f.readlines():
|
|
9
|
+
stripped_line = line.strip()
|
|
10
|
+
if stripped_line:
|
|
11
|
+
yield stripped_line
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def stream_file_bytes(path: str, chunk_size: int) -> Generator[bytes, None, None]:
|
|
15
|
+
chunk_size = max(chunk_size, 1)
|
|
16
|
+
with open(path, "rb") as f:
|
|
17
|
+
chunk = f.read(chunk_size)
|
|
18
|
+
while chunk:
|
|
19
|
+
yield chunk
|
|
20
|
+
chunk = f.read(chunk_size)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def read_json(path: str) -> Optional[Union[dict, list]]:
|
|
24
|
+
with open(path) as f:
|
|
25
|
+
return json.load(f)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def dump_json(path: str, content: Union[dict, list]) -> None:
|
|
29
|
+
ensure_parent_dir_exists(path=path)
|
|
30
|
+
with open(path, "w") as f:
|
|
31
|
+
json.dump(content, f)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def pre_allocate_file(
|
|
35
|
+
path: str, file_size: int, on_file_created: Optional[Callable[[str], None]] = None
|
|
36
|
+
) -> None:
|
|
37
|
+
ensure_parent_dir_exists(path=path)
|
|
38
|
+
with open(path, "wb") as f:
|
|
39
|
+
if on_file_created:
|
|
40
|
+
on_file_created(path)
|
|
41
|
+
f.truncate(file_size)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def ensure_parent_dir_exists(path: str) -> None:
|
|
45
|
+
parent_dir = os.path.dirname(os.path.abspath(path))
|
|
46
|
+
os.makedirs(parent_dir, exist_ok=True)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def remove_file_if_exists(path: str) -> None:
|
|
50
|
+
if os.path.isfile(path):
|
|
51
|
+
os.remove(path)
|