truefoundry 0.3.4rc1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truefoundry might be problematic. Click here for more details.
- truefoundry/__init__.py +2 -0
- truefoundry/autodeploy/agents/developer.py +1 -1
- truefoundry/autodeploy/agents/project_identifier.py +2 -2
- truefoundry/autodeploy/agents/tester.py +1 -1
- truefoundry/autodeploy/cli.py +1 -1
- truefoundry/autodeploy/tools/list_files.py +1 -1
- truefoundry/cli/__main__.py +3 -17
- truefoundry/common/__init__.py +0 -0
- truefoundry/{deploy/lib/auth → common}/auth_service_client.py +50 -40
- truefoundry/common/constants.py +12 -0
- truefoundry/{deploy/lib/auth → common}/credential_file_manager.py +7 -7
- truefoundry/{deploy/lib/auth → common}/credential_provider.py +10 -23
- truefoundry/common/entities.py +124 -0
- truefoundry/common/exceptions.py +12 -0
- truefoundry/common/request_utils.py +84 -0
- truefoundry/common/servicefoundry_client.py +91 -0
- truefoundry/common/utils.py +56 -0
- truefoundry/deploy/auto_gen/models.py +4 -6
- truefoundry/deploy/cli/cli.py +3 -1
- truefoundry/deploy/cli/commands/apply_command.py +1 -1
- truefoundry/deploy/cli/commands/build_command.py +1 -1
- truefoundry/deploy/cli/commands/deploy_command.py +1 -1
- truefoundry/deploy/cli/commands/login_command.py +2 -2
- truefoundry/deploy/cli/commands/patch_application_command.py +1 -1
- truefoundry/deploy/cli/commands/patch_command.py +1 -1
- truefoundry/deploy/cli/commands/terminate_comand.py +1 -1
- truefoundry/deploy/cli/util.py +1 -1
- truefoundry/deploy/function_service/remote/remote.py +1 -1
- truefoundry/deploy/lib/auth/servicefoundry_session.py +2 -2
- truefoundry/deploy/lib/clients/servicefoundry_client.py +120 -159
- truefoundry/deploy/lib/const.py +1 -35
- truefoundry/deploy/lib/exceptions.py +0 -16
- truefoundry/deploy/lib/model/entity.py +1 -112
- truefoundry/deploy/lib/session.py +14 -42
- truefoundry/deploy/lib/util.py +0 -37
- truefoundry/{python_deploy_codegen.py → deploy/python_deploy_codegen.py} +2 -2
- truefoundry/deploy/v2/lib/deploy.py +3 -3
- truefoundry/deploy/v2/lib/deployable_patched_models.py +1 -1
- truefoundry/langchain/truefoundry_chat.py +1 -1
- truefoundry/langchain/truefoundry_embeddings.py +1 -1
- truefoundry/langchain/truefoundry_llm.py +1 -1
- truefoundry/langchain/utils.py +0 -41
- truefoundry/ml/__init__.py +37 -6
- truefoundry/ml/artifact/__init__.py +0 -0
- truefoundry/ml/artifact/truefoundry_artifact_repo.py +1161 -0
- truefoundry/ml/autogen/__init__.py +0 -0
- truefoundry/ml/autogen/client/__init__.py +370 -0
- truefoundry/ml/autogen/client/api/__init__.py +16 -0
- truefoundry/ml/autogen/client/api/auth_api.py +184 -0
- truefoundry/ml/autogen/client/api/deprecated_api.py +605 -0
- truefoundry/ml/autogen/client/api/experiments_api.py +1944 -0
- truefoundry/ml/autogen/client/api/health_api.py +299 -0
- truefoundry/ml/autogen/client/api/metrics_api.py +371 -0
- truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +7213 -0
- truefoundry/ml/autogen/client/api/python_deployment_config_api.py +201 -0
- truefoundry/ml/autogen/client/api/run_artifacts_api.py +231 -0
- truefoundry/ml/autogen/client/api/runs_api.py +2919 -0
- truefoundry/ml/autogen/client/api_client.py +822 -0
- truefoundry/ml/autogen/client/api_response.py +30 -0
- truefoundry/ml/autogen/client/configuration.py +489 -0
- truefoundry/ml/autogen/client/exceptions.py +161 -0
- truefoundry/ml/autogen/client/models/__init__.py +341 -0
- truefoundry/ml/autogen/client/models/add_custom_metrics_to_model_version_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/add_features_to_model_version_request_dto.py +83 -0
- truefoundry/ml/autogen/client/models/agent.py +125 -0
- truefoundry/ml/autogen/client/models/agent_app.py +118 -0
- truefoundry/ml/autogen/client/models/agent_open_api_tool.py +143 -0
- truefoundry/ml/autogen/client/models/agent_open_api_tool_with_fqn.py +144 -0
- truefoundry/ml/autogen/client/models/agent_with_fqn.py +127 -0
- truefoundry/ml/autogen/client/models/artifact_dto.py +115 -0
- truefoundry/ml/autogen/client/models/artifact_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/artifact_type.py +39 -0
- truefoundry/ml/autogen/client/models/artifact_version_dto.py +141 -0
- truefoundry/ml/autogen/client/models/artifact_version_response_dto.py +77 -0
- truefoundry/ml/autogen/client/models/artifact_version_status.py +35 -0
- truefoundry/ml/autogen/client/models/assistant_message.py +89 -0
- truefoundry/ml/autogen/client/models/authorize_user_for_model_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/authorize_user_for_model_version_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/blob_storage_reference.py +93 -0
- truefoundry/ml/autogen/client/models/body_get_search_runs_get.py +72 -0
- truefoundry/ml/autogen/client/models/chat_prompt.py +156 -0
- truefoundry/ml/autogen/client/models/chat_prompt_messages_inner.py +171 -0
- truefoundry/ml/autogen/client/models/columns_dto.py +73 -0
- truefoundry/ml/autogen/client/models/content.py +153 -0
- truefoundry/ml/autogen/client/models/content1.py +153 -0
- truefoundry/ml/autogen/client/models/content2.py +174 -0
- truefoundry/ml/autogen/client/models/content2_any_of_inner.py +150 -0
- truefoundry/ml/autogen/client/models/create_artifact_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/create_artifact_response_dto.py +65 -0
- truefoundry/ml/autogen/client/models/create_artifact_version_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/create_artifact_version_response_dto.py +65 -0
- truefoundry/ml/autogen/client/models/create_dataset_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/create_experiment_request_dto.py +94 -0
- truefoundry/ml/autogen/client/models/create_experiment_response_dto.py +67 -0
- truefoundry/ml/autogen/client/models/create_model_version_request_dto.py +95 -0
- truefoundry/ml/autogen/client/models/create_multi_part_upload_for_dataset_request_dto.py +73 -0
- truefoundry/ml/autogen/client/models/create_multi_part_upload_for_dataset_response_dto.py +79 -0
- truefoundry/ml/autogen/client/models/create_multi_part_upload_request_dto.py +73 -0
- truefoundry/ml/autogen/client/models/create_python_deployment_config_request_dto.py +72 -0
- truefoundry/ml/autogen/client/models/create_python_deployment_config_response_dto.py +67 -0
- truefoundry/ml/autogen/client/models/create_run_request_dto.py +97 -0
- truefoundry/ml/autogen/client/models/create_run_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/dataset_dto.py +112 -0
- truefoundry/ml/autogen/client/models/dataset_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/delete_artifact_versions_request_dto.py +65 -0
- truefoundry/ml/autogen/client/models/delete_dataset_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/delete_model_version_request_dto.py +65 -0
- truefoundry/ml/autogen/client/models/delete_run_request.py +65 -0
- truefoundry/ml/autogen/client/models/delete_tag_request_dto.py +68 -0
- truefoundry/ml/autogen/client/models/experiment_dto.py +127 -0
- truefoundry/ml/autogen/client/models/experiment_id_request_dto.py +67 -0
- truefoundry/ml/autogen/client/models/experiment_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/experiment_tag_dto.py +69 -0
- truefoundry/ml/autogen/client/models/feature_dto.py +68 -0
- truefoundry/ml/autogen/client/models/feature_value_type.py +35 -0
- truefoundry/ml/autogen/client/models/file_info_dto.py +76 -0
- truefoundry/ml/autogen/client/models/finalize_artifact_version_request_dto.py +101 -0
- truefoundry/ml/autogen/client/models/get_experiment_response_dto.py +88 -0
- truefoundry/ml/autogen/client/models/get_latest_run_log_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/get_metric_history_response.py +79 -0
- truefoundry/ml/autogen/client/models/get_signed_url_for_dataset_write_request_dto.py +68 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_read_request_dto.py +68 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_read_response_dto.py +81 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_write_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_write_response_dto.py +83 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_read_request_dto.py +68 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_read_response_dto.py +81 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_write_response_dto.py +81 -0
- truefoundry/ml/autogen/client/models/get_tenant_id_response_dto.py +73 -0
- truefoundry/ml/autogen/client/models/http_validation_error.py +82 -0
- truefoundry/ml/autogen/client/models/image_content_part.py +87 -0
- truefoundry/ml/autogen/client/models/image_url.py +75 -0
- truefoundry/ml/autogen/client/models/internal_metadata.py +180 -0
- truefoundry/ml/autogen/client/models/latest_run_log_dto.py +78 -0
- truefoundry/ml/autogen/client/models/list_artifact_versions_request_dto.py +107 -0
- truefoundry/ml/autogen/client/models/list_artifact_versions_response_dto.py +87 -0
- truefoundry/ml/autogen/client/models/list_artifacts_request_dto.py +96 -0
- truefoundry/ml/autogen/client/models/list_artifacts_response_dto.py +86 -0
- truefoundry/ml/autogen/client/models/list_colums_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/list_datasets_request_dto.py +78 -0
- truefoundry/ml/autogen/client/models/list_datasets_response_dto.py +86 -0
- truefoundry/ml/autogen/client/models/list_experiments_response_dto.py +86 -0
- truefoundry/ml/autogen/client/models/list_files_for_artifact_version_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/list_files_for_artifact_versions_response_dto.py +82 -0
- truefoundry/ml/autogen/client/models/list_files_for_dataset_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/list_files_for_dataset_response_dto.py +82 -0
- truefoundry/ml/autogen/client/models/list_latest_run_logs_response_dto.py +82 -0
- truefoundry/ml/autogen/client/models/list_metric_history_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/list_metric_history_response_dto.py +84 -0
- truefoundry/ml/autogen/client/models/list_model_version_response_dto.py +87 -0
- truefoundry/ml/autogen/client/models/list_model_versions_request_dto.py +93 -0
- truefoundry/ml/autogen/client/models/list_models_request_dto.py +89 -0
- truefoundry/ml/autogen/client/models/list_models_response_dto.py +84 -0
- truefoundry/ml/autogen/client/models/list_run_artifacts_response_dto.py +84 -0
- truefoundry/ml/autogen/client/models/list_run_logs_response_dto.py +82 -0
- truefoundry/ml/autogen/client/models/list_seed_experiments_response_dto.py +81 -0
- truefoundry/ml/autogen/client/models/log_batch_request_dto.py +106 -0
- truefoundry/ml/autogen/client/models/log_metric_request_dto.py +80 -0
- truefoundry/ml/autogen/client/models/log_param_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/method.py +37 -0
- truefoundry/ml/autogen/client/models/metric_collection_dto.py +82 -0
- truefoundry/ml/autogen/client/models/metric_dto.py +76 -0
- truefoundry/ml/autogen/client/models/mime_type.py +37 -0
- truefoundry/ml/autogen/client/models/model_configuration.py +103 -0
- truefoundry/ml/autogen/client/models/model_dto.py +122 -0
- truefoundry/ml/autogen/client/models/model_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/model_schema_dto.py +85 -0
- truefoundry/ml/autogen/client/models/model_version_dto.py +170 -0
- truefoundry/ml/autogen/client/models/model_version_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/multi_part_upload_dto.py +107 -0
- truefoundry/ml/autogen/client/models/multi_part_upload_response_dto.py +79 -0
- truefoundry/ml/autogen/client/models/multi_part_upload_storage_provider.py +34 -0
- truefoundry/ml/autogen/client/models/notify_artifact_version_failure_dto.py +65 -0
- truefoundry/ml/autogen/client/models/openapi_spec.py +152 -0
- truefoundry/ml/autogen/client/models/param_dto.py +66 -0
- truefoundry/ml/autogen/client/models/parameters.py +84 -0
- truefoundry/ml/autogen/client/models/prediction_type.py +34 -0
- truefoundry/ml/autogen/client/models/resolve_agent_app_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/restore_run_request_dto.py +65 -0
- truefoundry/ml/autogen/client/models/run_data_dto.py +104 -0
- truefoundry/ml/autogen/client/models/run_dto.py +84 -0
- truefoundry/ml/autogen/client/models/run_info_dto.py +105 -0
- truefoundry/ml/autogen/client/models/run_log_dto.py +90 -0
- truefoundry/ml/autogen/client/models/run_log_input_dto.py +80 -0
- truefoundry/ml/autogen/client/models/run_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/run_tag_dto.py +66 -0
- truefoundry/ml/autogen/client/models/search_runs_request_dto.py +94 -0
- truefoundry/ml/autogen/client/models/search_runs_response_dto.py +84 -0
- truefoundry/ml/autogen/client/models/set_experiment_tag_request_dto.py +73 -0
- truefoundry/ml/autogen/client/models/set_tag_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/signed_url_dto.py +69 -0
- truefoundry/ml/autogen/client/models/stop.py +152 -0
- truefoundry/ml/autogen/client/models/store_run_logs_request_dto.py +83 -0
- truefoundry/ml/autogen/client/models/system_message.py +89 -0
- truefoundry/ml/autogen/client/models/text.py +153 -0
- truefoundry/ml/autogen/client/models/text_content_part.py +84 -0
- truefoundry/ml/autogen/client/models/update_artifact_version_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/update_dataset_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/update_experiment_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/update_model_version_request_dto.py +93 -0
- truefoundry/ml/autogen/client/models/update_run_request_dto.py +78 -0
- truefoundry/ml/autogen/client/models/update_run_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/url.py +153 -0
- truefoundry/ml/autogen/client/models/user_message.py +89 -0
- truefoundry/ml/autogen/client/models/validation_error.py +87 -0
- truefoundry/ml/autogen/client/models/validation_error_loc_inner.py +154 -0
- truefoundry/ml/autogen/client/rest.py +426 -0
- truefoundry/ml/autogen/client_README.md +320 -0
- truefoundry/ml/cli/__init__.py +0 -0
- truefoundry/ml/cli/cli.py +18 -0
- truefoundry/ml/cli/commands/__init__.py +3 -0
- truefoundry/ml/cli/commands/download.py +87 -0
- truefoundry/ml/clients/__init__.py +0 -0
- truefoundry/ml/clients/entities.py +8 -0
- truefoundry/ml/clients/servicefoundry_client.py +45 -0
- truefoundry/ml/clients/utils.py +122 -0
- truefoundry/ml/constants.py +84 -0
- truefoundry/ml/entities.py +62 -0
- truefoundry/ml/enums.py +70 -0
- truefoundry/ml/env_vars.py +9 -0
- truefoundry/ml/exceptions.py +8 -0
- truefoundry/ml/git_info.py +60 -0
- truefoundry/ml/internal_namespace.py +52 -0
- truefoundry/ml/log_types/__init__.py +4 -0
- truefoundry/ml/log_types/artifacts/artifact.py +431 -0
- truefoundry/ml/log_types/artifacts/constants.py +33 -0
- truefoundry/ml/log_types/artifacts/dataset.py +384 -0
- truefoundry/ml/log_types/artifacts/general_artifact.py +110 -0
- truefoundry/ml/log_types/artifacts/model.py +611 -0
- truefoundry/ml/log_types/artifacts/model_extras.py +48 -0
- truefoundry/ml/log_types/artifacts/utils.py +161 -0
- truefoundry/ml/log_types/image/__init__.py +3 -0
- truefoundry/ml/log_types/image/constants.py +8 -0
- truefoundry/ml/log_types/image/image.py +357 -0
- truefoundry/ml/log_types/image/image_normalizer.py +102 -0
- truefoundry/ml/log_types/image/types.py +68 -0
- truefoundry/ml/log_types/plot.py +281 -0
- truefoundry/ml/log_types/pydantic_base.py +10 -0
- truefoundry/ml/log_types/utils.py +12 -0
- truefoundry/ml/logger.py +17 -0
- truefoundry/ml/mlfoundry_api.py +1575 -0
- truefoundry/ml/mlfoundry_run.py +1203 -0
- truefoundry/ml/run_utils.py +93 -0
- truefoundry/ml/session.py +168 -0
- truefoundry/ml/validation_utils.py +346 -0
- truefoundry/pydantic_v1.py +8 -1
- truefoundry/workflow/__init__.py +16 -1
- {truefoundry-0.3.4rc1.dist-info → truefoundry-0.4.0.dist-info}/METADATA +21 -14
- truefoundry-0.4.0.dist-info/RECORD +344 -0
- truefoundry/deploy/lib/clients/utils.py +0 -41
- truefoundry-0.3.4rc1.dist-info/RECORD +0 -136
- {truefoundry-0.3.4rc1.dist-info → truefoundry-0.4.0.dist-info}/WHEEL +0 -0
- {truefoundry-0.3.4rc1.dist-info → truefoundry-0.4.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,1161 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import mmap
|
|
3
|
+
import os
|
|
4
|
+
import posixpath
|
|
5
|
+
import sys
|
|
6
|
+
import tempfile
|
|
7
|
+
import uuid
|
|
8
|
+
from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait
|
|
9
|
+
from shutil import rmtree
|
|
10
|
+
from threading import Event
|
|
11
|
+
from typing import (
|
|
12
|
+
Any,
|
|
13
|
+
Callable,
|
|
14
|
+
Dict,
|
|
15
|
+
Iterator,
|
|
16
|
+
List,
|
|
17
|
+
NamedTuple,
|
|
18
|
+
Optional,
|
|
19
|
+
Tuple,
|
|
20
|
+
Union,
|
|
21
|
+
)
|
|
22
|
+
from urllib.parse import unquote
|
|
23
|
+
from urllib.request import pathname2url
|
|
24
|
+
|
|
25
|
+
import requests
|
|
26
|
+
from rich.console import _is_jupyter
|
|
27
|
+
from rich.progress import (
|
|
28
|
+
BarColumn,
|
|
29
|
+
DownloadColumn,
|
|
30
|
+
Progress,
|
|
31
|
+
TimeElapsedColumn,
|
|
32
|
+
TimeRemainingColumn,
|
|
33
|
+
TransferSpeedColumn,
|
|
34
|
+
)
|
|
35
|
+
from tqdm.utils import CallbackIOWrapper
|
|
36
|
+
|
|
37
|
+
from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
|
|
38
|
+
ApiClient,
|
|
39
|
+
CreateMultiPartUploadForDatasetRequestDto,
|
|
40
|
+
CreateMultiPartUploadRequestDto,
|
|
41
|
+
FileInfoDto,
|
|
42
|
+
GetSignedURLForDatasetWriteRequestDto,
|
|
43
|
+
GetSignedURLsForArtifactVersionReadRequestDto,
|
|
44
|
+
GetSignedURLsForArtifactVersionWriteRequestDto,
|
|
45
|
+
GetSignedURLsForDatasetReadRequestDto,
|
|
46
|
+
ListFilesForArtifactVersionRequestDto,
|
|
47
|
+
ListFilesForArtifactVersionsResponseDto,
|
|
48
|
+
ListFilesForDatasetRequestDto,
|
|
49
|
+
ListFilesForDatasetResponseDto,
|
|
50
|
+
MlfoundryArtifactsApi,
|
|
51
|
+
MultiPartUploadDto,
|
|
52
|
+
MultiPartUploadResponseDto,
|
|
53
|
+
MultiPartUploadStorageProvider,
|
|
54
|
+
RunArtifactsApi,
|
|
55
|
+
SignedURLDto,
|
|
56
|
+
)
|
|
57
|
+
from truefoundry.ml.clients.utils import (
|
|
58
|
+
augmented_raise_for_status,
|
|
59
|
+
cloud_storage_http_request,
|
|
60
|
+
)
|
|
61
|
+
from truefoundry.ml.env_vars import DISABLE_MULTIPART_UPLOAD
|
|
62
|
+
from truefoundry.ml.exceptions import MlFoundryException
|
|
63
|
+
from truefoundry.ml.logger import logger
|
|
64
|
+
from truefoundry.ml.session import _get_api_client
|
|
65
|
+
from truefoundry.pydantic_v1 import BaseModel, root_validator
|
|
66
|
+
|
|
67
|
+
_MIN_BYTES_REQUIRED_FOR_MULTIPART = 100 * 1024 * 1024
|
|
68
|
+
_MULTIPART_DISABLED = os.getenv(DISABLE_MULTIPART_UPLOAD, "").lower() == "true"
|
|
69
|
+
# GCP/S3 Maximum number of parts per upload 10,000
|
|
70
|
+
# Maximum number of blocks in a block blob 50,000 blocks
|
|
71
|
+
# TODO: This number is artificially limited now. Later
|
|
72
|
+
# we will ask for parts signed URI in batches rather than in a single
|
|
73
|
+
# API Calls:
|
|
74
|
+
# Create Multipart Upload (Returns maximum number of parts, size limit of
|
|
75
|
+
# a single part, upload id for s3 etc )
|
|
76
|
+
# Get me signed uris for first 500 parts
|
|
77
|
+
# Upload 500 parts
|
|
78
|
+
# Get me signed uris for the next 500 parts
|
|
79
|
+
# Upload 500 parts
|
|
80
|
+
# ...
|
|
81
|
+
# Finalize the Multipart upload using the finalize signed url returned
|
|
82
|
+
# by Create Multipart Upload or get a new one.
|
|
83
|
+
_MAX_NUM_PARTS_FOR_MULTIPART = 1000
|
|
84
|
+
# Azure Maximum size of a block in a block blob 4000 MiB
|
|
85
|
+
# GCP/S3 Maximum size of an individual part in a multipart upload 5 GiB
|
|
86
|
+
_MAX_PART_SIZE_BYTES_FOR_MULTIPART = 4 * 1024 * 1024 * 1000
|
|
87
|
+
_cpu_count = os.cpu_count() or 2
|
|
88
|
+
_MAX_WORKERS_FOR_UPLOAD = max(min(32, _cpu_count * 2), 4)
|
|
89
|
+
_MAX_WORKERS_FOR_DOWNLOAD = max(min(32, _cpu_count * 2), 4)
|
|
90
|
+
_LIST_FILES_PAGE_SIZE = 500
|
|
91
|
+
_GENERATE_SIGNED_URL_BATCH_SIZE = 50
|
|
92
|
+
DEFAULT_PRESIGNED_URL_EXPIRY_TIME = 3600
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _get_relpath_if_in_tempdir(path: str) -> str:
|
|
96
|
+
tempdir = tempfile.gettempdir()
|
|
97
|
+
if path.startswith(tempdir):
|
|
98
|
+
return os.path.relpath(path, tempdir)
|
|
99
|
+
return path
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _can_display_progress(user_choice: Optional[bool] = None) -> bool:
|
|
103
|
+
if user_choice is False:
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
if sys.stdout.isatty():
|
|
107
|
+
return True
|
|
108
|
+
elif _is_jupyter():
|
|
109
|
+
try:
|
|
110
|
+
from IPython.display import display # noqa: F401
|
|
111
|
+
from ipywidgets import Output # noqa: F401
|
|
112
|
+
|
|
113
|
+
return True
|
|
114
|
+
except ImportError:
|
|
115
|
+
logger.warning(
|
|
116
|
+
"Detected Jupyter Environment. Install `ipywidgets` to display live progress bars.",
|
|
117
|
+
)
|
|
118
|
+
if user_choice is True:
|
|
119
|
+
logger.warning(
|
|
120
|
+
"`progress` argument is set to True but did not detect tty "
|
|
121
|
+
"or jupyter environment with ipywidgets installed. "
|
|
122
|
+
"Progress bars may not be displayed. "
|
|
123
|
+
)
|
|
124
|
+
return True
|
|
125
|
+
return False
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def relative_path_to_artifact_path(path):
|
|
129
|
+
if os.path == posixpath:
|
|
130
|
+
return path
|
|
131
|
+
if os.path.abspath(path) == path:
|
|
132
|
+
raise Exception("This method only works with relative paths.")
|
|
133
|
+
return unquote(pathname2url(path))
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _align_part_size_with_mmap_allocation_granularity(part_size: int) -> int:
|
|
137
|
+
modulo = part_size % mmap.ALLOCATIONGRANULARITY
|
|
138
|
+
if modulo == 0:
|
|
139
|
+
return part_size
|
|
140
|
+
|
|
141
|
+
part_size += mmap.ALLOCATIONGRANULARITY - modulo
|
|
142
|
+
return part_size
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# Can not be less than 5 * 1024 * 1024
|
|
146
|
+
_PART_SIZE_BYTES_FOR_MULTIPART = _align_part_size_with_mmap_allocation_granularity(
|
|
147
|
+
10 * 1024 * 1024
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def bad_path_message(name):
|
|
152
|
+
return (
|
|
153
|
+
"Names may be treated as files in certain cases, and must not resolve to other names"
|
|
154
|
+
" when treated as such. This name would resolve to '%s'"
|
|
155
|
+
) % posixpath.normpath(name)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def path_not_unique(name):
|
|
159
|
+
norm = posixpath.normpath(name)
|
|
160
|
+
return norm != name or norm == "." or norm.startswith("..") or norm.startswith("/")
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def verify_artifact_path(artifact_path):
|
|
164
|
+
if artifact_path and path_not_unique(artifact_path):
|
|
165
|
+
raise MlFoundryException(
|
|
166
|
+
f"Invalid artifact path: {artifact_path!r}. {bad_path_message(artifact_path)!r}"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class _PartNumberEtag(NamedTuple):
|
|
171
|
+
part_number: int
|
|
172
|
+
etag: str
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _get_s3_compatible_completion_body(multi_parts: List[_PartNumberEtag]) -> str:
|
|
176
|
+
body = "<CompleteMultipartUpload>\n"
|
|
177
|
+
for part in multi_parts:
|
|
178
|
+
body += " <Part>\n"
|
|
179
|
+
body += f" <PartNumber>{part.part_number}</PartNumber>\n"
|
|
180
|
+
body += f" <ETag>{part.etag}</ETag>\n"
|
|
181
|
+
body += " </Part>\n"
|
|
182
|
+
body += "</CompleteMultipartUpload>"
|
|
183
|
+
return body
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _get_azure_blob_completion_body(block_ids: List[str]) -> str:
|
|
187
|
+
body = "<BlockList>\n"
|
|
188
|
+
for block_id in block_ids:
|
|
189
|
+
body += f"<Uncommitted>{block_id}</Uncommitted> "
|
|
190
|
+
body += "</BlockList>"
|
|
191
|
+
return body
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class _FileMultiPartInfo(NamedTuple):
|
|
195
|
+
num_parts: int
|
|
196
|
+
part_size: int
|
|
197
|
+
file_size: int
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _decide_file_parts(file_path: str) -> _FileMultiPartInfo:
|
|
201
|
+
file_size = os.path.getsize(file_path)
|
|
202
|
+
if file_size < _MIN_BYTES_REQUIRED_FOR_MULTIPART or _MULTIPART_DISABLED:
|
|
203
|
+
return _FileMultiPartInfo(1, part_size=file_size, file_size=file_size)
|
|
204
|
+
|
|
205
|
+
ideal_num_parts = math.ceil(file_size / _PART_SIZE_BYTES_FOR_MULTIPART)
|
|
206
|
+
if ideal_num_parts <= _MAX_NUM_PARTS_FOR_MULTIPART:
|
|
207
|
+
return _FileMultiPartInfo(
|
|
208
|
+
ideal_num_parts,
|
|
209
|
+
part_size=_PART_SIZE_BYTES_FOR_MULTIPART,
|
|
210
|
+
file_size=file_size,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
part_size_when_using_max_parts = math.ceil(file_size / _MAX_NUM_PARTS_FOR_MULTIPART)
|
|
214
|
+
part_size_when_using_max_parts = _align_part_size_with_mmap_allocation_granularity(
|
|
215
|
+
part_size_when_using_max_parts
|
|
216
|
+
)
|
|
217
|
+
if part_size_when_using_max_parts > _MAX_PART_SIZE_BYTES_FOR_MULTIPART:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"file {file_path!r} is too big for upload. Multipart chunk"
|
|
220
|
+
f" size {part_size_when_using_max_parts} is higher"
|
|
221
|
+
f" than {_MAX_PART_SIZE_BYTES_FOR_MULTIPART}"
|
|
222
|
+
)
|
|
223
|
+
num_parts = math.ceil(file_size / part_size_when_using_max_parts)
|
|
224
|
+
return _FileMultiPartInfo(
|
|
225
|
+
num_parts, part_size=part_size_when_using_max_parts, file_size=file_size
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _signed_url_upload_file(
|
|
230
|
+
signed_url: SignedURLDto,
|
|
231
|
+
local_file: str,
|
|
232
|
+
progress_bar: Progress,
|
|
233
|
+
abort_event: Optional[Event] = None,
|
|
234
|
+
):
|
|
235
|
+
if os.stat(local_file).st_size == 0:
|
|
236
|
+
with cloud_storage_http_request(
|
|
237
|
+
method="put", url=signed_url.signed_url, data=""
|
|
238
|
+
) as response:
|
|
239
|
+
augmented_raise_for_status(response.raise_for_status())
|
|
240
|
+
return
|
|
241
|
+
|
|
242
|
+
task_progress_bar = progress_bar.add_task(
|
|
243
|
+
f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
def callback(length):
|
|
247
|
+
progress_bar.update(
|
|
248
|
+
task_progress_bar, advance=length, total=os.stat(local_file).st_size
|
|
249
|
+
)
|
|
250
|
+
if abort_event and abort_event.is_set():
|
|
251
|
+
raise Exception("aborting upload")
|
|
252
|
+
|
|
253
|
+
with open(local_file, "rb") as file:
|
|
254
|
+
# NOTE: Azure Put Blob does not support Transfer Encoding header.
|
|
255
|
+
wrapped_file = CallbackIOWrapper(callback, file, "read")
|
|
256
|
+
with cloud_storage_http_request(
|
|
257
|
+
method="put", url=signed_url.signed_url, data=wrapped_file
|
|
258
|
+
) as response:
|
|
259
|
+
augmented_raise_for_status(response)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _download_file_using_http_uri(
|
|
263
|
+
http_uri,
|
|
264
|
+
download_path,
|
|
265
|
+
chunk_size=100000000,
|
|
266
|
+
callback: Optional[Callable[[int, int], Any]] = None,
|
|
267
|
+
):
|
|
268
|
+
"""
|
|
269
|
+
Downloads a file specified using the `http_uri` to a local `download_path`. This function
|
|
270
|
+
uses a `chunk_size` to ensure an OOM error is not raised a large file is downloaded.
|
|
271
|
+
Note : This function is meant to download files using presigned urls from various cloud
|
|
272
|
+
providers.
|
|
273
|
+
"""
|
|
274
|
+
with cloud_storage_http_request(
|
|
275
|
+
method="get", url=http_uri, stream=True
|
|
276
|
+
) as response:
|
|
277
|
+
augmented_raise_for_status(response)
|
|
278
|
+
file_size = int(response.headers.get("Content-Length", 0))
|
|
279
|
+
with open(download_path, "wb") as output_file:
|
|
280
|
+
for chunk in response.iter_content(chunk_size=chunk_size):
|
|
281
|
+
if callback:
|
|
282
|
+
callback(len(chunk), file_size)
|
|
283
|
+
if not chunk:
|
|
284
|
+
break
|
|
285
|
+
output_file.write(chunk)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class _CallbackIOWrapperForMultiPartUpload(CallbackIOWrapper):
|
|
289
|
+
def __init__(self, callback, stream, method, length: int):
|
|
290
|
+
self.wrapper_setattr("_length", length)
|
|
291
|
+
super().__init__(callback, stream, method)
|
|
292
|
+
|
|
293
|
+
def __len__(self):
|
|
294
|
+
return self.wrapper_getattr("_length")
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _file_part_upload(
|
|
298
|
+
url: str,
|
|
299
|
+
file_path: str,
|
|
300
|
+
seek: int,
|
|
301
|
+
length: int,
|
|
302
|
+
file_size: int,
|
|
303
|
+
abort_event: Optional[Event] = None,
|
|
304
|
+
method: str = "put",
|
|
305
|
+
):
|
|
306
|
+
def callback(*_, **__):
|
|
307
|
+
if abort_event and abort_event.is_set():
|
|
308
|
+
raise Exception("aborting upload")
|
|
309
|
+
|
|
310
|
+
with open(file_path, "rb") as file:
|
|
311
|
+
with mmap.mmap(
|
|
312
|
+
file.fileno(),
|
|
313
|
+
length=min(file_size - seek, length),
|
|
314
|
+
offset=seek,
|
|
315
|
+
access=mmap.ACCESS_READ,
|
|
316
|
+
) as mapped_file:
|
|
317
|
+
wrapped_file = _CallbackIOWrapperForMultiPartUpload(
|
|
318
|
+
callback, mapped_file, "read", len(mapped_file)
|
|
319
|
+
)
|
|
320
|
+
with cloud_storage_http_request(
|
|
321
|
+
method=method,
|
|
322
|
+
url=url,
|
|
323
|
+
data=wrapped_file,
|
|
324
|
+
) as response:
|
|
325
|
+
augmented_raise_for_status(response)
|
|
326
|
+
return response
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def _s3_compatible_multipart_upload(
|
|
330
|
+
multipart_upload: MultiPartUploadDto,
|
|
331
|
+
local_file: str,
|
|
332
|
+
multipart_info: _FileMultiPartInfo,
|
|
333
|
+
executor: ThreadPoolExecutor,
|
|
334
|
+
progress_bar: Progress,
|
|
335
|
+
abort_event: Optional[Event] = None,
|
|
336
|
+
):
|
|
337
|
+
abort_event = abort_event or Event()
|
|
338
|
+
parts = []
|
|
339
|
+
|
|
340
|
+
multi_part_upload_progress = progress_bar.add_task(
|
|
341
|
+
f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
def upload(part_number: int, seek: int) -> None:
|
|
345
|
+
logger.debug(
|
|
346
|
+
"Uploading part %d/%d of %s",
|
|
347
|
+
part_number,
|
|
348
|
+
multipart_info.num_parts,
|
|
349
|
+
local_file,
|
|
350
|
+
)
|
|
351
|
+
response = _file_part_upload(
|
|
352
|
+
url=multipart_upload.part_signed_urls[part_number].signed_url,
|
|
353
|
+
file_path=local_file,
|
|
354
|
+
seek=seek,
|
|
355
|
+
length=multipart_info.part_size,
|
|
356
|
+
file_size=multipart_info.file_size,
|
|
357
|
+
abort_event=abort_event,
|
|
358
|
+
)
|
|
359
|
+
logger.debug(
|
|
360
|
+
"Uploaded part %d/%d of %s",
|
|
361
|
+
part_number,
|
|
362
|
+
multipart_info.num_parts,
|
|
363
|
+
local_file,
|
|
364
|
+
)
|
|
365
|
+
progress_bar.update(
|
|
366
|
+
multi_part_upload_progress,
|
|
367
|
+
advance=multipart_info.part_size,
|
|
368
|
+
total=multipart_info.file_size,
|
|
369
|
+
)
|
|
370
|
+
etag = response.headers["ETag"]
|
|
371
|
+
parts.append(_PartNumberEtag(etag=etag, part_number=part_number + 1))
|
|
372
|
+
|
|
373
|
+
futures: List[Future] = []
|
|
374
|
+
for part_number, seek in enumerate(
|
|
375
|
+
range(0, multipart_info.file_size, multipart_info.part_size)
|
|
376
|
+
):
|
|
377
|
+
future = executor.submit(upload, part_number=part_number, seek=seek)
|
|
378
|
+
futures.append(future)
|
|
379
|
+
|
|
380
|
+
done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
|
|
381
|
+
if len(not_done) > 0:
|
|
382
|
+
abort_event.set()
|
|
383
|
+
for future in not_done:
|
|
384
|
+
future.cancel()
|
|
385
|
+
for future in done:
|
|
386
|
+
if future.exception() is not None:
|
|
387
|
+
raise future.exception()
|
|
388
|
+
|
|
389
|
+
logger.debug("Finalizing multipart upload of %s", local_file)
|
|
390
|
+
parts = sorted(parts, key=lambda part: part.part_number)
|
|
391
|
+
response = requests.post(
|
|
392
|
+
multipart_upload.finalize_signed_url.signed_url,
|
|
393
|
+
data=_get_s3_compatible_completion_body(parts),
|
|
394
|
+
timeout=2 * 60,
|
|
395
|
+
)
|
|
396
|
+
response.raise_for_status()
|
|
397
|
+
logger.debug("Multipart upload of %s completed", local_file)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def _azure_multi_part_upload(
|
|
401
|
+
multipart_upload: MultiPartUploadDto,
|
|
402
|
+
local_file: str,
|
|
403
|
+
multipart_info: _FileMultiPartInfo,
|
|
404
|
+
executor: ThreadPoolExecutor,
|
|
405
|
+
progress_bar: Progress,
|
|
406
|
+
abort_event: Optional[Event] = None,
|
|
407
|
+
):
|
|
408
|
+
abort_event = abort_event or Event()
|
|
409
|
+
|
|
410
|
+
multi_part_upload_progress = progress_bar.add_task(
|
|
411
|
+
f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
def upload(part_number: int, seek: int):
|
|
415
|
+
logger.debug(
|
|
416
|
+
"Uploading part %d/%d of %s",
|
|
417
|
+
part_number,
|
|
418
|
+
multipart_info.num_parts,
|
|
419
|
+
local_file,
|
|
420
|
+
)
|
|
421
|
+
_file_part_upload(
|
|
422
|
+
url=multipart_upload.part_signed_urls[part_number].signed_url,
|
|
423
|
+
file_path=local_file,
|
|
424
|
+
seek=seek,
|
|
425
|
+
length=multipart_info.part_size,
|
|
426
|
+
file_size=multipart_info.file_size,
|
|
427
|
+
abort_event=abort_event,
|
|
428
|
+
)
|
|
429
|
+
progress_bar.update(
|
|
430
|
+
multi_part_upload_progress,
|
|
431
|
+
advance=multipart_info.part_size,
|
|
432
|
+
total=multipart_info.file_size,
|
|
433
|
+
)
|
|
434
|
+
logger.debug(
|
|
435
|
+
"Uploaded part %d/%d of %s",
|
|
436
|
+
part_number,
|
|
437
|
+
multipart_info.num_parts,
|
|
438
|
+
local_file,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
futures: List[Future] = []
|
|
442
|
+
for part_number, seek in enumerate(
|
|
443
|
+
range(0, multipart_info.file_size, multipart_info.part_size)
|
|
444
|
+
):
|
|
445
|
+
future = executor.submit(upload, part_number=part_number, seek=seek)
|
|
446
|
+
futures.append(future)
|
|
447
|
+
|
|
448
|
+
done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
|
|
449
|
+
if len(not_done) > 0:
|
|
450
|
+
abort_event.set()
|
|
451
|
+
for future in not_done:
|
|
452
|
+
future.cancel()
|
|
453
|
+
for future in done:
|
|
454
|
+
if future.exception() is not None:
|
|
455
|
+
raise future.exception()
|
|
456
|
+
|
|
457
|
+
logger.debug("Finalizing multipart upload of %s", local_file)
|
|
458
|
+
if multipart_upload.azure_blob_block_ids:
|
|
459
|
+
response = requests.put(
|
|
460
|
+
multipart_upload.finalize_signed_url.signed_url,
|
|
461
|
+
data=_get_azure_blob_completion_body(
|
|
462
|
+
block_ids=multipart_upload.azure_blob_block_ids
|
|
463
|
+
),
|
|
464
|
+
timeout=2 * 60,
|
|
465
|
+
)
|
|
466
|
+
response.raise_for_status()
|
|
467
|
+
logger.debug("Multipart upload of %s completed", local_file)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def _any_future_has_failed(futures) -> bool:
|
|
471
|
+
return any(
|
|
472
|
+
future.done() and not future.cancelled() and future.exception() is not None
|
|
473
|
+
for future in futures
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
class ArtifactIdentifier(BaseModel):
|
|
478
|
+
artifact_version_id: Optional[uuid.UUID] = None
|
|
479
|
+
dataset_fqn: Optional[str] = None
|
|
480
|
+
|
|
481
|
+
@root_validator
|
|
482
|
+
def _check_identifier_type(cls, values: Dict[str, Any]):
|
|
483
|
+
if not values.get("artifact_version_id", False) and not values.get(
|
|
484
|
+
"dataset_fqn", False
|
|
485
|
+
):
|
|
486
|
+
raise MlFoundryException(
|
|
487
|
+
"One of the version_id or dataset_fqn should be passed"
|
|
488
|
+
)
|
|
489
|
+
if values.get("artifact_version_id", False) and values.get(
|
|
490
|
+
"dataset_fqn", False
|
|
491
|
+
):
|
|
492
|
+
raise MlFoundryException(
|
|
493
|
+
"Exactly one of version_id or dataset_fqn should be passed"
|
|
494
|
+
)
|
|
495
|
+
return values
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
class MlFoundryArtifactsRepository:
|
|
499
|
+
def __init__(
|
|
500
|
+
self,
|
|
501
|
+
artifact_identifier: ArtifactIdentifier,
|
|
502
|
+
api_client: Optional[ApiClient] = None,
|
|
503
|
+
):
|
|
504
|
+
self.artifact_identifier = artifact_identifier
|
|
505
|
+
self._api_client = api_client or _get_api_client()
|
|
506
|
+
self._run_artifacts_api = RunArtifactsApi(api_client=self._api_client)
|
|
507
|
+
self._mlfoundry_artifacts_api = MlfoundryArtifactsApi(
|
|
508
|
+
api_client=self._api_client
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
def _create_download_destination(
|
|
512
|
+
self, src_artifact_path, dst_local_dir_path=None
|
|
513
|
+
) -> str:
|
|
514
|
+
"""
|
|
515
|
+
Creates a local filesystem location to be used as a destination for downloading the artifact
|
|
516
|
+
specified by `src_artifact_path`. The destination location is a subdirectory of the
|
|
517
|
+
specified `dst_local_dir_path`, which is determined according to the structure of
|
|
518
|
+
`src_artifact_path`. For example, if `src_artifact_path` is `dir1/file1.txt`, then the
|
|
519
|
+
resulting destination path is `<dst_local_dir_path>/dir1/file1.txt`. Local directories are
|
|
520
|
+
created for the resulting destination location if they do not exist.
|
|
521
|
+
|
|
522
|
+
:param src_artifact_path: A relative, POSIX-style path referring to an artifact stored
|
|
523
|
+
within the repository's artifact root location.
|
|
524
|
+
`src_artifact_path` should be specified relative to the
|
|
525
|
+
repository's artifact root location.
|
|
526
|
+
:param dst_local_dir_path: The absolute path to a local filesystem directory in which the
|
|
527
|
+
local destination path will be contained. The local destination
|
|
528
|
+
path may be contained in a subdirectory of `dst_root_dir` if
|
|
529
|
+
`src_artifact_path` contains subdirectories.
|
|
530
|
+
:return: The absolute path to a local filesystem location to be used as a destination
|
|
531
|
+
for downloading the artifact specified by `src_artifact_path`.
|
|
532
|
+
"""
|
|
533
|
+
src_artifact_path = src_artifact_path.rstrip(
|
|
534
|
+
"/"
|
|
535
|
+
) # Ensure correct dirname for trailing '/'
|
|
536
|
+
dirpath = posixpath.dirname(src_artifact_path)
|
|
537
|
+
local_dir_path = os.path.join(dst_local_dir_path, dirpath)
|
|
538
|
+
local_file_path = os.path.join(dst_local_dir_path, src_artifact_path)
|
|
539
|
+
if not os.path.exists(local_dir_path):
|
|
540
|
+
os.makedirs(local_dir_path, exist_ok=True)
|
|
541
|
+
return local_file_path
|
|
542
|
+
|
|
543
|
+
# these methods should be named list_files, log_directory, log_file, etc
|
|
544
|
+
def list_artifacts(
|
|
545
|
+
self, path=None, page_size=_LIST_FILES_PAGE_SIZE, **kwargs
|
|
546
|
+
) -> Iterator[FileInfoDto]:
|
|
547
|
+
page_token = None
|
|
548
|
+
started = False
|
|
549
|
+
while not started or page_token is not None:
|
|
550
|
+
started = True
|
|
551
|
+
page = self.list_files(
|
|
552
|
+
artifact_identifier=self.artifact_identifier,
|
|
553
|
+
path=path,
|
|
554
|
+
page_size=page_size,
|
|
555
|
+
page_token=page_token,
|
|
556
|
+
)
|
|
557
|
+
for file_info in page.files:
|
|
558
|
+
yield file_info
|
|
559
|
+
page_token = page.next_page_token
|
|
560
|
+
|
|
561
|
+
def log_artifacts( # noqa: C901
|
|
562
|
+
self, local_dir, artifact_path=None, progress=None
|
|
563
|
+
):
|
|
564
|
+
show_progress = _can_display_progress(progress)
|
|
565
|
+
|
|
566
|
+
dest_path = artifact_path or ""
|
|
567
|
+
dest_path = dest_path.lstrip(posixpath.sep)
|
|
568
|
+
|
|
569
|
+
files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
|
|
570
|
+
files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
|
|
571
|
+
|
|
572
|
+
for root, _, file_names in os.walk(local_dir):
|
|
573
|
+
upload_path = dest_path
|
|
574
|
+
if root != local_dir:
|
|
575
|
+
rel_path = os.path.relpath(root, local_dir)
|
|
576
|
+
rel_path = relative_path_to_artifact_path(rel_path)
|
|
577
|
+
upload_path = posixpath.join(dest_path, rel_path)
|
|
578
|
+
for file_name in file_names:
|
|
579
|
+
local_file = os.path.join(root, file_name)
|
|
580
|
+
multipart_info = _decide_file_parts(local_file)
|
|
581
|
+
|
|
582
|
+
final_upload_path = upload_path or ""
|
|
583
|
+
final_upload_path = final_upload_path.lstrip(posixpath.sep)
|
|
584
|
+
final_upload_path = posixpath.join(
|
|
585
|
+
final_upload_path, os.path.basename(local_file)
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
if multipart_info.num_parts == 1:
|
|
589
|
+
files_for_normal_upload.append(
|
|
590
|
+
(final_upload_path, local_file, multipart_info)
|
|
591
|
+
)
|
|
592
|
+
else:
|
|
593
|
+
files_for_multipart_upload.append(
|
|
594
|
+
(final_upload_path, local_file, multipart_info)
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
abort_event = Event()
|
|
598
|
+
|
|
599
|
+
with Progress(
|
|
600
|
+
"[progress.description]{task.description}",
|
|
601
|
+
BarColumn(bar_width=None),
|
|
602
|
+
"[progress.percentage]{task.percentage:>3.0f}%",
|
|
603
|
+
DownloadColumn(),
|
|
604
|
+
TransferSpeedColumn(),
|
|
605
|
+
TimeRemainingColumn(),
|
|
606
|
+
TimeElapsedColumn(),
|
|
607
|
+
refresh_per_second=1,
|
|
608
|
+
disable=not show_progress,
|
|
609
|
+
expand=True,
|
|
610
|
+
) as progress_bar, ThreadPoolExecutor(
|
|
611
|
+
max_workers=_MAX_WORKERS_FOR_UPLOAD
|
|
612
|
+
) as executor:
|
|
613
|
+
futures: List[Future] = []
|
|
614
|
+
# Note: While this batching is beneficial when there is a large number of files, there is also
|
|
615
|
+
# a rare case risk of the signed url expiring before a request is made to it
|
|
616
|
+
_batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
|
|
617
|
+
for start_idx in range(0, len(files_for_normal_upload), _batch_size):
|
|
618
|
+
end_idx = min(start_idx + _batch_size, len(files_for_normal_upload))
|
|
619
|
+
if _any_future_has_failed(futures):
|
|
620
|
+
break
|
|
621
|
+
logger.debug("Generating write signed urls for a batch ...")
|
|
622
|
+
remote_file_paths = [
|
|
623
|
+
files_for_normal_upload[idx][0] for idx in range(start_idx, end_idx)
|
|
624
|
+
]
|
|
625
|
+
signed_urls = self.get_signed_urls_for_write(
|
|
626
|
+
artifact_identifier=self.artifact_identifier,
|
|
627
|
+
paths=remote_file_paths,
|
|
628
|
+
)
|
|
629
|
+
for idx, signed_url in zip(range(start_idx, end_idx), signed_urls):
|
|
630
|
+
(
|
|
631
|
+
upload_path,
|
|
632
|
+
local_file,
|
|
633
|
+
multipart_info,
|
|
634
|
+
) = files_for_normal_upload[idx]
|
|
635
|
+
future = executor.submit(
|
|
636
|
+
self._log_artifact,
|
|
637
|
+
local_file=local_file,
|
|
638
|
+
artifact_path=upload_path,
|
|
639
|
+
multipart_info=multipart_info,
|
|
640
|
+
signed_url=signed_url,
|
|
641
|
+
abort_event=abort_event,
|
|
642
|
+
executor_for_multipart_upload=None,
|
|
643
|
+
progress_bar=progress_bar,
|
|
644
|
+
)
|
|
645
|
+
futures.append(future)
|
|
646
|
+
|
|
647
|
+
done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
|
|
648
|
+
if len(not_done) > 0:
|
|
649
|
+
abort_event.set()
|
|
650
|
+
for future in not_done:
|
|
651
|
+
future.cancel()
|
|
652
|
+
for future in done:
|
|
653
|
+
if future.exception() is not None:
|
|
654
|
+
raise future.exception()
|
|
655
|
+
|
|
656
|
+
for (
|
|
657
|
+
upload_path,
|
|
658
|
+
local_file,
|
|
659
|
+
multipart_info,
|
|
660
|
+
) in files_for_multipart_upload:
|
|
661
|
+
self._log_artifact(
|
|
662
|
+
local_file=local_file,
|
|
663
|
+
artifact_path=upload_path,
|
|
664
|
+
signed_url=None,
|
|
665
|
+
multipart_info=multipart_info,
|
|
666
|
+
executor_for_multipart_upload=executor,
|
|
667
|
+
progress_bar=progress_bar,
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
def _normal_upload(
|
|
671
|
+
self,
|
|
672
|
+
local_file: str,
|
|
673
|
+
artifact_path: str,
|
|
674
|
+
signed_url: Optional[SignedURLDto],
|
|
675
|
+
progress_bar: Progress,
|
|
676
|
+
abort_event: Optional[Event] = None,
|
|
677
|
+
):
|
|
678
|
+
if not signed_url:
|
|
679
|
+
signed_url = self.get_signed_urls_for_write(
|
|
680
|
+
artifact_identifier=self.artifact_identifier, paths=[artifact_path]
|
|
681
|
+
)[0]
|
|
682
|
+
|
|
683
|
+
if progress_bar.disable:
|
|
684
|
+
logger.info(
|
|
685
|
+
"Uploading %s to %s",
|
|
686
|
+
_get_relpath_if_in_tempdir(local_file),
|
|
687
|
+
artifact_path,
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
_signed_url_upload_file(
|
|
691
|
+
signed_url=signed_url,
|
|
692
|
+
local_file=local_file,
|
|
693
|
+
abort_event=abort_event,
|
|
694
|
+
progress_bar=progress_bar,
|
|
695
|
+
)
|
|
696
|
+
logger.debug("Uploaded %s to %s", local_file, artifact_path)
|
|
697
|
+
|
|
698
|
+
def _multipart_upload(
|
|
699
|
+
self,
|
|
700
|
+
local_file: str,
|
|
701
|
+
artifact_path: str,
|
|
702
|
+
multipart_info: _FileMultiPartInfo,
|
|
703
|
+
executor: ThreadPoolExecutor,
|
|
704
|
+
progress_bar: Progress,
|
|
705
|
+
abort_event: Optional[Event] = None,
|
|
706
|
+
):
|
|
707
|
+
if progress_bar.disable:
|
|
708
|
+
logger.info(
|
|
709
|
+
"Uploading %s to %s using multipart upload",
|
|
710
|
+
_get_relpath_if_in_tempdir(local_file),
|
|
711
|
+
artifact_path,
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
multipart_upload = self.create_multipart_upload_for_identifier(
|
|
715
|
+
artifact_identifier=self.artifact_identifier,
|
|
716
|
+
path=artifact_path,
|
|
717
|
+
num_parts=multipart_info.num_parts,
|
|
718
|
+
)
|
|
719
|
+
if (
|
|
720
|
+
multipart_upload.storage_provider
|
|
721
|
+
is MultiPartUploadStorageProvider.S3_COMPATIBLE
|
|
722
|
+
):
|
|
723
|
+
_s3_compatible_multipart_upload(
|
|
724
|
+
multipart_upload=multipart_upload,
|
|
725
|
+
local_file=local_file,
|
|
726
|
+
executor=executor,
|
|
727
|
+
multipart_info=multipart_info,
|
|
728
|
+
abort_event=abort_event,
|
|
729
|
+
progress_bar=progress_bar,
|
|
730
|
+
)
|
|
731
|
+
elif (
|
|
732
|
+
multipart_upload.storage_provider
|
|
733
|
+
is MultiPartUploadStorageProvider.AZURE_BLOB
|
|
734
|
+
):
|
|
735
|
+
_azure_multi_part_upload(
|
|
736
|
+
multipart_upload=multipart_upload,
|
|
737
|
+
local_file=local_file,
|
|
738
|
+
executor=executor,
|
|
739
|
+
multipart_info=multipart_info,
|
|
740
|
+
abort_event=abort_event,
|
|
741
|
+
progress_bar=progress_bar,
|
|
742
|
+
)
|
|
743
|
+
else:
|
|
744
|
+
raise NotImplementedError()
|
|
745
|
+
|
|
746
|
+
def _log_artifact(
|
|
747
|
+
self,
|
|
748
|
+
local_file: str,
|
|
749
|
+
artifact_path: str,
|
|
750
|
+
multipart_info: _FileMultiPartInfo,
|
|
751
|
+
progress_bar: Progress,
|
|
752
|
+
signed_url: Optional[SignedURLDto] = None,
|
|
753
|
+
abort_event: Optional[Event] = None,
|
|
754
|
+
executor_for_multipart_upload: Optional[ThreadPoolExecutor] = None,
|
|
755
|
+
):
|
|
756
|
+
if multipart_info.num_parts == 1:
|
|
757
|
+
return self._normal_upload(
|
|
758
|
+
local_file=local_file,
|
|
759
|
+
artifact_path=artifact_path,
|
|
760
|
+
signed_url=signed_url,
|
|
761
|
+
abort_event=abort_event,
|
|
762
|
+
progress_bar=progress_bar,
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
if not executor_for_multipart_upload:
|
|
766
|
+
with ThreadPoolExecutor(max_workers=_MAX_WORKERS_FOR_UPLOAD) as executor:
|
|
767
|
+
return self._multipart_upload(
|
|
768
|
+
local_file=local_file,
|
|
769
|
+
artifact_path=artifact_path,
|
|
770
|
+
executor=executor,
|
|
771
|
+
multipart_info=multipart_info,
|
|
772
|
+
progress_bar=progress_bar,
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
return self._multipart_upload(
|
|
776
|
+
local_file=local_file,
|
|
777
|
+
artifact_path=artifact_path,
|
|
778
|
+
executor=executor_for_multipart_upload,
|
|
779
|
+
multipart_info=multipart_info,
|
|
780
|
+
progress_bar=progress_bar,
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
def log_artifact(self, local_file: str, artifact_path: Optional[str] = None):
|
|
784
|
+
upload_path = artifact_path or ""
|
|
785
|
+
upload_path = upload_path.lstrip(posixpath.sep)
|
|
786
|
+
upload_path = posixpath.join(upload_path, os.path.basename(local_file))
|
|
787
|
+
with Progress(
|
|
788
|
+
"[progress.description]{task.description}",
|
|
789
|
+
BarColumn(bar_width=None),
|
|
790
|
+
"[progress.percentage]{task.percentage:>3.0f}%",
|
|
791
|
+
DownloadColumn(),
|
|
792
|
+
TransferSpeedColumn(),
|
|
793
|
+
TimeRemainingColumn(),
|
|
794
|
+
TimeElapsedColumn(),
|
|
795
|
+
refresh_per_second=1,
|
|
796
|
+
disable=True,
|
|
797
|
+
expand=True,
|
|
798
|
+
) as progress_bar:
|
|
799
|
+
self._log_artifact(
|
|
800
|
+
local_file=local_file,
|
|
801
|
+
artifact_path=upload_path,
|
|
802
|
+
multipart_info=_decide_file_parts(local_file),
|
|
803
|
+
progress_bar=progress_bar,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
def _is_directory(self, artifact_path):
|
|
807
|
+
for _ in self.list_artifacts(artifact_path, page_size=3):
|
|
808
|
+
return True
|
|
809
|
+
return False
|
|
810
|
+
|
|
811
|
+
def download_artifacts( # noqa: C901
|
|
812
|
+
self,
|
|
813
|
+
artifact_path: str,
|
|
814
|
+
dst_path: Optional[str] = None,
|
|
815
|
+
overwrite: bool = False,
|
|
816
|
+
progress: Optional[bool] = None,
|
|
817
|
+
) -> str:
|
|
818
|
+
"""
|
|
819
|
+
Download an artifact file or directory to a local directory if applicable, and return a
|
|
820
|
+
local path for it. The caller is responsible for managing the lifecycle of the downloaded artifacts.
|
|
821
|
+
|
|
822
|
+
Args:
|
|
823
|
+
artifact_path: Relative source path to the desired artifacts.
|
|
824
|
+
dst_path: Absolute path of the local filesystem destination directory to which to
|
|
825
|
+
download the specified artifacts. This directory must already exist.
|
|
826
|
+
If unspecified, the artifacts will either be downloaded to a new
|
|
827
|
+
uniquely-named directory.
|
|
828
|
+
overwrite: if to overwrite the files at/inside `dst_path` if they exist
|
|
829
|
+
progress: Show or hide progress bar
|
|
830
|
+
|
|
831
|
+
Returns:
|
|
832
|
+
str: Absolute path of the local filesystem location containing the desired artifacts.
|
|
833
|
+
"""
|
|
834
|
+
|
|
835
|
+
show_progress = _can_display_progress()
|
|
836
|
+
|
|
837
|
+
is_dir_temp = False
|
|
838
|
+
if dst_path is None:
|
|
839
|
+
dst_path = tempfile.mkdtemp()
|
|
840
|
+
is_dir_temp = True
|
|
841
|
+
|
|
842
|
+
dst_path = os.path.abspath(dst_path)
|
|
843
|
+
if is_dir_temp:
|
|
844
|
+
logger.info(
|
|
845
|
+
f"Using temporary directory {dst_path} as the download directory"
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
if not os.path.exists(dst_path):
|
|
849
|
+
raise MlFoundryException(
|
|
850
|
+
message=(
|
|
851
|
+
"The destination path for downloaded artifacts does not"
|
|
852
|
+
" exist! Destination path: {dst_path}".format(dst_path=dst_path)
|
|
853
|
+
),
|
|
854
|
+
)
|
|
855
|
+
elif not os.path.isdir(dst_path):
|
|
856
|
+
raise MlFoundryException(
|
|
857
|
+
message=(
|
|
858
|
+
"The destination path for downloaded artifacts must be a directory!"
|
|
859
|
+
" Destination path: {dst_path}".format(dst_path=dst_path)
|
|
860
|
+
),
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
progress_bar = Progress(
|
|
864
|
+
"[progress.description]{task.description}",
|
|
865
|
+
BarColumn(),
|
|
866
|
+
"[progress.percentage]{task.percentage:>3.0f}%",
|
|
867
|
+
DownloadColumn(),
|
|
868
|
+
TransferSpeedColumn(),
|
|
869
|
+
TimeRemainingColumn(),
|
|
870
|
+
TimeElapsedColumn(),
|
|
871
|
+
refresh_per_second=1,
|
|
872
|
+
disable=not show_progress,
|
|
873
|
+
expand=True,
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
try:
|
|
877
|
+
progress_bar.start()
|
|
878
|
+
# Check if the artifacts points to a directory
|
|
879
|
+
if self._is_directory(artifact_path):
|
|
880
|
+
futures: List[Future] = []
|
|
881
|
+
file_paths: List[Tuple[str, str]] = []
|
|
882
|
+
abort_event = Event()
|
|
883
|
+
|
|
884
|
+
# Check if any file is being overwritten before downloading them
|
|
885
|
+
for file_path, download_dest_path in self._get_file_paths_recur(
|
|
886
|
+
src_artifact_dir_path=artifact_path, dst_local_dir_path=dst_path
|
|
887
|
+
):
|
|
888
|
+
final_file_path = os.path.join(download_dest_path, file_path)
|
|
889
|
+
|
|
890
|
+
# There would be no overwrite if temp directory is being used
|
|
891
|
+
if (
|
|
892
|
+
not is_dir_temp
|
|
893
|
+
and os.path.exists(final_file_path)
|
|
894
|
+
and not overwrite
|
|
895
|
+
):
|
|
896
|
+
raise MlFoundryException(
|
|
897
|
+
f"File already exists at {final_file_path}, aborting download "
|
|
898
|
+
f"(set `overwrite` flag to overwrite this and any subsequent files)"
|
|
899
|
+
)
|
|
900
|
+
file_paths.append((file_path, download_dest_path))
|
|
901
|
+
|
|
902
|
+
with ThreadPoolExecutor(_MAX_WORKERS_FOR_DOWNLOAD) as executor:
|
|
903
|
+
# Note: While this batching is beneficial when there is a large number of files, there is also
|
|
904
|
+
# a rare case risk of the signed url expiring before a request is made to it
|
|
905
|
+
batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
|
|
906
|
+
for start_idx in range(0, len(file_paths), batch_size):
|
|
907
|
+
end_idx = min(start_idx + batch_size, len(file_paths))
|
|
908
|
+
if _any_future_has_failed(futures):
|
|
909
|
+
break
|
|
910
|
+
logger.debug("Generating read signed urls for a batch ...")
|
|
911
|
+
remote_file_paths = [
|
|
912
|
+
file_paths[idx][0] for idx in range(start_idx, end_idx)
|
|
913
|
+
]
|
|
914
|
+
signed_urls = self.get_signed_urls_for_read(
|
|
915
|
+
artifact_identifier=self.artifact_identifier,
|
|
916
|
+
paths=remote_file_paths,
|
|
917
|
+
)
|
|
918
|
+
for idx, signed_url in zip(
|
|
919
|
+
range(start_idx, end_idx), signed_urls
|
|
920
|
+
):
|
|
921
|
+
file_path, download_dest_path = file_paths[idx]
|
|
922
|
+
future = executor.submit(
|
|
923
|
+
self._download_artifact,
|
|
924
|
+
src_artifact_path=file_path,
|
|
925
|
+
dst_local_dir_path=download_dest_path,
|
|
926
|
+
signed_url=signed_url,
|
|
927
|
+
abort_event=abort_event,
|
|
928
|
+
progress_bar=progress_bar,
|
|
929
|
+
)
|
|
930
|
+
futures.append(future)
|
|
931
|
+
|
|
932
|
+
done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
|
|
933
|
+
if len(not_done) > 0:
|
|
934
|
+
abort_event.set()
|
|
935
|
+
for future in not_done:
|
|
936
|
+
future.cancel()
|
|
937
|
+
for future in done:
|
|
938
|
+
if future.exception() is not None:
|
|
939
|
+
raise future.exception()
|
|
940
|
+
|
|
941
|
+
output_dir = os.path.join(dst_path, artifact_path)
|
|
942
|
+
return output_dir
|
|
943
|
+
else:
|
|
944
|
+
return self._download_artifact(
|
|
945
|
+
src_artifact_path=artifact_path,
|
|
946
|
+
dst_local_dir_path=dst_path,
|
|
947
|
+
signed_url=None,
|
|
948
|
+
progress_bar=progress_bar,
|
|
949
|
+
)
|
|
950
|
+
except Exception as err:
|
|
951
|
+
if is_dir_temp:
|
|
952
|
+
logger.info(
|
|
953
|
+
f"Error encountered, removing temporary download directory at {dst_path}"
|
|
954
|
+
)
|
|
955
|
+
rmtree(dst_path) # remove temp directory alongside it's contents
|
|
956
|
+
raise err
|
|
957
|
+
|
|
958
|
+
finally:
|
|
959
|
+
progress_bar.stop()
|
|
960
|
+
|
|
961
|
+
# noinspection PyMethodOverriding
|
|
962
|
+
def _download_file(
|
|
963
|
+
self,
|
|
964
|
+
remote_file_path: str,
|
|
965
|
+
local_path: str,
|
|
966
|
+
progress_bar: Optional[Progress],
|
|
967
|
+
signed_url: Optional[SignedURLDto],
|
|
968
|
+
abort_event: Optional[Event] = None,
|
|
969
|
+
):
|
|
970
|
+
if not remote_file_path:
|
|
971
|
+
raise MlFoundryException(
|
|
972
|
+
f"remote_file_path cannot be None or empty str {remote_file_path}"
|
|
973
|
+
)
|
|
974
|
+
if not signed_url:
|
|
975
|
+
signed_url = self.get_signed_urls_for_read(
|
|
976
|
+
artifact_identifier=self.artifact_identifier, paths=[remote_file_path]
|
|
977
|
+
)[0]
|
|
978
|
+
|
|
979
|
+
if progress_bar is None or not progress_bar.disable:
|
|
980
|
+
logger.info("Downloading %s to %s", remote_file_path, local_path)
|
|
981
|
+
|
|
982
|
+
if progress_bar is not None:
|
|
983
|
+
download_progress_bar = progress_bar.add_task(
|
|
984
|
+
f"[green]Downloading to {remote_file_path}:", start=True
|
|
985
|
+
)
|
|
986
|
+
|
|
987
|
+
def callback(chunk, total_file_size):
|
|
988
|
+
if progress_bar is not None:
|
|
989
|
+
progress_bar.update(
|
|
990
|
+
download_progress_bar,
|
|
991
|
+
advance=chunk,
|
|
992
|
+
total=total_file_size,
|
|
993
|
+
)
|
|
994
|
+
if abort_event and abort_event.is_set():
|
|
995
|
+
raise Exception("aborting download")
|
|
996
|
+
|
|
997
|
+
_download_file_using_http_uri(
|
|
998
|
+
http_uri=signed_url.signed_url,
|
|
999
|
+
download_path=local_path,
|
|
1000
|
+
callback=callback,
|
|
1001
|
+
)
|
|
1002
|
+
logger.debug("Downloaded %s to %s", remote_file_path, local_path)
|
|
1003
|
+
|
|
1004
|
+
def _download_artifact(
|
|
1005
|
+
self,
|
|
1006
|
+
src_artifact_path,
|
|
1007
|
+
dst_local_dir_path,
|
|
1008
|
+
signed_url: Optional[SignedURLDto],
|
|
1009
|
+
progress_bar: Optional[Progress] = None,
|
|
1010
|
+
abort_event=None,
|
|
1011
|
+
) -> str:
|
|
1012
|
+
"""
|
|
1013
|
+
Download the file artifact specified by `src_artifact_path` to the local filesystem
|
|
1014
|
+
directory specified by `dst_local_dir_path`.
|
|
1015
|
+
:param src_artifact_path: A relative, POSIX-style path referring to a file artifact
|
|
1016
|
+
stored within the repository's artifact root location.
|
|
1017
|
+
`src_artifact_path` should be specified relative to the
|
|
1018
|
+
repository's artifact root location.
|
|
1019
|
+
:param dst_local_dir_path: Absolute path of the local filesystem destination directory
|
|
1020
|
+
to which to download the specified artifact. The downloaded
|
|
1021
|
+
artifact may be written to a subdirectory of
|
|
1022
|
+
`dst_local_dir_path` if `src_artifact_path` contains
|
|
1023
|
+
subdirectories.
|
|
1024
|
+
:param progress_bar: An instance of a Rich progress bar used to visually display the
|
|
1025
|
+
progress of the file download.
|
|
1026
|
+
:return: A local filesystem path referring to the downloaded file.
|
|
1027
|
+
"""
|
|
1028
|
+
local_destination_file_path = self._create_download_destination(
|
|
1029
|
+
src_artifact_path=src_artifact_path, dst_local_dir_path=dst_local_dir_path
|
|
1030
|
+
)
|
|
1031
|
+
self._download_file(
|
|
1032
|
+
remote_file_path=src_artifact_path,
|
|
1033
|
+
local_path=local_destination_file_path,
|
|
1034
|
+
signed_url=signed_url,
|
|
1035
|
+
abort_event=abort_event,
|
|
1036
|
+
progress_bar=progress_bar,
|
|
1037
|
+
)
|
|
1038
|
+
return local_destination_file_path
|
|
1039
|
+
|
|
1040
|
+
def _get_file_paths_recur(self, src_artifact_dir_path, dst_local_dir_path):
|
|
1041
|
+
local_dir = os.path.join(dst_local_dir_path, src_artifact_dir_path)
|
|
1042
|
+
dir_content = [ # prevent infinite loop, sometimes the dir is recursively included
|
|
1043
|
+
file_info
|
|
1044
|
+
for file_info in self.list_artifacts(src_artifact_dir_path)
|
|
1045
|
+
if file_info.path != "." and file_info.path != src_artifact_dir_path
|
|
1046
|
+
]
|
|
1047
|
+
if not dir_content: # empty dir
|
|
1048
|
+
if not os.path.exists(local_dir):
|
|
1049
|
+
os.makedirs(local_dir, exist_ok=True)
|
|
1050
|
+
else:
|
|
1051
|
+
for file_info in dir_content:
|
|
1052
|
+
if file_info.is_dir:
|
|
1053
|
+
yield from self._get_file_paths_recur(
|
|
1054
|
+
src_artifact_dir_path=file_info.path,
|
|
1055
|
+
dst_local_dir_path=dst_local_dir_path,
|
|
1056
|
+
)
|
|
1057
|
+
else:
|
|
1058
|
+
yield file_info.path, dst_local_dir_path
|
|
1059
|
+
|
|
1060
|
+
# TODO (chiragjn): Refactor these methods - if else is very inconvenient
|
|
1061
|
+
def get_signed_urls_for_read(
|
|
1062
|
+
self,
|
|
1063
|
+
artifact_identifier: ArtifactIdentifier,
|
|
1064
|
+
paths,
|
|
1065
|
+
) -> List[SignedURLDto]:
|
|
1066
|
+
if artifact_identifier.artifact_version_id:
|
|
1067
|
+
signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_read_post(
|
|
1068
|
+
get_signed_urls_for_artifact_version_read_request_dto=GetSignedURLsForArtifactVersionReadRequestDto(
|
|
1069
|
+
id=str(artifact_identifier.artifact_version_id), paths=paths
|
|
1070
|
+
)
|
|
1071
|
+
)
|
|
1072
|
+
signed_urls = signed_urls_response.signed_urls
|
|
1073
|
+
elif artifact_identifier.dataset_fqn:
|
|
1074
|
+
signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_dataset_read_post(
|
|
1075
|
+
get_signed_urls_for_dataset_read_request_dto=GetSignedURLsForDatasetReadRequestDto(
|
|
1076
|
+
dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
|
|
1077
|
+
)
|
|
1078
|
+
)
|
|
1079
|
+
signed_urls = signed_urls_dataset_response.signed_urls
|
|
1080
|
+
else:
|
|
1081
|
+
raise ValueError(
|
|
1082
|
+
"Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
|
|
1083
|
+
)
|
|
1084
|
+
return signed_urls
|
|
1085
|
+
|
|
1086
|
+
def get_signed_urls_for_write(
|
|
1087
|
+
self,
|
|
1088
|
+
artifact_identifier: ArtifactIdentifier,
|
|
1089
|
+
paths: List[str],
|
|
1090
|
+
) -> List[SignedURLDto]:
|
|
1091
|
+
if artifact_identifier.artifact_version_id:
|
|
1092
|
+
signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_write_post(
|
|
1093
|
+
get_signed_urls_for_artifact_version_write_request_dto=GetSignedURLsForArtifactVersionWriteRequestDto(
|
|
1094
|
+
id=str(artifact_identifier.artifact_version_id), paths=paths
|
|
1095
|
+
)
|
|
1096
|
+
)
|
|
1097
|
+
signed_urls = signed_urls_response.signed_urls
|
|
1098
|
+
elif artifact_identifier.dataset_fqn:
|
|
1099
|
+
signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_for_dataset_write_post(
|
|
1100
|
+
get_signed_url_for_dataset_write_request_dto=GetSignedURLForDatasetWriteRequestDto(
|
|
1101
|
+
dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
|
|
1102
|
+
)
|
|
1103
|
+
)
|
|
1104
|
+
signed_urls = signed_urls_dataset_response.signed_urls
|
|
1105
|
+
else:
|
|
1106
|
+
raise ValueError(
|
|
1107
|
+
"Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
|
|
1108
|
+
)
|
|
1109
|
+
return signed_urls
|
|
1110
|
+
|
|
1111
|
+
def create_multipart_upload_for_identifier(
|
|
1112
|
+
self,
|
|
1113
|
+
artifact_identifier: ArtifactIdentifier,
|
|
1114
|
+
path,
|
|
1115
|
+
num_parts,
|
|
1116
|
+
) -> MultiPartUploadDto:
|
|
1117
|
+
if artifact_identifier.artifact_version_id:
|
|
1118
|
+
create_multipart_response: MultiPartUploadResponseDto = self._mlfoundry_artifacts_api.create_multi_part_upload_post(
|
|
1119
|
+
create_multi_part_upload_request_dto=CreateMultiPartUploadRequestDto(
|
|
1120
|
+
artifact_version_id=str(artifact_identifier.artifact_version_id),
|
|
1121
|
+
path=path,
|
|
1122
|
+
num_parts=num_parts,
|
|
1123
|
+
)
|
|
1124
|
+
)
|
|
1125
|
+
multipart_upload = create_multipart_response.multipart_upload
|
|
1126
|
+
elif artifact_identifier.dataset_fqn:
|
|
1127
|
+
create_multipart_for_dataset_response = self._mlfoundry_artifacts_api.create_multipart_upload_for_dataset_post(
|
|
1128
|
+
create_multi_part_upload_for_dataset_request_dto=CreateMultiPartUploadForDatasetRequestDto(
|
|
1129
|
+
dataset_fqn=artifact_identifier.dataset_fqn,
|
|
1130
|
+
path=path,
|
|
1131
|
+
num_parts=num_parts,
|
|
1132
|
+
)
|
|
1133
|
+
)
|
|
1134
|
+
multipart_upload = create_multipart_for_dataset_response.multipart_upload
|
|
1135
|
+
else:
|
|
1136
|
+
raise ValueError(
|
|
1137
|
+
"Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
|
|
1138
|
+
)
|
|
1139
|
+
return multipart_upload
|
|
1140
|
+
|
|
1141
|
+
def list_files(
|
|
1142
|
+
self, artifact_identifier: ArtifactIdentifier, path, page_size, page_token
|
|
1143
|
+
) -> Union[ListFilesForDatasetResponseDto, ListFilesForArtifactVersionsResponseDto]:
|
|
1144
|
+
if artifact_identifier.dataset_fqn:
|
|
1145
|
+
return self._mlfoundry_artifacts_api.list_files_for_dataset_post(
|
|
1146
|
+
list_files_for_dataset_request_dto=ListFilesForDatasetRequestDto(
|
|
1147
|
+
dataset_fqn=artifact_identifier.dataset_fqn,
|
|
1148
|
+
path=path,
|
|
1149
|
+
max_results=page_size,
|
|
1150
|
+
page_token=page_token,
|
|
1151
|
+
)
|
|
1152
|
+
)
|
|
1153
|
+
else:
|
|
1154
|
+
return self._mlfoundry_artifacts_api.list_files_for_artifact_version_post(
|
|
1155
|
+
list_files_for_artifact_version_request_dto=ListFilesForArtifactVersionRequestDto(
|
|
1156
|
+
id=str(artifact_identifier.artifact_version_id),
|
|
1157
|
+
path=path,
|
|
1158
|
+
max_results=page_size,
|
|
1159
|
+
page_token=page_token,
|
|
1160
|
+
)
|
|
1161
|
+
)
|