truefoundry 0.3.3__py3-none-any.whl → 0.4.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (224) hide show
  1. truefoundry/cli/__main__.py +3 -17
  2. truefoundry/common/__init__.py +0 -0
  3. truefoundry/common/request_utils.py +56 -0
  4. truefoundry/deploy/cli/cli.py +1 -1
  5. truefoundry/deploy/lib/auth/credential_provider.py +2 -12
  6. truefoundry/deploy/lib/clients/servicefoundry_client.py +0 -9
  7. truefoundry/deploy/lib/exceptions.py +1 -6
  8. truefoundry/deploy/lib/session.py +1 -16
  9. truefoundry/langchain/truefoundry_chat.py +1 -1
  10. truefoundry/langchain/truefoundry_embeddings.py +1 -1
  11. truefoundry/langchain/truefoundry_llm.py +1 -1
  12. truefoundry/langchain/utils.py +0 -41
  13. truefoundry/ml/__init__.py +46 -6
  14. truefoundry/ml/artifact/__init__.py +0 -0
  15. truefoundry/ml/artifact/truefoundry_artifact_repo.py +1120 -0
  16. truefoundry/ml/autogen/__init__.py +0 -0
  17. truefoundry/ml/autogen/client/__init__.py +373 -0
  18. truefoundry/ml/autogen/client/api/__init__.py +16 -0
  19. truefoundry/ml/autogen/client/api/auth_api.py +184 -0
  20. truefoundry/ml/autogen/client/api/deprecated_api.py +605 -0
  21. truefoundry/ml/autogen/client/api/experiments_api.py +2109 -0
  22. truefoundry/ml/autogen/client/api/health_api.py +299 -0
  23. truefoundry/ml/autogen/client/api/metrics_api.py +371 -0
  24. truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +7213 -0
  25. truefoundry/ml/autogen/client/api/python_deployment_config_api.py +201 -0
  26. truefoundry/ml/autogen/client/api/run_artifacts_api.py +231 -0
  27. truefoundry/ml/autogen/client/api/runs_api.py +2919 -0
  28. truefoundry/ml/autogen/client/api_client.py +822 -0
  29. truefoundry/ml/autogen/client/api_response.py +30 -0
  30. truefoundry/ml/autogen/client/configuration.py +489 -0
  31. truefoundry/ml/autogen/client/exceptions.py +161 -0
  32. truefoundry/ml/autogen/client/models/__init__.py +344 -0
  33. truefoundry/ml/autogen/client/models/add_custom_metrics_to_model_version_request_dto.py +69 -0
  34. truefoundry/ml/autogen/client/models/add_features_to_model_version_request_dto.py +83 -0
  35. truefoundry/ml/autogen/client/models/agent.py +125 -0
  36. truefoundry/ml/autogen/client/models/agent_app.py +118 -0
  37. truefoundry/ml/autogen/client/models/agent_open_api_tool.py +143 -0
  38. truefoundry/ml/autogen/client/models/agent_open_api_tool_with_fqn.py +144 -0
  39. truefoundry/ml/autogen/client/models/agent_with_fqn.py +127 -0
  40. truefoundry/ml/autogen/client/models/artifact_dto.py +115 -0
  41. truefoundry/ml/autogen/client/models/artifact_response_dto.py +75 -0
  42. truefoundry/ml/autogen/client/models/artifact_type.py +39 -0
  43. truefoundry/ml/autogen/client/models/artifact_version_dto.py +141 -0
  44. truefoundry/ml/autogen/client/models/artifact_version_response_dto.py +77 -0
  45. truefoundry/ml/autogen/client/models/artifact_version_status.py +35 -0
  46. truefoundry/ml/autogen/client/models/assistant_message.py +89 -0
  47. truefoundry/ml/autogen/client/models/authorize_user_for_model_request_dto.py +69 -0
  48. truefoundry/ml/autogen/client/models/authorize_user_for_model_version_request_dto.py +69 -0
  49. truefoundry/ml/autogen/client/models/backfill_default_storage_integration_id_request_dto.py +67 -0
  50. truefoundry/ml/autogen/client/models/blob_storage_reference.py +93 -0
  51. truefoundry/ml/autogen/client/models/body_get_search_runs_get.py +72 -0
  52. truefoundry/ml/autogen/client/models/chat_prompt.py +156 -0
  53. truefoundry/ml/autogen/client/models/chat_prompt_messages_inner.py +171 -0
  54. truefoundry/ml/autogen/client/models/columns_dto.py +73 -0
  55. truefoundry/ml/autogen/client/models/content.py +153 -0
  56. truefoundry/ml/autogen/client/models/content1.py +153 -0
  57. truefoundry/ml/autogen/client/models/content2.py +174 -0
  58. truefoundry/ml/autogen/client/models/content2_any_of_inner.py +150 -0
  59. truefoundry/ml/autogen/client/models/create_artifact_request_dto.py +74 -0
  60. truefoundry/ml/autogen/client/models/create_artifact_response_dto.py +66 -0
  61. truefoundry/ml/autogen/client/models/create_artifact_version_request_dto.py +74 -0
  62. truefoundry/ml/autogen/client/models/create_artifact_version_response_dto.py +66 -0
  63. truefoundry/ml/autogen/client/models/create_dataset_request_dto.py +76 -0
  64. truefoundry/ml/autogen/client/models/create_experiment_request_dto.py +94 -0
  65. truefoundry/ml/autogen/client/models/create_experiment_response_dto.py +67 -0
  66. truefoundry/ml/autogen/client/models/create_model_version_request_dto.py +95 -0
  67. truefoundry/ml/autogen/client/models/create_multi_part_upload_for_dataset_request_dto.py +73 -0
  68. truefoundry/ml/autogen/client/models/create_multi_part_upload_for_dataset_response_dto.py +79 -0
  69. truefoundry/ml/autogen/client/models/create_multi_part_upload_request_dto.py +73 -0
  70. truefoundry/ml/autogen/client/models/create_python_deployment_config_request_dto.py +72 -0
  71. truefoundry/ml/autogen/client/models/create_python_deployment_config_response_dto.py +68 -0
  72. truefoundry/ml/autogen/client/models/create_run_request_dto.py +97 -0
  73. truefoundry/ml/autogen/client/models/create_run_response_dto.py +76 -0
  74. truefoundry/ml/autogen/client/models/dataset_dto.py +112 -0
  75. truefoundry/ml/autogen/client/models/dataset_response_dto.py +75 -0
  76. truefoundry/ml/autogen/client/models/delete_artifact_versions_request_dto.py +65 -0
  77. truefoundry/ml/autogen/client/models/delete_dataset_request_dto.py +74 -0
  78. truefoundry/ml/autogen/client/models/delete_model_version_request_dto.py +65 -0
  79. truefoundry/ml/autogen/client/models/delete_run_request.py +65 -0
  80. truefoundry/ml/autogen/client/models/delete_tag_request_dto.py +68 -0
  81. truefoundry/ml/autogen/client/models/experiment_dto.py +127 -0
  82. truefoundry/ml/autogen/client/models/experiment_id_request_dto.py +67 -0
  83. truefoundry/ml/autogen/client/models/experiment_response_dto.py +75 -0
  84. truefoundry/ml/autogen/client/models/experiment_tag_dto.py +69 -0
  85. truefoundry/ml/autogen/client/models/feature_dto.py +68 -0
  86. truefoundry/ml/autogen/client/models/feature_value_type.py +35 -0
  87. truefoundry/ml/autogen/client/models/file_info_dto.py +76 -0
  88. truefoundry/ml/autogen/client/models/finalize_artifact_version_request_dto.py +101 -0
  89. truefoundry/ml/autogen/client/models/get_experiment_response_dto.py +88 -0
  90. truefoundry/ml/autogen/client/models/get_latest_run_log_response_dto.py +76 -0
  91. truefoundry/ml/autogen/client/models/get_metric_history_response.py +79 -0
  92. truefoundry/ml/autogen/client/models/get_signed_url_for_dataset_write_request_dto.py +68 -0
  93. truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_read_request_dto.py +68 -0
  94. truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_read_response_dto.py +81 -0
  95. truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_write_request_dto.py +69 -0
  96. truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_write_response_dto.py +83 -0
  97. truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_read_request_dto.py +68 -0
  98. truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_read_response_dto.py +81 -0
  99. truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_write_response_dto.py +81 -0
  100. truefoundry/ml/autogen/client/models/get_tenant_id_response_dto.py +74 -0
  101. truefoundry/ml/autogen/client/models/http_validation_error.py +82 -0
  102. truefoundry/ml/autogen/client/models/image_content_part.py +87 -0
  103. truefoundry/ml/autogen/client/models/image_url.py +75 -0
  104. truefoundry/ml/autogen/client/models/internal_metadata.py +180 -0
  105. truefoundry/ml/autogen/client/models/latest_run_log_dto.py +78 -0
  106. truefoundry/ml/autogen/client/models/list_artifact_versions_request_dto.py +107 -0
  107. truefoundry/ml/autogen/client/models/list_artifact_versions_response_dto.py +87 -0
  108. truefoundry/ml/autogen/client/models/list_artifacts_request_dto.py +96 -0
  109. truefoundry/ml/autogen/client/models/list_artifacts_response_dto.py +86 -0
  110. truefoundry/ml/autogen/client/models/list_colums_response_dto.py +75 -0
  111. truefoundry/ml/autogen/client/models/list_datasets_request_dto.py +78 -0
  112. truefoundry/ml/autogen/client/models/list_datasets_response_dto.py +86 -0
  113. truefoundry/ml/autogen/client/models/list_experiments_response_dto.py +86 -0
  114. truefoundry/ml/autogen/client/models/list_files_for_artifact_version_request_dto.py +76 -0
  115. truefoundry/ml/autogen/client/models/list_files_for_artifact_versions_response_dto.py +82 -0
  116. truefoundry/ml/autogen/client/models/list_files_for_dataset_request_dto.py +76 -0
  117. truefoundry/ml/autogen/client/models/list_files_for_dataset_response_dto.py +82 -0
  118. truefoundry/ml/autogen/client/models/list_latest_run_logs_response_dto.py +82 -0
  119. truefoundry/ml/autogen/client/models/list_metric_history_request_dto.py +69 -0
  120. truefoundry/ml/autogen/client/models/list_metric_history_response_dto.py +84 -0
  121. truefoundry/ml/autogen/client/models/list_model_version_response_dto.py +87 -0
  122. truefoundry/ml/autogen/client/models/list_model_versions_request_dto.py +93 -0
  123. truefoundry/ml/autogen/client/models/list_models_request_dto.py +89 -0
  124. truefoundry/ml/autogen/client/models/list_models_response_dto.py +84 -0
  125. truefoundry/ml/autogen/client/models/list_run_artifacts_response_dto.py +84 -0
  126. truefoundry/ml/autogen/client/models/list_run_logs_response_dto.py +82 -0
  127. truefoundry/ml/autogen/client/models/list_seed_experiments_response_dto.py +81 -0
  128. truefoundry/ml/autogen/client/models/log_batch_request_dto.py +106 -0
  129. truefoundry/ml/autogen/client/models/log_metric_request_dto.py +80 -0
  130. truefoundry/ml/autogen/client/models/log_param_request_dto.py +76 -0
  131. truefoundry/ml/autogen/client/models/method.py +37 -0
  132. truefoundry/ml/autogen/client/models/metric_collection_dto.py +82 -0
  133. truefoundry/ml/autogen/client/models/metric_dto.py +76 -0
  134. truefoundry/ml/autogen/client/models/mime_type.py +37 -0
  135. truefoundry/ml/autogen/client/models/model_configuration.py +103 -0
  136. truefoundry/ml/autogen/client/models/model_dto.py +122 -0
  137. truefoundry/ml/autogen/client/models/model_response_dto.py +75 -0
  138. truefoundry/ml/autogen/client/models/model_schema_dto.py +85 -0
  139. truefoundry/ml/autogen/client/models/model_version_dto.py +163 -0
  140. truefoundry/ml/autogen/client/models/model_version_response_dto.py +75 -0
  141. truefoundry/ml/autogen/client/models/multi_part_upload_dto.py +107 -0
  142. truefoundry/ml/autogen/client/models/multi_part_upload_response_dto.py +79 -0
  143. truefoundry/ml/autogen/client/models/multi_part_upload_storage_provider.py +34 -0
  144. truefoundry/ml/autogen/client/models/notify_artifact_version_failure_dto.py +65 -0
  145. truefoundry/ml/autogen/client/models/openapi_spec.py +152 -0
  146. truefoundry/ml/autogen/client/models/param_dto.py +66 -0
  147. truefoundry/ml/autogen/client/models/parameters.py +84 -0
  148. truefoundry/ml/autogen/client/models/prediction_type.py +34 -0
  149. truefoundry/ml/autogen/client/models/resolve_agent_app_response_dto.py +75 -0
  150. truefoundry/ml/autogen/client/models/restore_run_request_dto.py +65 -0
  151. truefoundry/ml/autogen/client/models/run_data_dto.py +104 -0
  152. truefoundry/ml/autogen/client/models/run_dto.py +84 -0
  153. truefoundry/ml/autogen/client/models/run_info_dto.py +105 -0
  154. truefoundry/ml/autogen/client/models/run_log_dto.py +90 -0
  155. truefoundry/ml/autogen/client/models/run_log_input_dto.py +80 -0
  156. truefoundry/ml/autogen/client/models/run_response_dto.py +75 -0
  157. truefoundry/ml/autogen/client/models/run_tag_dto.py +66 -0
  158. truefoundry/ml/autogen/client/models/search_runs_request_dto.py +94 -0
  159. truefoundry/ml/autogen/client/models/search_runs_response_dto.py +84 -0
  160. truefoundry/ml/autogen/client/models/set_experiment_tag_request_dto.py +73 -0
  161. truefoundry/ml/autogen/client/models/set_tag_request_dto.py +76 -0
  162. truefoundry/ml/autogen/client/models/signed_url_dto.py +69 -0
  163. truefoundry/ml/autogen/client/models/stop.py +152 -0
  164. truefoundry/ml/autogen/client/models/store_run_logs_request_dto.py +83 -0
  165. truefoundry/ml/autogen/client/models/system_message.py +89 -0
  166. truefoundry/ml/autogen/client/models/text.py +153 -0
  167. truefoundry/ml/autogen/client/models/text_content_part.py +84 -0
  168. truefoundry/ml/autogen/client/models/update_artifact_version_request_dto.py +74 -0
  169. truefoundry/ml/autogen/client/models/update_dataset_request_dto.py +74 -0
  170. truefoundry/ml/autogen/client/models/update_experiment_request_dto.py +74 -0
  171. truefoundry/ml/autogen/client/models/update_model_version_request_dto.py +93 -0
  172. truefoundry/ml/autogen/client/models/update_run_request_dto.py +78 -0
  173. truefoundry/ml/autogen/client/models/update_run_response_dto.py +76 -0
  174. truefoundry/ml/autogen/client/models/url.py +153 -0
  175. truefoundry/ml/autogen/client/models/user_message.py +89 -0
  176. truefoundry/ml/autogen/client/models/validation_error.py +87 -0
  177. truefoundry/ml/autogen/client/models/validation_error_loc_inner.py +154 -0
  178. truefoundry/ml/autogen/client/rest.py +426 -0
  179. truefoundry/ml/autogen/client_README.md +322 -0
  180. truefoundry/ml/cli/__init__.py +0 -0
  181. truefoundry/ml/cli/cli.py +18 -0
  182. truefoundry/ml/cli/commands/__init__.py +3 -0
  183. truefoundry/ml/cli/commands/download.py +87 -0
  184. truefoundry/ml/constants.py +84 -0
  185. truefoundry/ml/enums.py +70 -0
  186. truefoundry/ml/env_vars.py +13 -0
  187. truefoundry/ml/exceptions.py +8 -0
  188. truefoundry/ml/git_info.py +60 -0
  189. truefoundry/ml/internal_namespace.py +52 -0
  190. truefoundry/ml/log_types/__init__.py +4 -0
  191. truefoundry/ml/log_types/artifacts/artifact.py +427 -0
  192. truefoundry/ml/log_types/artifacts/constants.py +33 -0
  193. truefoundry/ml/log_types/artifacts/dataset.py +383 -0
  194. truefoundry/ml/log_types/artifacts/general_artifact.py +110 -0
  195. truefoundry/ml/log_types/artifacts/model.py +628 -0
  196. truefoundry/ml/log_types/artifacts/model_extras.py +48 -0
  197. truefoundry/ml/log_types/artifacts/utils.py +161 -0
  198. truefoundry/ml/log_types/image/__init__.py +3 -0
  199. truefoundry/ml/log_types/image/constants.py +8 -0
  200. truefoundry/ml/log_types/image/image.py +358 -0
  201. truefoundry/ml/log_types/image/image_normalizer.py +101 -0
  202. truefoundry/ml/log_types/image/types.py +68 -0
  203. truefoundry/ml/log_types/plot.py +281 -0
  204. truefoundry/ml/log_types/pydantic_base.py +10 -0
  205. truefoundry/ml/log_types/utils.py +12 -0
  206. truefoundry/ml/logger.py +17 -0
  207. truefoundry/ml/login.py +241 -0
  208. truefoundry/ml/mlfoundry_api.py +1620 -0
  209. truefoundry/ml/mlfoundry_run.py +1238 -0
  210. truefoundry/ml/run_utils.py +102 -0
  211. truefoundry/ml/services/__init__.py +0 -0
  212. truefoundry/ml/services/auth_service.py +109 -0
  213. truefoundry/ml/services/entities.py +108 -0
  214. truefoundry/ml/services/servicefoundry_service.py +35 -0
  215. truefoundry/ml/services/utils.py +122 -0
  216. truefoundry/ml/session.py +271 -0
  217. truefoundry/ml/validation_utils.py +346 -0
  218. truefoundry/pydantic_v1.py +5 -1
  219. {truefoundry-0.3.3.dist-info → truefoundry-0.4.0.dev0.dist-info}/METADATA +19 -12
  220. truefoundry-0.4.0.dev0.dist-info/RECORD +342 -0
  221. truefoundry-0.3.3.dist-info/RECORD +0 -136
  222. /truefoundry/{python_deploy_codegen.py → deploy/python_deploy_codegen.py} +0 -0
  223. {truefoundry-0.3.3.dist-info → truefoundry-0.4.0.dev0.dist-info}/WHEEL +0 -0
  224. {truefoundry-0.3.3.dist-info → truefoundry-0.4.0.dev0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,1120 @@
1
+ import math
2
+ import mmap
3
+ import os
4
+ import posixpath
5
+ import sys
6
+ import tempfile
7
+ import uuid
8
+ from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait
9
+ from shutil import rmtree
10
+ from threading import Event
11
+ from typing import (
12
+ Any,
13
+ Callable,
14
+ Dict,
15
+ Iterator,
16
+ List,
17
+ NamedTuple,
18
+ Optional,
19
+ Tuple,
20
+ Union,
21
+ )
22
+ from urllib.parse import unquote
23
+ from urllib.request import pathname2url
24
+
25
+ import requests
26
+ from rich.progress import (
27
+ BarColumn,
28
+ DownloadColumn,
29
+ Progress,
30
+ TimeElapsedColumn,
31
+ TimeRemainingColumn,
32
+ TransferSpeedColumn,
33
+ )
34
+ from tqdm.utils import CallbackIOWrapper
35
+
36
+ from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
37
+ ApiClient,
38
+ CreateMultiPartUploadForDatasetRequestDto,
39
+ CreateMultiPartUploadRequestDto,
40
+ FileInfoDto,
41
+ GetSignedURLForDatasetWriteRequestDto,
42
+ GetSignedURLsForArtifactVersionReadRequestDto,
43
+ GetSignedURLsForArtifactVersionWriteRequestDto,
44
+ GetSignedURLsForDatasetReadRequestDto,
45
+ ListFilesForArtifactVersionRequestDto,
46
+ ListFilesForArtifactVersionsResponseDto,
47
+ ListFilesForDatasetRequestDto,
48
+ ListFilesForDatasetResponseDto,
49
+ MlfoundryArtifactsApi,
50
+ MultiPartUploadDto,
51
+ MultiPartUploadResponseDto,
52
+ MultiPartUploadStorageProvider,
53
+ RunArtifactsApi,
54
+ SignedURLDto,
55
+ )
56
+ from truefoundry.ml.env_vars import DISABLE_MULTIPART_UPLOAD
57
+ from truefoundry.ml.exceptions import MlFoundryException
58
+ from truefoundry.ml.logger import logger
59
+ from truefoundry.ml.services.utils import (
60
+ augmented_raise_for_status,
61
+ cloud_storage_http_request,
62
+ )
63
+ from truefoundry.ml.session import _get_api_client
64
+ from truefoundry.pydantic_v1 import BaseModel, root_validator
65
+
66
+ _MIN_BYTES_REQUIRED_FOR_MULTIPART = 100 * 1024 * 1024
67
+ _MULTIPART_DISABLED = os.getenv(DISABLE_MULTIPART_UPLOAD, "").lower() == "true"
68
+ # GCP/S3 Maximum number of parts per upload 10,000
69
+ # Maximum number of blocks in a block blob 50,000 blocks
70
+ # TODO: This number is artificially limited now. Later
71
+ # we will ask for parts signed URI in batches rather than in a single
72
+ # API Calls:
73
+ # Create Multipart Upload (Returns maximum number of parts, size limit of
74
+ # a single part, upload id for s3 etc )
75
+ # Get me signed uris for first 500 parts
76
+ # Upload 500 parts
77
+ # Get me signed uris for the next 500 parts
78
+ # Upload 500 parts
79
+ # ...
80
+ # Finalize the Multipart upload using the finalize signed url returned
81
+ # by Create Multipart Upload or get a new one.
82
+ _MAX_NUM_PARTS_FOR_MULTIPART = 1000
83
+ # Azure Maximum size of a block in a block blob 4000 MiB
84
+ # GCP/S3 Maximum size of an individual part in a multipart upload 5 GiB
85
+ _MAX_PART_SIZE_BYTES_FOR_MULTIPART = 4 * 1024 * 1024 * 1000
86
+ _cpu_count = os.cpu_count() or 2
87
+ _MAX_WORKERS_FOR_UPLOAD = max(min(32, _cpu_count * 2), 4)
88
+ _MAX_WORKERS_FOR_DOWNLOAD = max(min(32, _cpu_count * 2), 4)
89
+ _LIST_FILES_PAGE_SIZE = 500
90
+ _GENERATE_SIGNED_URL_BATCH_SIZE = 50
91
+ DEFAULT_PRESIGNED_URL_EXPIRY_TIME = 3600
92
+
93
+
94
+ def relative_path_to_artifact_path(path):
95
+ if os.path == posixpath:
96
+ return path
97
+ if os.path.abspath(path) == path:
98
+ raise Exception("This method only works with relative paths.")
99
+ return unquote(pathname2url(path))
100
+
101
+
102
+ def _align_part_size_with_mmap_allocation_granularity(part_size: int) -> int:
103
+ modulo = part_size % mmap.ALLOCATIONGRANULARITY
104
+ if modulo == 0:
105
+ return part_size
106
+
107
+ part_size += mmap.ALLOCATIONGRANULARITY - modulo
108
+ return part_size
109
+
110
+
111
+ # Can not be less than 5 * 1024 * 1024
112
+ _PART_SIZE_BYTES_FOR_MULTIPART = _align_part_size_with_mmap_allocation_granularity(
113
+ 10 * 1024 * 1024
114
+ )
115
+
116
+
117
+ def bad_path_message(name):
118
+ return (
119
+ "Names may be treated as files in certain cases, and must not resolve to other names"
120
+ " when treated as such. This name would resolve to '%s'"
121
+ ) % posixpath.normpath(name)
122
+
123
+
124
+ def path_not_unique(name):
125
+ norm = posixpath.normpath(name)
126
+ return norm != name or norm == "." or norm.startswith("..") or norm.startswith("/")
127
+
128
+
129
+ def verify_artifact_path(artifact_path):
130
+ if artifact_path and path_not_unique(artifact_path):
131
+ raise MlFoundryException(
132
+ f"Invalid artifact path: {artifact_path!r}. {bad_path_message(artifact_path)!r}"
133
+ )
134
+
135
+
136
+ class _PartNumberEtag(NamedTuple):
137
+ part_number: int
138
+ etag: str
139
+
140
+
141
+ def _get_s3_compatible_completion_body(multi_parts: List[_PartNumberEtag]) -> str:
142
+ body = "<CompleteMultipartUpload>\n"
143
+ for part in multi_parts:
144
+ body += " <Part>\n"
145
+ body += f" <PartNumber>{part.part_number}</PartNumber>\n"
146
+ body += f" <ETag>{part.etag}</ETag>\n"
147
+ body += " </Part>\n"
148
+ body += "</CompleteMultipartUpload>"
149
+ return body
150
+
151
+
152
+ def _get_azure_blob_completion_body(block_ids: List[str]) -> str:
153
+ body = "<BlockList>\n"
154
+ for block_id in block_ids:
155
+ body += f"<Uncommitted>{block_id}</Uncommitted> "
156
+ body += "</BlockList>"
157
+ return body
158
+
159
+
160
+ class _FileMultiPartInfo(NamedTuple):
161
+ num_parts: int
162
+ part_size: int
163
+ file_size: int
164
+
165
+
166
+ def _decide_file_parts(file_path: str) -> _FileMultiPartInfo:
167
+ file_size = os.path.getsize(file_path)
168
+ if file_size < _MIN_BYTES_REQUIRED_FOR_MULTIPART or _MULTIPART_DISABLED:
169
+ return _FileMultiPartInfo(1, part_size=file_size, file_size=file_size)
170
+
171
+ ideal_num_parts = math.ceil(file_size / _PART_SIZE_BYTES_FOR_MULTIPART)
172
+ if ideal_num_parts <= _MAX_NUM_PARTS_FOR_MULTIPART:
173
+ return _FileMultiPartInfo(
174
+ ideal_num_parts,
175
+ part_size=_PART_SIZE_BYTES_FOR_MULTIPART,
176
+ file_size=file_size,
177
+ )
178
+
179
+ part_size_when_using_max_parts = math.ceil(file_size / _MAX_NUM_PARTS_FOR_MULTIPART)
180
+ part_size_when_using_max_parts = _align_part_size_with_mmap_allocation_granularity(
181
+ part_size_when_using_max_parts
182
+ )
183
+ if part_size_when_using_max_parts > _MAX_PART_SIZE_BYTES_FOR_MULTIPART:
184
+ raise ValueError(
185
+ f"file {file_path!r} is too big for upload. Multipart chunk"
186
+ f" size {part_size_when_using_max_parts} is higher"
187
+ f" than {_MAX_PART_SIZE_BYTES_FOR_MULTIPART}"
188
+ )
189
+ num_parts = math.ceil(file_size / part_size_when_using_max_parts)
190
+ return _FileMultiPartInfo(
191
+ num_parts, part_size=part_size_when_using_max_parts, file_size=file_size
192
+ )
193
+
194
+
195
+ def _signed_url_upload_file(
196
+ signed_url: SignedURLDto,
197
+ local_file: str,
198
+ progress_bar: Progress,
199
+ abort_event: Optional[Event] = None,
200
+ ):
201
+ if os.stat(local_file).st_size == 0:
202
+ with cloud_storage_http_request(
203
+ method="put", url=signed_url.signed_url, data=""
204
+ ) as response:
205
+ augmented_raise_for_status(response.raise_for_status())
206
+ return
207
+
208
+ task_progress_bar = progress_bar.add_task(
209
+ f"[green]Uploading {local_file}:", start=True
210
+ )
211
+
212
+ def callback(length):
213
+ progress_bar.update(
214
+ task_progress_bar, advance=length, total=os.stat(local_file).st_size
215
+ )
216
+ if abort_event and abort_event.is_set():
217
+ raise Exception("aborting upload")
218
+
219
+ with open(local_file, "rb") as file:
220
+ # NOTE: Azure Put Blob does not support Transfer Encoding header.
221
+ wrapped_file = CallbackIOWrapper(callback, file, "read")
222
+ with cloud_storage_http_request(
223
+ method="put", url=signed_url.signed_url, data=wrapped_file
224
+ ) as response:
225
+ augmented_raise_for_status(response)
226
+
227
+
228
+ def _download_file_using_http_uri(
229
+ http_uri,
230
+ download_path,
231
+ chunk_size=100000000,
232
+ callback: Optional[Callable[[int, int], Any]] = None,
233
+ ):
234
+ """
235
+ Downloads a file specified using the `http_uri` to a local `download_path`. This function
236
+ uses a `chunk_size` to ensure an OOM error is not raised a large file is downloaded.
237
+ Note : This function is meant to download files using presigned urls from various cloud
238
+ providers.
239
+ """
240
+ with cloud_storage_http_request(
241
+ method="get", url=http_uri, stream=True
242
+ ) as response:
243
+ augmented_raise_for_status(response)
244
+ file_size = int(response.headers.get("Content-Length", 0))
245
+ with open(download_path, "wb") as output_file:
246
+ for chunk in response.iter_content(chunk_size=chunk_size):
247
+ if callback:
248
+ callback(len(chunk), file_size)
249
+ if not chunk:
250
+ break
251
+ output_file.write(chunk)
252
+
253
+
254
+ class _CallbackIOWrapperForMultiPartUpload(CallbackIOWrapper):
255
+ def __init__(self, callback, stream, method, length: int):
256
+ self.wrapper_setattr("_length", length)
257
+ super().__init__(callback, stream, method)
258
+
259
+ def __len__(self):
260
+ return self.wrapper_getattr("_length")
261
+
262
+
263
+ def _file_part_upload(
264
+ url: str,
265
+ file_path: str,
266
+ seek: int,
267
+ length: int,
268
+ file_size: int,
269
+ abort_event: Optional[Event] = None,
270
+ method: str = "put",
271
+ ):
272
+ def callback(*_, **__):
273
+ if abort_event and abort_event.is_set():
274
+ raise Exception("aborting upload")
275
+
276
+ with open(file_path, "rb") as file:
277
+ with mmap.mmap(
278
+ file.fileno(),
279
+ length=min(file_size - seek, length),
280
+ offset=seek,
281
+ access=mmap.ACCESS_READ,
282
+ ) as mapped_file:
283
+ wrapped_file = _CallbackIOWrapperForMultiPartUpload(
284
+ callback, mapped_file, "read", len(mapped_file)
285
+ )
286
+ with cloud_storage_http_request(
287
+ method=method,
288
+ url=url,
289
+ data=wrapped_file,
290
+ ) as response:
291
+ augmented_raise_for_status(response)
292
+ return response
293
+
294
+
295
+ def _s3_compatible_multipart_upload(
296
+ multipart_upload: MultiPartUploadDto,
297
+ local_file: str,
298
+ multipart_info: _FileMultiPartInfo,
299
+ executor: ThreadPoolExecutor,
300
+ progress_bar: Progress,
301
+ abort_event: Optional[Event] = None,
302
+ ):
303
+ abort_event = abort_event or Event()
304
+ parts = []
305
+
306
+ multi_part_upload_progress = progress_bar.add_task(
307
+ f"[green]Uploading {local_file}:", start=True
308
+ )
309
+
310
+ def upload(part_number: int, seek: int) -> None:
311
+ logger.debug(
312
+ "Uploading part %d/%d of %s",
313
+ part_number,
314
+ multipart_info.num_parts,
315
+ local_file,
316
+ )
317
+ response = _file_part_upload(
318
+ url=multipart_upload.part_signed_urls[part_number].url,
319
+ file_path=local_file,
320
+ seek=seek,
321
+ length=multipart_info.part_size,
322
+ file_size=multipart_info.file_size,
323
+ abort_event=abort_event,
324
+ )
325
+ logger.debug(
326
+ "Uploaded part %d/%d of %s",
327
+ part_number,
328
+ multipart_info.num_parts,
329
+ local_file,
330
+ )
331
+ progress_bar.update(
332
+ multi_part_upload_progress,
333
+ advance=multipart_info.part_size,
334
+ total=multipart_info.file_size,
335
+ )
336
+ etag = response.headers["ETag"]
337
+ parts.append(_PartNumberEtag(etag=etag, part_number=part_number + 1))
338
+
339
+ futures: List[Future] = []
340
+ for part_number, seek in enumerate(
341
+ range(0, multipart_info.file_size, multipart_info.part_size)
342
+ ):
343
+ future = executor.submit(upload, part_number=part_number, seek=seek)
344
+ futures.append(future)
345
+
346
+ done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
347
+ if len(not_done) > 0:
348
+ abort_event.set()
349
+ for future in not_done:
350
+ future.cancel()
351
+ for future in done:
352
+ if future.exception() is not None:
353
+ raise future.exception()
354
+
355
+ logger.debug("Finalizing multipart upload of %s", local_file)
356
+ parts = sorted(parts, key=lambda part: part.part_number)
357
+ response = requests.post(
358
+ multipart_upload.finalize_signed_url.signed_url,
359
+ data=_get_s3_compatible_completion_body(parts),
360
+ timeout=2 * 60,
361
+ )
362
+ response.raise_for_status()
363
+ logger.debug("Multipart upload of %s completed", local_file)
364
+
365
+
366
+ def _azure_multi_part_upload(
367
+ multipart_upload: MultiPartUploadDto,
368
+ local_file: str,
369
+ multipart_info: _FileMultiPartInfo,
370
+ executor: ThreadPoolExecutor,
371
+ progress_bar: Progress,
372
+ abort_event: Optional[Event] = None,
373
+ ):
374
+ abort_event = abort_event or Event()
375
+
376
+ multi_part_upload_progress = progress_bar.add_task(
377
+ f"[green]Uploading {local_file}:", start=True
378
+ )
379
+
380
+ def upload(part_number: int, seek: int):
381
+ logger.debug(
382
+ "Uploading part %d/%d of %s",
383
+ part_number,
384
+ multipart_info.num_parts,
385
+ local_file,
386
+ )
387
+ _file_part_upload(
388
+ url=multipart_upload.part_signed_urls[part_number].url,
389
+ file_path=local_file,
390
+ seek=seek,
391
+ length=multipart_info.part_size,
392
+ file_size=multipart_info.file_size,
393
+ abort_event=abort_event,
394
+ )
395
+ progress_bar.update(
396
+ multi_part_upload_progress,
397
+ advance=multipart_info.part_size,
398
+ total=multipart_info.file_size,
399
+ )
400
+ logger.debug(
401
+ "Uploaded part %d/%d of %s",
402
+ part_number,
403
+ multipart_info.num_parts,
404
+ local_file,
405
+ )
406
+
407
+ futures: List[Future] = []
408
+ for part_number, seek in enumerate(
409
+ range(0, multipart_info.file_size, multipart_info.part_size)
410
+ ):
411
+ future = executor.submit(upload, part_number=part_number, seek=seek)
412
+ futures.append(future)
413
+
414
+ done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
415
+ if len(not_done) > 0:
416
+ abort_event.set()
417
+ for future in not_done:
418
+ future.cancel()
419
+ for future in done:
420
+ if future.exception() is not None:
421
+ raise future.exception()
422
+
423
+ logger.debug("Finalizing multipart upload of %s", local_file)
424
+ if multipart_upload.azure_blob_block_ids:
425
+ response = requests.put(
426
+ multipart_upload.finalize_signed_url.signed_url,
427
+ data=_get_azure_blob_completion_body(
428
+ block_ids=multipart_upload.azure_blob_block_ids
429
+ ),
430
+ timeout=2 * 60,
431
+ )
432
+ response.raise_for_status()
433
+ logger.debug("Multipart upload of %s completed", local_file)
434
+
435
+
436
+ def _any_future_has_failed(futures) -> bool:
437
+ return any(
438
+ future.done() and not future.cancelled() and future.exception() is not None
439
+ for future in futures
440
+ )
441
+
442
+
443
+ class ArtifactIdentifier(BaseModel):
444
+ artifact_version_id: Optional[uuid.UUID] = None
445
+ dataset_fqn: Optional[str] = None
446
+
447
+ @root_validator
448
+ def _check_identifier_type(cls, values: Dict[str, Any]):
449
+ if not values.get("artifact_version_id", False) and not values.get(
450
+ "dataset_fqn", False
451
+ ):
452
+ raise MlFoundryException(
453
+ "One of the version_id or dataset_fqn should be passed"
454
+ )
455
+ if values.get("artifact_version_id", False) and values.get(
456
+ "dataset_fqn", False
457
+ ):
458
+ raise MlFoundryException(
459
+ "Exactly one of version_id or dataset_fqn should be passed"
460
+ )
461
+ return values
462
+
463
+
464
+ class MlFoundryArtifactsRepository:
465
+ def __init__(
466
+ self,
467
+ artifact_identifier: ArtifactIdentifier,
468
+ api_client: Optional[ApiClient] = None,
469
+ ):
470
+ self.artifact_identifier = artifact_identifier
471
+ self._api_client = api_client or _get_api_client()
472
+ self._run_artifacts_api = RunArtifactsApi(api_client=self._api_client)
473
+ self._mlfoundry_artifacts_api = MlfoundryArtifactsApi(
474
+ api_client=self._api_client
475
+ )
476
+
477
+ def _create_download_destination(
478
+ self, src_artifact_path, dst_local_dir_path=None
479
+ ) -> str:
480
+ """
481
+ Creates a local filesystem location to be used as a destination for downloading the artifact
482
+ specified by `src_artifact_path`. The destination location is a subdirectory of the
483
+ specified `dst_local_dir_path`, which is determined according to the structure of
484
+ `src_artifact_path`. For example, if `src_artifact_path` is `dir1/file1.txt`, then the
485
+ resulting destination path is `<dst_local_dir_path>/dir1/file1.txt`. Local directories are
486
+ created for the resulting destination location if they do not exist.
487
+
488
+ :param src_artifact_path: A relative, POSIX-style path referring to an artifact stored
489
+ within the repository's artifact root location.
490
+ `src_artifact_path` should be specified relative to the
491
+ repository's artifact root location.
492
+ :param dst_local_dir_path: The absolute path to a local filesystem directory in which the
493
+ local destination path will be contained. The local destination
494
+ path may be contained in a subdirectory of `dst_root_dir` if
495
+ `src_artifact_path` contains subdirectories.
496
+ :return: The absolute path to a local filesystem location to be used as a destination
497
+ for downloading the artifact specified by `src_artifact_path`.
498
+ """
499
+ src_artifact_path = src_artifact_path.rstrip(
500
+ "/"
501
+ ) # Ensure correct dirname for trailing '/'
502
+ dirpath = posixpath.dirname(src_artifact_path)
503
+ local_dir_path = os.path.join(dst_local_dir_path, dirpath)
504
+ local_file_path = os.path.join(dst_local_dir_path, src_artifact_path)
505
+ if not os.path.exists(local_dir_path):
506
+ os.makedirs(local_dir_path, exist_ok=True)
507
+ return local_file_path
508
+
509
+ # these methods should be named list_files, log_directory, log_file, etc
510
+ def list_artifacts(
511
+ self, path=None, page_size=_LIST_FILES_PAGE_SIZE, **kwargs
512
+ ) -> Iterator[FileInfoDto]:
513
+ page_token = None
514
+ started = False
515
+ while not started or page_token is not None:
516
+ started = True
517
+ page = self.list_files(
518
+ artifact_identifier=self.artifact_identifier,
519
+ path=path,
520
+ page_size=page_size,
521
+ page_token=page_token,
522
+ )
523
+ for file_info in page.files:
524
+ yield file_info
525
+ page_token = page.next_page_token
526
+
527
+ def log_artifacts( # noqa: C901
528
+ self, local_dir, artifact_path=None, progress=None
529
+ ):
530
+ if progress is None:
531
+ progress = sys.stdout.isatty()
532
+
533
+ dest_path = artifact_path or ""
534
+ dest_path = dest_path.lstrip(posixpath.sep)
535
+
536
+ files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
537
+ files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
538
+
539
+ for root, _, file_names in os.walk(local_dir):
540
+ upload_path = dest_path
541
+ if root != local_dir:
542
+ rel_path = os.path.relpath(root, local_dir)
543
+ rel_path = relative_path_to_artifact_path(rel_path)
544
+ upload_path = posixpath.join(dest_path, rel_path)
545
+ for file_name in file_names:
546
+ local_file = os.path.join(root, file_name)
547
+ multipart_info = _decide_file_parts(local_file)
548
+
549
+ final_upload_path = upload_path or ""
550
+ final_upload_path = final_upload_path.lstrip(posixpath.sep)
551
+ final_upload_path = posixpath.join(
552
+ final_upload_path, os.path.basename(local_file)
553
+ )
554
+
555
+ if multipart_info.num_parts == 1:
556
+ files_for_normal_upload.append(
557
+ (final_upload_path, local_file, multipart_info)
558
+ )
559
+ else:
560
+ files_for_multipart_upload.append(
561
+ (final_upload_path, local_file, multipart_info)
562
+ )
563
+
564
+ abort_event = Event()
565
+
566
+ with Progress(
567
+ "[progress.description]{task.description}",
568
+ BarColumn(),
569
+ "[progress.percentage]{task.percentage:>3.0f}%",
570
+ DownloadColumn(),
571
+ TransferSpeedColumn(),
572
+ TimeRemainingColumn(),
573
+ TimeElapsedColumn(),
574
+ refresh_per_second=1,
575
+ disable=not progress,
576
+ ) as progress_bar, ThreadPoolExecutor(
577
+ max_workers=_MAX_WORKERS_FOR_UPLOAD
578
+ ) as executor:
579
+ futures: List[Future] = []
580
+ # Note: While this batching is beneficial when there is a large number of files, there is also
581
+ # a rare case risk of the signed url expiring before a request is made to it
582
+ _batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
583
+ for start_idx in range(0, len(files_for_normal_upload), _batch_size):
584
+ end_idx = min(start_idx + _batch_size, len(files_for_normal_upload))
585
+ if _any_future_has_failed(futures):
586
+ break
587
+ logger.debug("Generating write signed urls for a batch ...")
588
+ remote_file_paths = [
589
+ files_for_normal_upload[idx][0] for idx in range(start_idx, end_idx)
590
+ ]
591
+ signed_urls = self.get_signed_urls_for_write(
592
+ artifact_identifier=self.artifact_identifier,
593
+ paths=remote_file_paths,
594
+ )
595
+ for idx, signed_url in zip(range(start_idx, end_idx), signed_urls):
596
+ (
597
+ upload_path,
598
+ local_file,
599
+ multipart_info,
600
+ ) = files_for_normal_upload[idx]
601
+ future = executor.submit(
602
+ self._log_artifact,
603
+ local_file=local_file,
604
+ artifact_path=upload_path,
605
+ multipart_info=multipart_info,
606
+ signed_url=signed_url,
607
+ abort_event=abort_event,
608
+ executor_for_multipart_upload=None,
609
+ progress_bar=progress_bar,
610
+ )
611
+ futures.append(future)
612
+
613
+ done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
614
+ if len(not_done) > 0:
615
+ abort_event.set()
616
+ for future in not_done:
617
+ future.cancel()
618
+ for future in done:
619
+ if future.exception() is not None:
620
+ raise future.exception()
621
+
622
+ for (
623
+ upload_path,
624
+ local_file,
625
+ multipart_info,
626
+ ) in files_for_multipart_upload:
627
+ self._log_artifact(
628
+ local_file=local_file,
629
+ artifact_path=upload_path,
630
+ signed_url=None,
631
+ multipart_info=multipart_info,
632
+ executor_for_multipart_upload=executor,
633
+ progress_bar=progress_bar,
634
+ )
635
+
636
+ def _normal_upload(
637
+ self,
638
+ local_file: str,
639
+ artifact_path: str,
640
+ signed_url: Optional[SignedURLDto],
641
+ progress_bar: Progress,
642
+ abort_event: Optional[Event] = None,
643
+ ):
644
+ if not signed_url:
645
+ signed_url = self.get_signed_urls_for_write(
646
+ artifact_identifier=self.artifact_identifier, paths=[artifact_path]
647
+ )[0]
648
+
649
+ if progress_bar.disable:
650
+ logger.info("Uploading %s to %s", local_file, artifact_path)
651
+
652
+ _signed_url_upload_file(
653
+ signed_url=signed_url,
654
+ local_file=local_file,
655
+ abort_event=abort_event,
656
+ progress_bar=progress_bar,
657
+ )
658
+ logger.debug("Uploaded %s to %s", local_file, artifact_path)
659
+
660
+ def _multipart_upload(
661
+ self,
662
+ local_file: str,
663
+ artifact_path: str,
664
+ multipart_info: _FileMultiPartInfo,
665
+ executor: ThreadPoolExecutor,
666
+ progress_bar: Progress,
667
+ abort_event: Optional[Event] = None,
668
+ ):
669
+ if progress_bar.disable:
670
+ logger.info(
671
+ "Uploading %s to %s using multipart upload", local_file, artifact_path
672
+ )
673
+
674
+ multipart_upload = self.create_multipart_upload_for_identifier(
675
+ artifact_identifier=self.artifact_identifier,
676
+ path=artifact_path,
677
+ num_parts=multipart_info.num_parts,
678
+ )
679
+ if (
680
+ multipart_upload.storage_provider
681
+ is MultiPartUploadStorageProvider.S3_COMPATIBLE
682
+ ):
683
+ _s3_compatible_multipart_upload(
684
+ multipart_upload=multipart_upload,
685
+ local_file=local_file,
686
+ executor=executor,
687
+ multipart_info=multipart_info,
688
+ abort_event=abort_event,
689
+ progress_bar=progress_bar,
690
+ )
691
+ elif (
692
+ multipart_upload.storage_provider
693
+ is MultiPartUploadStorageProvider.AZURE_BLOB
694
+ ):
695
+ _azure_multi_part_upload(
696
+ multipart_upload=multipart_upload,
697
+ local_file=local_file,
698
+ executor=executor,
699
+ multipart_info=multipart_info,
700
+ abort_event=abort_event,
701
+ progress_bar=progress_bar,
702
+ )
703
+ else:
704
+ raise NotImplementedError()
705
+
706
+ def _log_artifact(
707
+ self,
708
+ local_file: str,
709
+ artifact_path: str,
710
+ multipart_info: _FileMultiPartInfo,
711
+ progress_bar: Progress,
712
+ signed_url: Optional[SignedURLDto] = None,
713
+ abort_event: Optional[Event] = None,
714
+ executor_for_multipart_upload: Optional[ThreadPoolExecutor] = None,
715
+ ):
716
+ if multipart_info.num_parts == 1:
717
+ return self._normal_upload(
718
+ local_file=local_file,
719
+ artifact_path=artifact_path,
720
+ signed_url=signed_url,
721
+ abort_event=abort_event,
722
+ progress_bar=progress_bar,
723
+ )
724
+
725
+ if not executor_for_multipart_upload:
726
+ with ThreadPoolExecutor(max_workers=_MAX_WORKERS_FOR_UPLOAD) as executor:
727
+ return self._multipart_upload(
728
+ local_file=local_file,
729
+ artifact_path=artifact_path,
730
+ executor=executor,
731
+ multipart_info=multipart_info,
732
+ progress_bar=progress_bar,
733
+ )
734
+
735
+ return self._multipart_upload(
736
+ local_file=local_file,
737
+ artifact_path=artifact_path,
738
+ executor=executor_for_multipart_upload,
739
+ multipart_info=multipart_info,
740
+ progress_bar=progress_bar,
741
+ )
742
+
743
+ def log_artifact(self, local_file: str, artifact_path: Optional[str] = None):
744
+ upload_path = artifact_path or ""
745
+ upload_path = upload_path.lstrip(posixpath.sep)
746
+ upload_path = posixpath.join(upload_path, os.path.basename(local_file))
747
+ with Progress(
748
+ "[progress.description]{task.description}",
749
+ BarColumn(),
750
+ "[progress.percentage]{task.percentage:>3.0f}%",
751
+ DownloadColumn(),
752
+ TransferSpeedColumn(),
753
+ TimeRemainingColumn(),
754
+ TimeElapsedColumn(),
755
+ refresh_per_second=1,
756
+ disable=True,
757
+ ) as progress_bar:
758
+ self._log_artifact(
759
+ local_file=local_file,
760
+ artifact_path=upload_path,
761
+ multipart_info=_decide_file_parts(local_file),
762
+ progress_bar=progress_bar,
763
+ )
764
+
765
+ def _is_directory(self, artifact_path):
766
+ for _ in self.list_artifacts(artifact_path, page_size=3):
767
+ return True
768
+ return False
769
+
770
+ def download_artifacts( # noqa: C901
771
+ self,
772
+ artifact_path,
773
+ dst_path=None,
774
+ overwrite: bool = False,
775
+ progress: Optional[bool] = None,
776
+ ):
777
+ """
778
+ Download an artifact file or directory to a local directory if applicable, and return a
779
+ local path for it. The caller is responsible for managing the lifecycle of the downloaded artifacts.
780
+
781
+ Args:
782
+ artifact_path: Relative source path to the desired artifacts.
783
+ dst_path: Absolute path of the local filesystem destination directory to which to
784
+ download the specified artifacts. This directory must already exist.
785
+ If unspecified, the artifacts will either be downloaded to a new
786
+ uniquely-named directory.
787
+ overwrite: if to overwrite the files at/inside `dst_path` if they exist
788
+ progress: Show or hide progress bar
789
+
790
+ Returns:
791
+ str: Absolute path of the local filesystem location containing the desired artifacts.
792
+ """
793
+
794
+ if progress is None:
795
+ progress = sys.stdout.isatty()
796
+
797
+ is_dir_temp = False
798
+ if dst_path is None:
799
+ dst_path = tempfile.mkdtemp()
800
+ is_dir_temp = True
801
+
802
+ dst_path = os.path.abspath(dst_path)
803
+ if is_dir_temp:
804
+ logger.info(
805
+ f"Using temporary directory {dst_path} as the download directory"
806
+ )
807
+
808
+ if not os.path.exists(dst_path):
809
+ raise MlFoundryException(
810
+ message=(
811
+ "The destination path for downloaded artifacts does not"
812
+ " exist! Destination path: {dst_path}".format(dst_path=dst_path)
813
+ ),
814
+ )
815
+ elif not os.path.isdir(dst_path):
816
+ raise MlFoundryException(
817
+ message=(
818
+ "The destination path for downloaded artifacts must be a directory!"
819
+ " Destination path: {dst_path}".format(dst_path=dst_path)
820
+ ),
821
+ )
822
+
823
+ progress_bar = Progress(
824
+ "[progress.description]{task.description}",
825
+ BarColumn(),
826
+ "[progress.percentage]{task.percentage:>3.0f}%",
827
+ DownloadColumn(),
828
+ TransferSpeedColumn(),
829
+ TimeRemainingColumn(),
830
+ TimeElapsedColumn(),
831
+ refresh_per_second=1,
832
+ disable=not progress,
833
+ )
834
+
835
+ try:
836
+ progress_bar.start()
837
+ # Check if the artifacts points to a directory
838
+ if self._is_directory(artifact_path):
839
+ futures: List[Future] = []
840
+ file_paths: List[Tuple[str, str]] = []
841
+ abort_event = Event()
842
+
843
+ # Check if any file is being overwritten before downloading them
844
+ for file_path, download_dest_path in self._get_file_paths_recur(
845
+ src_artifact_dir_path=artifact_path, dst_local_dir_path=dst_path
846
+ ):
847
+ final_file_path = os.path.join(download_dest_path, file_path)
848
+
849
+ # There would be no overwrite if temp directory is being used
850
+ if (
851
+ not is_dir_temp
852
+ and os.path.exists(final_file_path)
853
+ and not overwrite
854
+ ):
855
+ raise MlFoundryException(
856
+ f"File already exists at {final_file_path}, aborting download "
857
+ f"(set `overwrite` flag to overwrite this and any subsequent files)"
858
+ )
859
+ file_paths.append((file_path, download_dest_path))
860
+
861
+ with ThreadPoolExecutor(_MAX_WORKERS_FOR_DOWNLOAD) as executor:
862
+ # Note: While this batching is beneficial when there is a large number of files, there is also
863
+ # a rare case risk of the signed url expiring before a request is made to it
864
+ batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
865
+ for start_idx in range(0, len(file_paths), batch_size):
866
+ end_idx = min(start_idx + batch_size, len(file_paths))
867
+ if _any_future_has_failed(futures):
868
+ break
869
+ logger.debug("Generating read signed urls for a batch ...")
870
+ remote_file_paths = [
871
+ file_paths[idx][0] for idx in range(start_idx, end_idx)
872
+ ]
873
+ signed_urls = self.get_signed_urls_for_read(
874
+ artifact_identifier=self.artifact_identifier,
875
+ paths=remote_file_paths,
876
+ )
877
+ for idx, signed_url in zip(
878
+ range(start_idx, end_idx), signed_urls
879
+ ):
880
+ file_path, download_dest_path = file_paths[idx]
881
+ future = executor.submit(
882
+ self._download_artifact,
883
+ src_artifact_path=file_path,
884
+ dst_local_dir_path=download_dest_path,
885
+ signed_url=signed_url,
886
+ abort_event=abort_event,
887
+ progress_bar=progress_bar,
888
+ )
889
+ futures.append(future)
890
+
891
+ done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
892
+ if len(not_done) > 0:
893
+ abort_event.set()
894
+ for future in not_done:
895
+ future.cancel()
896
+ for future in done:
897
+ if future.exception() is not None:
898
+ raise future.exception()
899
+
900
+ output_dir = os.path.join(dst_path, artifact_path)
901
+ return output_dir
902
+ else:
903
+ return self._download_artifact(
904
+ src_artifact_path=artifact_path,
905
+ dst_local_dir_path=dst_path,
906
+ signed_url=None,
907
+ progress_bar=progress_bar,
908
+ )
909
+ except Exception as err:
910
+ if is_dir_temp:
911
+ logger.info(
912
+ f"Error encountered, removing temporary download directory at {dst_path}"
913
+ )
914
+ rmtree(dst_path) # remove temp directory alongside it's contents
915
+ raise err
916
+
917
+ finally:
918
+ progress_bar.stop()
919
+
920
+ # noinspection PyMethodOverriding
921
+ def _download_file(
922
+ self,
923
+ remote_file_path: str,
924
+ local_path: str,
925
+ progress_bar: Optional[Progress],
926
+ signed_url: Optional[SignedURLDto],
927
+ abort_event: Optional[Event] = None,
928
+ ):
929
+ if not remote_file_path:
930
+ raise MlFoundryException(
931
+ f"remote_file_path cannot be None or empty str {remote_file_path}"
932
+ )
933
+ if not signed_url:
934
+ signed_url = self.get_signed_urls_for_read(
935
+ artifact_identifier=self.artifact_identifier, paths=[remote_file_path]
936
+ )[0]
937
+
938
+ if progress_bar is None or not progress_bar.disable:
939
+ logger.info("Downloading %s to %s", remote_file_path, local_path)
940
+
941
+ if progress_bar is not None:
942
+ download_progress_bar = progress_bar.add_task(
943
+ f"[green]Downloading to {remote_file_path}:", start=True
944
+ )
945
+
946
+ def callback(chunk, total_file_size):
947
+ if progress_bar is not None:
948
+ progress_bar.update(
949
+ download_progress_bar,
950
+ advance=chunk,
951
+ total=total_file_size,
952
+ )
953
+ if abort_event and abort_event.is_set():
954
+ raise Exception("aborting download")
955
+
956
+ _download_file_using_http_uri(
957
+ http_uri=signed_url.signed_url,
958
+ download_path=local_path,
959
+ callback=callback,
960
+ )
961
+ logger.debug("Downloaded %s to %s", remote_file_path, local_path)
962
+
963
+ def _download_artifact(
964
+ self,
965
+ src_artifact_path,
966
+ dst_local_dir_path,
967
+ signed_url: Optional[SignedURLDto],
968
+ progress_bar: Optional[Progress] = None,
969
+ abort_event=None,
970
+ ) -> str:
971
+ """
972
+ Download the file artifact specified by `src_artifact_path` to the local filesystem
973
+ directory specified by `dst_local_dir_path`.
974
+ :param src_artifact_path: A relative, POSIX-style path referring to a file artifact
975
+ stored within the repository's artifact root location.
976
+ `src_artifact_path` should be specified relative to the
977
+ repository's artifact root location.
978
+ :param dst_local_dir_path: Absolute path of the local filesystem destination directory
979
+ to which to download the specified artifact. The downloaded
980
+ artifact may be written to a subdirectory of
981
+ `dst_local_dir_path` if `src_artifact_path` contains
982
+ subdirectories.
983
+ :param progress_bar: An instance of a Rich progress bar used to visually display the
984
+ progress of the file download.
985
+ :return: A local filesystem path referring to the downloaded file.
986
+ """
987
+ local_destination_file_path = self._create_download_destination(
988
+ src_artifact_path=src_artifact_path, dst_local_dir_path=dst_local_dir_path
989
+ )
990
+ self._download_file(
991
+ remote_file_path=src_artifact_path,
992
+ local_path=local_destination_file_path,
993
+ signed_url=signed_url,
994
+ abort_event=abort_event,
995
+ progress_bar=progress_bar,
996
+ )
997
+ return local_destination_file_path
998
+
999
+ def _get_file_paths_recur(self, src_artifact_dir_path, dst_local_dir_path):
1000
+ local_dir = os.path.join(dst_local_dir_path, src_artifact_dir_path)
1001
+ dir_content = [ # prevent infinite loop, sometimes the dir is recursively included
1002
+ file_info
1003
+ for file_info in self.list_artifacts(src_artifact_dir_path)
1004
+ if file_info.path != "." and file_info.path != src_artifact_dir_path
1005
+ ]
1006
+ if not dir_content: # empty dir
1007
+ if not os.path.exists(local_dir):
1008
+ os.makedirs(local_dir, exist_ok=True)
1009
+ else:
1010
+ for file_info in dir_content:
1011
+ if file_info.is_dir:
1012
+ yield from self._get_file_paths_recur(
1013
+ src_artifact_dir_path=file_info.path,
1014
+ dst_local_dir_path=dst_local_dir_path,
1015
+ )
1016
+ else:
1017
+ yield file_info.path, dst_local_dir_path
1018
+
1019
+ # TODO (chiragjn): Refactor these methods - if else is very inconvenient
1020
+ def get_signed_urls_for_read(
1021
+ self,
1022
+ artifact_identifier: ArtifactIdentifier,
1023
+ paths,
1024
+ ) -> List[SignedURLDto]:
1025
+ if artifact_identifier.artifact_version_id:
1026
+ signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_read_post(
1027
+ get_signed_urls_for_artifact_version_read_request_dto=GetSignedURLsForArtifactVersionReadRequestDto(
1028
+ id=str(artifact_identifier.artifact_version_id), paths=paths
1029
+ )
1030
+ )
1031
+ signed_urls = signed_urls_response.signed_urls
1032
+ elif artifact_identifier.dataset_fqn:
1033
+ signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_dataset_read_post(
1034
+ get_signed_urls_for_dataset_read_request_dto=GetSignedURLsForDatasetReadRequestDto(
1035
+ dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
1036
+ )
1037
+ )
1038
+ signed_urls = signed_urls_dataset_response.signed_urls
1039
+ else:
1040
+ raise ValueError(
1041
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1042
+ )
1043
+ return signed_urls
1044
+
1045
+ def get_signed_urls_for_write(
1046
+ self,
1047
+ artifact_identifier: ArtifactIdentifier,
1048
+ paths: List[str],
1049
+ ) -> List[SignedURLDto]:
1050
+ if artifact_identifier.artifact_version_id:
1051
+ signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_write_post(
1052
+ get_signed_urls_for_artifact_version_write_request_dto=GetSignedURLsForArtifactVersionWriteRequestDto(
1053
+ id=str(artifact_identifier.artifact_version_id), paths=paths
1054
+ )
1055
+ )
1056
+ signed_urls = signed_urls_response.signed_urls
1057
+ elif artifact_identifier.dataset_fqn:
1058
+ signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_for_dataset_write_post(
1059
+ get_signed_url_for_dataset_write_request_dto=GetSignedURLForDatasetWriteRequestDto(
1060
+ dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
1061
+ )
1062
+ )
1063
+ signed_urls = signed_urls_dataset_response.signed_urls
1064
+ else:
1065
+ raise ValueError(
1066
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1067
+ )
1068
+ return signed_urls
1069
+
1070
+ def create_multipart_upload_for_identifier(
1071
+ self,
1072
+ artifact_identifier: ArtifactIdentifier,
1073
+ path,
1074
+ num_parts,
1075
+ ) -> MultiPartUploadDto:
1076
+ if artifact_identifier.artifact_version_id:
1077
+ create_multipart_response: MultiPartUploadResponseDto = self._mlfoundry_artifacts_api.create_multi_part_upload_post(
1078
+ create_multi_part_upload_request_dto=CreateMultiPartUploadRequestDto(
1079
+ artifact_version_id=str(artifact_identifier.artifact_version_id),
1080
+ path=path,
1081
+ num_parts=num_parts,
1082
+ )
1083
+ )
1084
+ multipart_upload = create_multipart_response.multipart_upload
1085
+ elif artifact_identifier.dataset_fqn:
1086
+ create_multipart_for_dataset_response = self._mlfoundry_artifacts_api.create_multipart_upload_for_dataset_post(
1087
+ create_multi_part_upload_for_dataset_request_dto=CreateMultiPartUploadForDatasetRequestDto(
1088
+ dataset_fqn=artifact_identifier.dataset_fqn,
1089
+ path=path,
1090
+ num_parts=num_parts,
1091
+ )
1092
+ )
1093
+ multipart_upload = create_multipart_for_dataset_response.multipart_upload
1094
+ else:
1095
+ raise ValueError(
1096
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1097
+ )
1098
+ return multipart_upload
1099
+
1100
+ def list_files(
1101
+ self, artifact_identifier: ArtifactIdentifier, path, page_size, page_token
1102
+ ) -> Union[ListFilesForDatasetResponseDto, ListFilesForArtifactVersionsResponseDto]:
1103
+ if artifact_identifier.dataset_fqn:
1104
+ return self._mlfoundry_artifacts_api.list_files_for_dataset_post(
1105
+ list_files_for_dataset_request_dto=ListFilesForDatasetRequestDto(
1106
+ dataset_fqn=artifact_identifier.dataset_fqn,
1107
+ path=path,
1108
+ max_results=page_size,
1109
+ page_token=page_token,
1110
+ )
1111
+ )
1112
+ else:
1113
+ return self._mlfoundry_artifacts_api.list_files_for_artifact_version_post(
1114
+ list_files_for_artifact_version_request_dto=ListFilesForArtifactVersionRequestDto(
1115
+ id=str(artifact_identifier.artifact_version_id),
1116
+ path=path,
1117
+ max_results=page_size,
1118
+ page_token=page_token,
1119
+ )
1120
+ )