truefoundry 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (253) hide show
  1. truefoundry/__init__.py +2 -0
  2. truefoundry/autodeploy/agents/developer.py +1 -1
  3. truefoundry/autodeploy/agents/project_identifier.py +2 -2
  4. truefoundry/autodeploy/agents/tester.py +1 -1
  5. truefoundry/autodeploy/cli.py +1 -1
  6. truefoundry/autodeploy/tools/list_files.py +1 -1
  7. truefoundry/cli/__main__.py +3 -17
  8. truefoundry/common/__init__.py +0 -0
  9. truefoundry/{deploy/lib/auth → common}/auth_service_client.py +50 -40
  10. truefoundry/common/constants.py +12 -0
  11. truefoundry/{deploy/lib/auth → common}/credential_file_manager.py +7 -7
  12. truefoundry/{deploy/lib/auth → common}/credential_provider.py +10 -23
  13. truefoundry/common/entities.py +124 -0
  14. truefoundry/common/exceptions.py +12 -0
  15. truefoundry/common/request_utils.py +84 -0
  16. truefoundry/common/servicefoundry_client.py +91 -0
  17. truefoundry/common/utils.py +56 -0
  18. truefoundry/deploy/auto_gen/models.py +4 -6
  19. truefoundry/deploy/cli/cli.py +3 -1
  20. truefoundry/deploy/cli/commands/apply_command.py +1 -1
  21. truefoundry/deploy/cli/commands/build_command.py +1 -1
  22. truefoundry/deploy/cli/commands/deploy_command.py +1 -1
  23. truefoundry/deploy/cli/commands/login_command.py +2 -2
  24. truefoundry/deploy/cli/commands/patch_application_command.py +1 -1
  25. truefoundry/deploy/cli/commands/patch_command.py +1 -1
  26. truefoundry/deploy/cli/commands/terminate_comand.py +1 -1
  27. truefoundry/deploy/cli/util.py +1 -1
  28. truefoundry/deploy/function_service/remote/remote.py +1 -1
  29. truefoundry/deploy/lib/auth/servicefoundry_session.py +2 -2
  30. truefoundry/deploy/lib/clients/servicefoundry_client.py +120 -159
  31. truefoundry/deploy/lib/const.py +1 -35
  32. truefoundry/deploy/lib/exceptions.py +0 -16
  33. truefoundry/deploy/lib/model/entity.py +1 -112
  34. truefoundry/deploy/lib/session.py +14 -42
  35. truefoundry/deploy/lib/util.py +0 -37
  36. truefoundry/{python_deploy_codegen.py → deploy/python_deploy_codegen.py} +2 -2
  37. truefoundry/deploy/v2/lib/deploy.py +3 -3
  38. truefoundry/deploy/v2/lib/deployable_patched_models.py +1 -1
  39. truefoundry/langchain/truefoundry_chat.py +1 -1
  40. truefoundry/langchain/truefoundry_embeddings.py +1 -1
  41. truefoundry/langchain/truefoundry_llm.py +1 -1
  42. truefoundry/langchain/utils.py +0 -41
  43. truefoundry/ml/__init__.py +37 -6
  44. truefoundry/ml/artifact/__init__.py +0 -0
  45. truefoundry/ml/artifact/truefoundry_artifact_repo.py +1161 -0
  46. truefoundry/ml/autogen/__init__.py +0 -0
  47. truefoundry/ml/autogen/client/__init__.py +370 -0
  48. truefoundry/ml/autogen/client/api/__init__.py +16 -0
  49. truefoundry/ml/autogen/client/api/auth_api.py +184 -0
  50. truefoundry/ml/autogen/client/api/deprecated_api.py +605 -0
  51. truefoundry/ml/autogen/client/api/experiments_api.py +1944 -0
  52. truefoundry/ml/autogen/client/api/health_api.py +299 -0
  53. truefoundry/ml/autogen/client/api/metrics_api.py +371 -0
  54. truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +7213 -0
  55. truefoundry/ml/autogen/client/api/python_deployment_config_api.py +201 -0
  56. truefoundry/ml/autogen/client/api/run_artifacts_api.py +231 -0
  57. truefoundry/ml/autogen/client/api/runs_api.py +2919 -0
  58. truefoundry/ml/autogen/client/api_client.py +822 -0
  59. truefoundry/ml/autogen/client/api_response.py +30 -0
  60. truefoundry/ml/autogen/client/configuration.py +489 -0
  61. truefoundry/ml/autogen/client/exceptions.py +161 -0
  62. truefoundry/ml/autogen/client/models/__init__.py +341 -0
  63. truefoundry/ml/autogen/client/models/add_custom_metrics_to_model_version_request_dto.py +69 -0
  64. truefoundry/ml/autogen/client/models/add_features_to_model_version_request_dto.py +83 -0
  65. truefoundry/ml/autogen/client/models/agent.py +125 -0
  66. truefoundry/ml/autogen/client/models/agent_app.py +118 -0
  67. truefoundry/ml/autogen/client/models/agent_open_api_tool.py +143 -0
  68. truefoundry/ml/autogen/client/models/agent_open_api_tool_with_fqn.py +144 -0
  69. truefoundry/ml/autogen/client/models/agent_with_fqn.py +127 -0
  70. truefoundry/ml/autogen/client/models/artifact_dto.py +115 -0
  71. truefoundry/ml/autogen/client/models/artifact_response_dto.py +75 -0
  72. truefoundry/ml/autogen/client/models/artifact_type.py +39 -0
  73. truefoundry/ml/autogen/client/models/artifact_version_dto.py +141 -0
  74. truefoundry/ml/autogen/client/models/artifact_version_response_dto.py +77 -0
  75. truefoundry/ml/autogen/client/models/artifact_version_status.py +35 -0
  76. truefoundry/ml/autogen/client/models/assistant_message.py +89 -0
  77. truefoundry/ml/autogen/client/models/authorize_user_for_model_request_dto.py +69 -0
  78. truefoundry/ml/autogen/client/models/authorize_user_for_model_version_request_dto.py +69 -0
  79. truefoundry/ml/autogen/client/models/blob_storage_reference.py +93 -0
  80. truefoundry/ml/autogen/client/models/body_get_search_runs_get.py +72 -0
  81. truefoundry/ml/autogen/client/models/chat_prompt.py +156 -0
  82. truefoundry/ml/autogen/client/models/chat_prompt_messages_inner.py +171 -0
  83. truefoundry/ml/autogen/client/models/columns_dto.py +73 -0
  84. truefoundry/ml/autogen/client/models/content.py +153 -0
  85. truefoundry/ml/autogen/client/models/content1.py +153 -0
  86. truefoundry/ml/autogen/client/models/content2.py +174 -0
  87. truefoundry/ml/autogen/client/models/content2_any_of_inner.py +150 -0
  88. truefoundry/ml/autogen/client/models/create_artifact_request_dto.py +74 -0
  89. truefoundry/ml/autogen/client/models/create_artifact_response_dto.py +65 -0
  90. truefoundry/ml/autogen/client/models/create_artifact_version_request_dto.py +74 -0
  91. truefoundry/ml/autogen/client/models/create_artifact_version_response_dto.py +65 -0
  92. truefoundry/ml/autogen/client/models/create_dataset_request_dto.py +76 -0
  93. truefoundry/ml/autogen/client/models/create_experiment_request_dto.py +94 -0
  94. truefoundry/ml/autogen/client/models/create_experiment_response_dto.py +67 -0
  95. truefoundry/ml/autogen/client/models/create_model_version_request_dto.py +95 -0
  96. truefoundry/ml/autogen/client/models/create_multi_part_upload_for_dataset_request_dto.py +73 -0
  97. truefoundry/ml/autogen/client/models/create_multi_part_upload_for_dataset_response_dto.py +79 -0
  98. truefoundry/ml/autogen/client/models/create_multi_part_upload_request_dto.py +73 -0
  99. truefoundry/ml/autogen/client/models/create_python_deployment_config_request_dto.py +72 -0
  100. truefoundry/ml/autogen/client/models/create_python_deployment_config_response_dto.py +67 -0
  101. truefoundry/ml/autogen/client/models/create_run_request_dto.py +97 -0
  102. truefoundry/ml/autogen/client/models/create_run_response_dto.py +75 -0
  103. truefoundry/ml/autogen/client/models/dataset_dto.py +112 -0
  104. truefoundry/ml/autogen/client/models/dataset_response_dto.py +75 -0
  105. truefoundry/ml/autogen/client/models/delete_artifact_versions_request_dto.py +65 -0
  106. truefoundry/ml/autogen/client/models/delete_dataset_request_dto.py +74 -0
  107. truefoundry/ml/autogen/client/models/delete_model_version_request_dto.py +65 -0
  108. truefoundry/ml/autogen/client/models/delete_run_request.py +65 -0
  109. truefoundry/ml/autogen/client/models/delete_tag_request_dto.py +68 -0
  110. truefoundry/ml/autogen/client/models/experiment_dto.py +127 -0
  111. truefoundry/ml/autogen/client/models/experiment_id_request_dto.py +67 -0
  112. truefoundry/ml/autogen/client/models/experiment_response_dto.py +75 -0
  113. truefoundry/ml/autogen/client/models/experiment_tag_dto.py +69 -0
  114. truefoundry/ml/autogen/client/models/feature_dto.py +68 -0
  115. truefoundry/ml/autogen/client/models/feature_value_type.py +35 -0
  116. truefoundry/ml/autogen/client/models/file_info_dto.py +76 -0
  117. truefoundry/ml/autogen/client/models/finalize_artifact_version_request_dto.py +101 -0
  118. truefoundry/ml/autogen/client/models/get_experiment_response_dto.py +88 -0
  119. truefoundry/ml/autogen/client/models/get_latest_run_log_response_dto.py +75 -0
  120. truefoundry/ml/autogen/client/models/get_metric_history_response.py +79 -0
  121. truefoundry/ml/autogen/client/models/get_signed_url_for_dataset_write_request_dto.py +68 -0
  122. truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_read_request_dto.py +68 -0
  123. truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_read_response_dto.py +81 -0
  124. truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_write_request_dto.py +69 -0
  125. truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_write_response_dto.py +83 -0
  126. truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_read_request_dto.py +68 -0
  127. truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_read_response_dto.py +81 -0
  128. truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_write_response_dto.py +81 -0
  129. truefoundry/ml/autogen/client/models/get_tenant_id_response_dto.py +73 -0
  130. truefoundry/ml/autogen/client/models/http_validation_error.py +82 -0
  131. truefoundry/ml/autogen/client/models/image_content_part.py +87 -0
  132. truefoundry/ml/autogen/client/models/image_url.py +75 -0
  133. truefoundry/ml/autogen/client/models/internal_metadata.py +180 -0
  134. truefoundry/ml/autogen/client/models/latest_run_log_dto.py +78 -0
  135. truefoundry/ml/autogen/client/models/list_artifact_versions_request_dto.py +107 -0
  136. truefoundry/ml/autogen/client/models/list_artifact_versions_response_dto.py +87 -0
  137. truefoundry/ml/autogen/client/models/list_artifacts_request_dto.py +96 -0
  138. truefoundry/ml/autogen/client/models/list_artifacts_response_dto.py +86 -0
  139. truefoundry/ml/autogen/client/models/list_colums_response_dto.py +75 -0
  140. truefoundry/ml/autogen/client/models/list_datasets_request_dto.py +78 -0
  141. truefoundry/ml/autogen/client/models/list_datasets_response_dto.py +86 -0
  142. truefoundry/ml/autogen/client/models/list_experiments_response_dto.py +86 -0
  143. truefoundry/ml/autogen/client/models/list_files_for_artifact_version_request_dto.py +76 -0
  144. truefoundry/ml/autogen/client/models/list_files_for_artifact_versions_response_dto.py +82 -0
  145. truefoundry/ml/autogen/client/models/list_files_for_dataset_request_dto.py +76 -0
  146. truefoundry/ml/autogen/client/models/list_files_for_dataset_response_dto.py +82 -0
  147. truefoundry/ml/autogen/client/models/list_latest_run_logs_response_dto.py +82 -0
  148. truefoundry/ml/autogen/client/models/list_metric_history_request_dto.py +69 -0
  149. truefoundry/ml/autogen/client/models/list_metric_history_response_dto.py +84 -0
  150. truefoundry/ml/autogen/client/models/list_model_version_response_dto.py +87 -0
  151. truefoundry/ml/autogen/client/models/list_model_versions_request_dto.py +93 -0
  152. truefoundry/ml/autogen/client/models/list_models_request_dto.py +89 -0
  153. truefoundry/ml/autogen/client/models/list_models_response_dto.py +84 -0
  154. truefoundry/ml/autogen/client/models/list_run_artifacts_response_dto.py +84 -0
  155. truefoundry/ml/autogen/client/models/list_run_logs_response_dto.py +82 -0
  156. truefoundry/ml/autogen/client/models/list_seed_experiments_response_dto.py +81 -0
  157. truefoundry/ml/autogen/client/models/log_batch_request_dto.py +106 -0
  158. truefoundry/ml/autogen/client/models/log_metric_request_dto.py +80 -0
  159. truefoundry/ml/autogen/client/models/log_param_request_dto.py +76 -0
  160. truefoundry/ml/autogen/client/models/method.py +37 -0
  161. truefoundry/ml/autogen/client/models/metric_collection_dto.py +82 -0
  162. truefoundry/ml/autogen/client/models/metric_dto.py +76 -0
  163. truefoundry/ml/autogen/client/models/mime_type.py +37 -0
  164. truefoundry/ml/autogen/client/models/model_configuration.py +103 -0
  165. truefoundry/ml/autogen/client/models/model_dto.py +122 -0
  166. truefoundry/ml/autogen/client/models/model_response_dto.py +75 -0
  167. truefoundry/ml/autogen/client/models/model_schema_dto.py +85 -0
  168. truefoundry/ml/autogen/client/models/model_version_dto.py +170 -0
  169. truefoundry/ml/autogen/client/models/model_version_response_dto.py +75 -0
  170. truefoundry/ml/autogen/client/models/multi_part_upload_dto.py +107 -0
  171. truefoundry/ml/autogen/client/models/multi_part_upload_response_dto.py +79 -0
  172. truefoundry/ml/autogen/client/models/multi_part_upload_storage_provider.py +34 -0
  173. truefoundry/ml/autogen/client/models/notify_artifact_version_failure_dto.py +65 -0
  174. truefoundry/ml/autogen/client/models/openapi_spec.py +152 -0
  175. truefoundry/ml/autogen/client/models/param_dto.py +66 -0
  176. truefoundry/ml/autogen/client/models/parameters.py +84 -0
  177. truefoundry/ml/autogen/client/models/prediction_type.py +34 -0
  178. truefoundry/ml/autogen/client/models/resolve_agent_app_response_dto.py +75 -0
  179. truefoundry/ml/autogen/client/models/restore_run_request_dto.py +65 -0
  180. truefoundry/ml/autogen/client/models/run_data_dto.py +104 -0
  181. truefoundry/ml/autogen/client/models/run_dto.py +84 -0
  182. truefoundry/ml/autogen/client/models/run_info_dto.py +105 -0
  183. truefoundry/ml/autogen/client/models/run_log_dto.py +90 -0
  184. truefoundry/ml/autogen/client/models/run_log_input_dto.py +80 -0
  185. truefoundry/ml/autogen/client/models/run_response_dto.py +75 -0
  186. truefoundry/ml/autogen/client/models/run_tag_dto.py +66 -0
  187. truefoundry/ml/autogen/client/models/search_runs_request_dto.py +94 -0
  188. truefoundry/ml/autogen/client/models/search_runs_response_dto.py +84 -0
  189. truefoundry/ml/autogen/client/models/set_experiment_tag_request_dto.py +73 -0
  190. truefoundry/ml/autogen/client/models/set_tag_request_dto.py +76 -0
  191. truefoundry/ml/autogen/client/models/signed_url_dto.py +69 -0
  192. truefoundry/ml/autogen/client/models/stop.py +152 -0
  193. truefoundry/ml/autogen/client/models/store_run_logs_request_dto.py +83 -0
  194. truefoundry/ml/autogen/client/models/system_message.py +89 -0
  195. truefoundry/ml/autogen/client/models/text.py +153 -0
  196. truefoundry/ml/autogen/client/models/text_content_part.py +84 -0
  197. truefoundry/ml/autogen/client/models/update_artifact_version_request_dto.py +74 -0
  198. truefoundry/ml/autogen/client/models/update_dataset_request_dto.py +74 -0
  199. truefoundry/ml/autogen/client/models/update_experiment_request_dto.py +74 -0
  200. truefoundry/ml/autogen/client/models/update_model_version_request_dto.py +93 -0
  201. truefoundry/ml/autogen/client/models/update_run_request_dto.py +78 -0
  202. truefoundry/ml/autogen/client/models/update_run_response_dto.py +75 -0
  203. truefoundry/ml/autogen/client/models/url.py +153 -0
  204. truefoundry/ml/autogen/client/models/user_message.py +89 -0
  205. truefoundry/ml/autogen/client/models/validation_error.py +87 -0
  206. truefoundry/ml/autogen/client/models/validation_error_loc_inner.py +154 -0
  207. truefoundry/ml/autogen/client/rest.py +426 -0
  208. truefoundry/ml/autogen/client_README.md +320 -0
  209. truefoundry/ml/cli/__init__.py +0 -0
  210. truefoundry/ml/cli/cli.py +18 -0
  211. truefoundry/ml/cli/commands/__init__.py +3 -0
  212. truefoundry/ml/cli/commands/download.py +87 -0
  213. truefoundry/ml/clients/__init__.py +0 -0
  214. truefoundry/ml/clients/entities.py +8 -0
  215. truefoundry/ml/clients/servicefoundry_client.py +45 -0
  216. truefoundry/ml/clients/utils.py +122 -0
  217. truefoundry/ml/constants.py +84 -0
  218. truefoundry/ml/entities.py +62 -0
  219. truefoundry/ml/enums.py +70 -0
  220. truefoundry/ml/env_vars.py +9 -0
  221. truefoundry/ml/exceptions.py +8 -0
  222. truefoundry/ml/git_info.py +60 -0
  223. truefoundry/ml/internal_namespace.py +52 -0
  224. truefoundry/ml/log_types/__init__.py +4 -0
  225. truefoundry/ml/log_types/artifacts/artifact.py +431 -0
  226. truefoundry/ml/log_types/artifacts/constants.py +33 -0
  227. truefoundry/ml/log_types/artifacts/dataset.py +384 -0
  228. truefoundry/ml/log_types/artifacts/general_artifact.py +110 -0
  229. truefoundry/ml/log_types/artifacts/model.py +611 -0
  230. truefoundry/ml/log_types/artifacts/model_extras.py +48 -0
  231. truefoundry/ml/log_types/artifacts/utils.py +161 -0
  232. truefoundry/ml/log_types/image/__init__.py +3 -0
  233. truefoundry/ml/log_types/image/constants.py +8 -0
  234. truefoundry/ml/log_types/image/image.py +357 -0
  235. truefoundry/ml/log_types/image/image_normalizer.py +102 -0
  236. truefoundry/ml/log_types/image/types.py +68 -0
  237. truefoundry/ml/log_types/plot.py +281 -0
  238. truefoundry/ml/log_types/pydantic_base.py +10 -0
  239. truefoundry/ml/log_types/utils.py +12 -0
  240. truefoundry/ml/logger.py +17 -0
  241. truefoundry/ml/mlfoundry_api.py +1575 -0
  242. truefoundry/ml/mlfoundry_run.py +1203 -0
  243. truefoundry/ml/run_utils.py +93 -0
  244. truefoundry/ml/session.py +168 -0
  245. truefoundry/ml/validation_utils.py +346 -0
  246. truefoundry/pydantic_v1.py +6 -2
  247. truefoundry/workflow/__init__.py +16 -1
  248. {truefoundry-0.3.4.dist-info → truefoundry-0.4.0.dist-info}/METADATA +21 -14
  249. truefoundry-0.4.0.dist-info/RECORD +344 -0
  250. truefoundry/deploy/lib/clients/utils.py +0 -41
  251. truefoundry-0.3.4.dist-info/RECORD +0 -136
  252. {truefoundry-0.3.4.dist-info → truefoundry-0.4.0.dist-info}/WHEEL +0 -0
  253. {truefoundry-0.3.4.dist-info → truefoundry-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,1161 @@
1
+ import math
2
+ import mmap
3
+ import os
4
+ import posixpath
5
+ import sys
6
+ import tempfile
7
+ import uuid
8
+ from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait
9
+ from shutil import rmtree
10
+ from threading import Event
11
+ from typing import (
12
+ Any,
13
+ Callable,
14
+ Dict,
15
+ Iterator,
16
+ List,
17
+ NamedTuple,
18
+ Optional,
19
+ Tuple,
20
+ Union,
21
+ )
22
+ from urllib.parse import unquote
23
+ from urllib.request import pathname2url
24
+
25
+ import requests
26
+ from rich.console import _is_jupyter
27
+ from rich.progress import (
28
+ BarColumn,
29
+ DownloadColumn,
30
+ Progress,
31
+ TimeElapsedColumn,
32
+ TimeRemainingColumn,
33
+ TransferSpeedColumn,
34
+ )
35
+ from tqdm.utils import CallbackIOWrapper
36
+
37
+ from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
38
+ ApiClient,
39
+ CreateMultiPartUploadForDatasetRequestDto,
40
+ CreateMultiPartUploadRequestDto,
41
+ FileInfoDto,
42
+ GetSignedURLForDatasetWriteRequestDto,
43
+ GetSignedURLsForArtifactVersionReadRequestDto,
44
+ GetSignedURLsForArtifactVersionWriteRequestDto,
45
+ GetSignedURLsForDatasetReadRequestDto,
46
+ ListFilesForArtifactVersionRequestDto,
47
+ ListFilesForArtifactVersionsResponseDto,
48
+ ListFilesForDatasetRequestDto,
49
+ ListFilesForDatasetResponseDto,
50
+ MlfoundryArtifactsApi,
51
+ MultiPartUploadDto,
52
+ MultiPartUploadResponseDto,
53
+ MultiPartUploadStorageProvider,
54
+ RunArtifactsApi,
55
+ SignedURLDto,
56
+ )
57
+ from truefoundry.ml.clients.utils import (
58
+ augmented_raise_for_status,
59
+ cloud_storage_http_request,
60
+ )
61
+ from truefoundry.ml.env_vars import DISABLE_MULTIPART_UPLOAD
62
+ from truefoundry.ml.exceptions import MlFoundryException
63
+ from truefoundry.ml.logger import logger
64
+ from truefoundry.ml.session import _get_api_client
65
+ from truefoundry.pydantic_v1 import BaseModel, root_validator
66
+
67
+ _MIN_BYTES_REQUIRED_FOR_MULTIPART = 100 * 1024 * 1024
68
+ _MULTIPART_DISABLED = os.getenv(DISABLE_MULTIPART_UPLOAD, "").lower() == "true"
69
+ # GCP/S3 Maximum number of parts per upload 10,000
70
+ # Maximum number of blocks in a block blob 50,000 blocks
71
+ # TODO: This number is artificially limited now. Later
72
+ # we will ask for parts signed URI in batches rather than in a single
73
+ # API Calls:
74
+ # Create Multipart Upload (Returns maximum number of parts, size limit of
75
+ # a single part, upload id for s3 etc )
76
+ # Get me signed uris for first 500 parts
77
+ # Upload 500 parts
78
+ # Get me signed uris for the next 500 parts
79
+ # Upload 500 parts
80
+ # ...
81
+ # Finalize the Multipart upload using the finalize signed url returned
82
+ # by Create Multipart Upload or get a new one.
83
+ _MAX_NUM_PARTS_FOR_MULTIPART = 1000
84
+ # Azure Maximum size of a block in a block blob 4000 MiB
85
+ # GCP/S3 Maximum size of an individual part in a multipart upload 5 GiB
86
+ _MAX_PART_SIZE_BYTES_FOR_MULTIPART = 4 * 1024 * 1024 * 1000
87
+ _cpu_count = os.cpu_count() or 2
88
+ _MAX_WORKERS_FOR_UPLOAD = max(min(32, _cpu_count * 2), 4)
89
+ _MAX_WORKERS_FOR_DOWNLOAD = max(min(32, _cpu_count * 2), 4)
90
+ _LIST_FILES_PAGE_SIZE = 500
91
+ _GENERATE_SIGNED_URL_BATCH_SIZE = 50
92
+ DEFAULT_PRESIGNED_URL_EXPIRY_TIME = 3600
93
+
94
+
95
+ def _get_relpath_if_in_tempdir(path: str) -> str:
96
+ tempdir = tempfile.gettempdir()
97
+ if path.startswith(tempdir):
98
+ return os.path.relpath(path, tempdir)
99
+ return path
100
+
101
+
102
+ def _can_display_progress(user_choice: Optional[bool] = None) -> bool:
103
+ if user_choice is False:
104
+ return False
105
+
106
+ if sys.stdout.isatty():
107
+ return True
108
+ elif _is_jupyter():
109
+ try:
110
+ from IPython.display import display # noqa: F401
111
+ from ipywidgets import Output # noqa: F401
112
+
113
+ return True
114
+ except ImportError:
115
+ logger.warning(
116
+ "Detected Jupyter Environment. Install `ipywidgets` to display live progress bars.",
117
+ )
118
+ if user_choice is True:
119
+ logger.warning(
120
+ "`progress` argument is set to True but did not detect tty "
121
+ "or jupyter environment with ipywidgets installed. "
122
+ "Progress bars may not be displayed. "
123
+ )
124
+ return True
125
+ return False
126
+
127
+
128
+ def relative_path_to_artifact_path(path):
129
+ if os.path == posixpath:
130
+ return path
131
+ if os.path.abspath(path) == path:
132
+ raise Exception("This method only works with relative paths.")
133
+ return unquote(pathname2url(path))
134
+
135
+
136
+ def _align_part_size_with_mmap_allocation_granularity(part_size: int) -> int:
137
+ modulo = part_size % mmap.ALLOCATIONGRANULARITY
138
+ if modulo == 0:
139
+ return part_size
140
+
141
+ part_size += mmap.ALLOCATIONGRANULARITY - modulo
142
+ return part_size
143
+
144
+
145
+ # Can not be less than 5 * 1024 * 1024
146
+ _PART_SIZE_BYTES_FOR_MULTIPART = _align_part_size_with_mmap_allocation_granularity(
147
+ 10 * 1024 * 1024
148
+ )
149
+
150
+
151
+ def bad_path_message(name):
152
+ return (
153
+ "Names may be treated as files in certain cases, and must not resolve to other names"
154
+ " when treated as such. This name would resolve to '%s'"
155
+ ) % posixpath.normpath(name)
156
+
157
+
158
+ def path_not_unique(name):
159
+ norm = posixpath.normpath(name)
160
+ return norm != name or norm == "." or norm.startswith("..") or norm.startswith("/")
161
+
162
+
163
+ def verify_artifact_path(artifact_path):
164
+ if artifact_path and path_not_unique(artifact_path):
165
+ raise MlFoundryException(
166
+ f"Invalid artifact path: {artifact_path!r}. {bad_path_message(artifact_path)!r}"
167
+ )
168
+
169
+
170
+ class _PartNumberEtag(NamedTuple):
171
+ part_number: int
172
+ etag: str
173
+
174
+
175
+ def _get_s3_compatible_completion_body(multi_parts: List[_PartNumberEtag]) -> str:
176
+ body = "<CompleteMultipartUpload>\n"
177
+ for part in multi_parts:
178
+ body += " <Part>\n"
179
+ body += f" <PartNumber>{part.part_number}</PartNumber>\n"
180
+ body += f" <ETag>{part.etag}</ETag>\n"
181
+ body += " </Part>\n"
182
+ body += "</CompleteMultipartUpload>"
183
+ return body
184
+
185
+
186
+ def _get_azure_blob_completion_body(block_ids: List[str]) -> str:
187
+ body = "<BlockList>\n"
188
+ for block_id in block_ids:
189
+ body += f"<Uncommitted>{block_id}</Uncommitted> "
190
+ body += "</BlockList>"
191
+ return body
192
+
193
+
194
+ class _FileMultiPartInfo(NamedTuple):
195
+ num_parts: int
196
+ part_size: int
197
+ file_size: int
198
+
199
+
200
+ def _decide_file_parts(file_path: str) -> _FileMultiPartInfo:
201
+ file_size = os.path.getsize(file_path)
202
+ if file_size < _MIN_BYTES_REQUIRED_FOR_MULTIPART or _MULTIPART_DISABLED:
203
+ return _FileMultiPartInfo(1, part_size=file_size, file_size=file_size)
204
+
205
+ ideal_num_parts = math.ceil(file_size / _PART_SIZE_BYTES_FOR_MULTIPART)
206
+ if ideal_num_parts <= _MAX_NUM_PARTS_FOR_MULTIPART:
207
+ return _FileMultiPartInfo(
208
+ ideal_num_parts,
209
+ part_size=_PART_SIZE_BYTES_FOR_MULTIPART,
210
+ file_size=file_size,
211
+ )
212
+
213
+ part_size_when_using_max_parts = math.ceil(file_size / _MAX_NUM_PARTS_FOR_MULTIPART)
214
+ part_size_when_using_max_parts = _align_part_size_with_mmap_allocation_granularity(
215
+ part_size_when_using_max_parts
216
+ )
217
+ if part_size_when_using_max_parts > _MAX_PART_SIZE_BYTES_FOR_MULTIPART:
218
+ raise ValueError(
219
+ f"file {file_path!r} is too big for upload. Multipart chunk"
220
+ f" size {part_size_when_using_max_parts} is higher"
221
+ f" than {_MAX_PART_SIZE_BYTES_FOR_MULTIPART}"
222
+ )
223
+ num_parts = math.ceil(file_size / part_size_when_using_max_parts)
224
+ return _FileMultiPartInfo(
225
+ num_parts, part_size=part_size_when_using_max_parts, file_size=file_size
226
+ )
227
+
228
+
229
+ def _signed_url_upload_file(
230
+ signed_url: SignedURLDto,
231
+ local_file: str,
232
+ progress_bar: Progress,
233
+ abort_event: Optional[Event] = None,
234
+ ):
235
+ if os.stat(local_file).st_size == 0:
236
+ with cloud_storage_http_request(
237
+ method="put", url=signed_url.signed_url, data=""
238
+ ) as response:
239
+ augmented_raise_for_status(response.raise_for_status())
240
+ return
241
+
242
+ task_progress_bar = progress_bar.add_task(
243
+ f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
244
+ )
245
+
246
+ def callback(length):
247
+ progress_bar.update(
248
+ task_progress_bar, advance=length, total=os.stat(local_file).st_size
249
+ )
250
+ if abort_event and abort_event.is_set():
251
+ raise Exception("aborting upload")
252
+
253
+ with open(local_file, "rb") as file:
254
+ # NOTE: Azure Put Blob does not support Transfer Encoding header.
255
+ wrapped_file = CallbackIOWrapper(callback, file, "read")
256
+ with cloud_storage_http_request(
257
+ method="put", url=signed_url.signed_url, data=wrapped_file
258
+ ) as response:
259
+ augmented_raise_for_status(response)
260
+
261
+
262
+ def _download_file_using_http_uri(
263
+ http_uri,
264
+ download_path,
265
+ chunk_size=100000000,
266
+ callback: Optional[Callable[[int, int], Any]] = None,
267
+ ):
268
+ """
269
+ Downloads a file specified using the `http_uri` to a local `download_path`. This function
270
+ uses a `chunk_size` to ensure an OOM error is not raised a large file is downloaded.
271
+ Note : This function is meant to download files using presigned urls from various cloud
272
+ providers.
273
+ """
274
+ with cloud_storage_http_request(
275
+ method="get", url=http_uri, stream=True
276
+ ) as response:
277
+ augmented_raise_for_status(response)
278
+ file_size = int(response.headers.get("Content-Length", 0))
279
+ with open(download_path, "wb") as output_file:
280
+ for chunk in response.iter_content(chunk_size=chunk_size):
281
+ if callback:
282
+ callback(len(chunk), file_size)
283
+ if not chunk:
284
+ break
285
+ output_file.write(chunk)
286
+
287
+
288
+ class _CallbackIOWrapperForMultiPartUpload(CallbackIOWrapper):
289
+ def __init__(self, callback, stream, method, length: int):
290
+ self.wrapper_setattr("_length", length)
291
+ super().__init__(callback, stream, method)
292
+
293
+ def __len__(self):
294
+ return self.wrapper_getattr("_length")
295
+
296
+
297
+ def _file_part_upload(
298
+ url: str,
299
+ file_path: str,
300
+ seek: int,
301
+ length: int,
302
+ file_size: int,
303
+ abort_event: Optional[Event] = None,
304
+ method: str = "put",
305
+ ):
306
+ def callback(*_, **__):
307
+ if abort_event and abort_event.is_set():
308
+ raise Exception("aborting upload")
309
+
310
+ with open(file_path, "rb") as file:
311
+ with mmap.mmap(
312
+ file.fileno(),
313
+ length=min(file_size - seek, length),
314
+ offset=seek,
315
+ access=mmap.ACCESS_READ,
316
+ ) as mapped_file:
317
+ wrapped_file = _CallbackIOWrapperForMultiPartUpload(
318
+ callback, mapped_file, "read", len(mapped_file)
319
+ )
320
+ with cloud_storage_http_request(
321
+ method=method,
322
+ url=url,
323
+ data=wrapped_file,
324
+ ) as response:
325
+ augmented_raise_for_status(response)
326
+ return response
327
+
328
+
329
+ def _s3_compatible_multipart_upload(
330
+ multipart_upload: MultiPartUploadDto,
331
+ local_file: str,
332
+ multipart_info: _FileMultiPartInfo,
333
+ executor: ThreadPoolExecutor,
334
+ progress_bar: Progress,
335
+ abort_event: Optional[Event] = None,
336
+ ):
337
+ abort_event = abort_event or Event()
338
+ parts = []
339
+
340
+ multi_part_upload_progress = progress_bar.add_task(
341
+ f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
342
+ )
343
+
344
+ def upload(part_number: int, seek: int) -> None:
345
+ logger.debug(
346
+ "Uploading part %d/%d of %s",
347
+ part_number,
348
+ multipart_info.num_parts,
349
+ local_file,
350
+ )
351
+ response = _file_part_upload(
352
+ url=multipart_upload.part_signed_urls[part_number].signed_url,
353
+ file_path=local_file,
354
+ seek=seek,
355
+ length=multipart_info.part_size,
356
+ file_size=multipart_info.file_size,
357
+ abort_event=abort_event,
358
+ )
359
+ logger.debug(
360
+ "Uploaded part %d/%d of %s",
361
+ part_number,
362
+ multipart_info.num_parts,
363
+ local_file,
364
+ )
365
+ progress_bar.update(
366
+ multi_part_upload_progress,
367
+ advance=multipart_info.part_size,
368
+ total=multipart_info.file_size,
369
+ )
370
+ etag = response.headers["ETag"]
371
+ parts.append(_PartNumberEtag(etag=etag, part_number=part_number + 1))
372
+
373
+ futures: List[Future] = []
374
+ for part_number, seek in enumerate(
375
+ range(0, multipart_info.file_size, multipart_info.part_size)
376
+ ):
377
+ future = executor.submit(upload, part_number=part_number, seek=seek)
378
+ futures.append(future)
379
+
380
+ done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
381
+ if len(not_done) > 0:
382
+ abort_event.set()
383
+ for future in not_done:
384
+ future.cancel()
385
+ for future in done:
386
+ if future.exception() is not None:
387
+ raise future.exception()
388
+
389
+ logger.debug("Finalizing multipart upload of %s", local_file)
390
+ parts = sorted(parts, key=lambda part: part.part_number)
391
+ response = requests.post(
392
+ multipart_upload.finalize_signed_url.signed_url,
393
+ data=_get_s3_compatible_completion_body(parts),
394
+ timeout=2 * 60,
395
+ )
396
+ response.raise_for_status()
397
+ logger.debug("Multipart upload of %s completed", local_file)
398
+
399
+
400
+ def _azure_multi_part_upload(
401
+ multipart_upload: MultiPartUploadDto,
402
+ local_file: str,
403
+ multipart_info: _FileMultiPartInfo,
404
+ executor: ThreadPoolExecutor,
405
+ progress_bar: Progress,
406
+ abort_event: Optional[Event] = None,
407
+ ):
408
+ abort_event = abort_event or Event()
409
+
410
+ multi_part_upload_progress = progress_bar.add_task(
411
+ f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
412
+ )
413
+
414
+ def upload(part_number: int, seek: int):
415
+ logger.debug(
416
+ "Uploading part %d/%d of %s",
417
+ part_number,
418
+ multipart_info.num_parts,
419
+ local_file,
420
+ )
421
+ _file_part_upload(
422
+ url=multipart_upload.part_signed_urls[part_number].signed_url,
423
+ file_path=local_file,
424
+ seek=seek,
425
+ length=multipart_info.part_size,
426
+ file_size=multipart_info.file_size,
427
+ abort_event=abort_event,
428
+ )
429
+ progress_bar.update(
430
+ multi_part_upload_progress,
431
+ advance=multipart_info.part_size,
432
+ total=multipart_info.file_size,
433
+ )
434
+ logger.debug(
435
+ "Uploaded part %d/%d of %s",
436
+ part_number,
437
+ multipart_info.num_parts,
438
+ local_file,
439
+ )
440
+
441
+ futures: List[Future] = []
442
+ for part_number, seek in enumerate(
443
+ range(0, multipart_info.file_size, multipart_info.part_size)
444
+ ):
445
+ future = executor.submit(upload, part_number=part_number, seek=seek)
446
+ futures.append(future)
447
+
448
+ done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
449
+ if len(not_done) > 0:
450
+ abort_event.set()
451
+ for future in not_done:
452
+ future.cancel()
453
+ for future in done:
454
+ if future.exception() is not None:
455
+ raise future.exception()
456
+
457
+ logger.debug("Finalizing multipart upload of %s", local_file)
458
+ if multipart_upload.azure_blob_block_ids:
459
+ response = requests.put(
460
+ multipart_upload.finalize_signed_url.signed_url,
461
+ data=_get_azure_blob_completion_body(
462
+ block_ids=multipart_upload.azure_blob_block_ids
463
+ ),
464
+ timeout=2 * 60,
465
+ )
466
+ response.raise_for_status()
467
+ logger.debug("Multipart upload of %s completed", local_file)
468
+
469
+
470
+ def _any_future_has_failed(futures) -> bool:
471
+ return any(
472
+ future.done() and not future.cancelled() and future.exception() is not None
473
+ for future in futures
474
+ )
475
+
476
+
477
+ class ArtifactIdentifier(BaseModel):
478
+ artifact_version_id: Optional[uuid.UUID] = None
479
+ dataset_fqn: Optional[str] = None
480
+
481
+ @root_validator
482
+ def _check_identifier_type(cls, values: Dict[str, Any]):
483
+ if not values.get("artifact_version_id", False) and not values.get(
484
+ "dataset_fqn", False
485
+ ):
486
+ raise MlFoundryException(
487
+ "One of the version_id or dataset_fqn should be passed"
488
+ )
489
+ if values.get("artifact_version_id", False) and values.get(
490
+ "dataset_fqn", False
491
+ ):
492
+ raise MlFoundryException(
493
+ "Exactly one of version_id or dataset_fqn should be passed"
494
+ )
495
+ return values
496
+
497
+
498
+ class MlFoundryArtifactsRepository:
499
+ def __init__(
500
+ self,
501
+ artifact_identifier: ArtifactIdentifier,
502
+ api_client: Optional[ApiClient] = None,
503
+ ):
504
+ self.artifact_identifier = artifact_identifier
505
+ self._api_client = api_client or _get_api_client()
506
+ self._run_artifacts_api = RunArtifactsApi(api_client=self._api_client)
507
+ self._mlfoundry_artifacts_api = MlfoundryArtifactsApi(
508
+ api_client=self._api_client
509
+ )
510
+
511
+ def _create_download_destination(
512
+ self, src_artifact_path, dst_local_dir_path=None
513
+ ) -> str:
514
+ """
515
+ Creates a local filesystem location to be used as a destination for downloading the artifact
516
+ specified by `src_artifact_path`. The destination location is a subdirectory of the
517
+ specified `dst_local_dir_path`, which is determined according to the structure of
518
+ `src_artifact_path`. For example, if `src_artifact_path` is `dir1/file1.txt`, then the
519
+ resulting destination path is `<dst_local_dir_path>/dir1/file1.txt`. Local directories are
520
+ created for the resulting destination location if they do not exist.
521
+
522
+ :param src_artifact_path: A relative, POSIX-style path referring to an artifact stored
523
+ within the repository's artifact root location.
524
+ `src_artifact_path` should be specified relative to the
525
+ repository's artifact root location.
526
+ :param dst_local_dir_path: The absolute path to a local filesystem directory in which the
527
+ local destination path will be contained. The local destination
528
+ path may be contained in a subdirectory of `dst_root_dir` if
529
+ `src_artifact_path` contains subdirectories.
530
+ :return: The absolute path to a local filesystem location to be used as a destination
531
+ for downloading the artifact specified by `src_artifact_path`.
532
+ """
533
+ src_artifact_path = src_artifact_path.rstrip(
534
+ "/"
535
+ ) # Ensure correct dirname for trailing '/'
536
+ dirpath = posixpath.dirname(src_artifact_path)
537
+ local_dir_path = os.path.join(dst_local_dir_path, dirpath)
538
+ local_file_path = os.path.join(dst_local_dir_path, src_artifact_path)
539
+ if not os.path.exists(local_dir_path):
540
+ os.makedirs(local_dir_path, exist_ok=True)
541
+ return local_file_path
542
+
543
+ # these methods should be named list_files, log_directory, log_file, etc
544
+ def list_artifacts(
545
+ self, path=None, page_size=_LIST_FILES_PAGE_SIZE, **kwargs
546
+ ) -> Iterator[FileInfoDto]:
547
+ page_token = None
548
+ started = False
549
+ while not started or page_token is not None:
550
+ started = True
551
+ page = self.list_files(
552
+ artifact_identifier=self.artifact_identifier,
553
+ path=path,
554
+ page_size=page_size,
555
+ page_token=page_token,
556
+ )
557
+ for file_info in page.files:
558
+ yield file_info
559
+ page_token = page.next_page_token
560
+
561
+ def log_artifacts( # noqa: C901
562
+ self, local_dir, artifact_path=None, progress=None
563
+ ):
564
+ show_progress = _can_display_progress(progress)
565
+
566
+ dest_path = artifact_path or ""
567
+ dest_path = dest_path.lstrip(posixpath.sep)
568
+
569
+ files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
570
+ files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
571
+
572
+ for root, _, file_names in os.walk(local_dir):
573
+ upload_path = dest_path
574
+ if root != local_dir:
575
+ rel_path = os.path.relpath(root, local_dir)
576
+ rel_path = relative_path_to_artifact_path(rel_path)
577
+ upload_path = posixpath.join(dest_path, rel_path)
578
+ for file_name in file_names:
579
+ local_file = os.path.join(root, file_name)
580
+ multipart_info = _decide_file_parts(local_file)
581
+
582
+ final_upload_path = upload_path or ""
583
+ final_upload_path = final_upload_path.lstrip(posixpath.sep)
584
+ final_upload_path = posixpath.join(
585
+ final_upload_path, os.path.basename(local_file)
586
+ )
587
+
588
+ if multipart_info.num_parts == 1:
589
+ files_for_normal_upload.append(
590
+ (final_upload_path, local_file, multipart_info)
591
+ )
592
+ else:
593
+ files_for_multipart_upload.append(
594
+ (final_upload_path, local_file, multipart_info)
595
+ )
596
+
597
+ abort_event = Event()
598
+
599
+ with Progress(
600
+ "[progress.description]{task.description}",
601
+ BarColumn(bar_width=None),
602
+ "[progress.percentage]{task.percentage:>3.0f}%",
603
+ DownloadColumn(),
604
+ TransferSpeedColumn(),
605
+ TimeRemainingColumn(),
606
+ TimeElapsedColumn(),
607
+ refresh_per_second=1,
608
+ disable=not show_progress,
609
+ expand=True,
610
+ ) as progress_bar, ThreadPoolExecutor(
611
+ max_workers=_MAX_WORKERS_FOR_UPLOAD
612
+ ) as executor:
613
+ futures: List[Future] = []
614
+ # Note: While this batching is beneficial when there is a large number of files, there is also
615
+ # a rare case risk of the signed url expiring before a request is made to it
616
+ _batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
617
+ for start_idx in range(0, len(files_for_normal_upload), _batch_size):
618
+ end_idx = min(start_idx + _batch_size, len(files_for_normal_upload))
619
+ if _any_future_has_failed(futures):
620
+ break
621
+ logger.debug("Generating write signed urls for a batch ...")
622
+ remote_file_paths = [
623
+ files_for_normal_upload[idx][0] for idx in range(start_idx, end_idx)
624
+ ]
625
+ signed_urls = self.get_signed_urls_for_write(
626
+ artifact_identifier=self.artifact_identifier,
627
+ paths=remote_file_paths,
628
+ )
629
+ for idx, signed_url in zip(range(start_idx, end_idx), signed_urls):
630
+ (
631
+ upload_path,
632
+ local_file,
633
+ multipart_info,
634
+ ) = files_for_normal_upload[idx]
635
+ future = executor.submit(
636
+ self._log_artifact,
637
+ local_file=local_file,
638
+ artifact_path=upload_path,
639
+ multipart_info=multipart_info,
640
+ signed_url=signed_url,
641
+ abort_event=abort_event,
642
+ executor_for_multipart_upload=None,
643
+ progress_bar=progress_bar,
644
+ )
645
+ futures.append(future)
646
+
647
+ done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
648
+ if len(not_done) > 0:
649
+ abort_event.set()
650
+ for future in not_done:
651
+ future.cancel()
652
+ for future in done:
653
+ if future.exception() is not None:
654
+ raise future.exception()
655
+
656
+ for (
657
+ upload_path,
658
+ local_file,
659
+ multipart_info,
660
+ ) in files_for_multipart_upload:
661
+ self._log_artifact(
662
+ local_file=local_file,
663
+ artifact_path=upload_path,
664
+ signed_url=None,
665
+ multipart_info=multipart_info,
666
+ executor_for_multipart_upload=executor,
667
+ progress_bar=progress_bar,
668
+ )
669
+
670
+ def _normal_upload(
671
+ self,
672
+ local_file: str,
673
+ artifact_path: str,
674
+ signed_url: Optional[SignedURLDto],
675
+ progress_bar: Progress,
676
+ abort_event: Optional[Event] = None,
677
+ ):
678
+ if not signed_url:
679
+ signed_url = self.get_signed_urls_for_write(
680
+ artifact_identifier=self.artifact_identifier, paths=[artifact_path]
681
+ )[0]
682
+
683
+ if progress_bar.disable:
684
+ logger.info(
685
+ "Uploading %s to %s",
686
+ _get_relpath_if_in_tempdir(local_file),
687
+ artifact_path,
688
+ )
689
+
690
+ _signed_url_upload_file(
691
+ signed_url=signed_url,
692
+ local_file=local_file,
693
+ abort_event=abort_event,
694
+ progress_bar=progress_bar,
695
+ )
696
+ logger.debug("Uploaded %s to %s", local_file, artifact_path)
697
+
698
+ def _multipart_upload(
699
+ self,
700
+ local_file: str,
701
+ artifact_path: str,
702
+ multipart_info: _FileMultiPartInfo,
703
+ executor: ThreadPoolExecutor,
704
+ progress_bar: Progress,
705
+ abort_event: Optional[Event] = None,
706
+ ):
707
+ if progress_bar.disable:
708
+ logger.info(
709
+ "Uploading %s to %s using multipart upload",
710
+ _get_relpath_if_in_tempdir(local_file),
711
+ artifact_path,
712
+ )
713
+
714
+ multipart_upload = self.create_multipart_upload_for_identifier(
715
+ artifact_identifier=self.artifact_identifier,
716
+ path=artifact_path,
717
+ num_parts=multipart_info.num_parts,
718
+ )
719
+ if (
720
+ multipart_upload.storage_provider
721
+ is MultiPartUploadStorageProvider.S3_COMPATIBLE
722
+ ):
723
+ _s3_compatible_multipart_upload(
724
+ multipart_upload=multipart_upload,
725
+ local_file=local_file,
726
+ executor=executor,
727
+ multipart_info=multipart_info,
728
+ abort_event=abort_event,
729
+ progress_bar=progress_bar,
730
+ )
731
+ elif (
732
+ multipart_upload.storage_provider
733
+ is MultiPartUploadStorageProvider.AZURE_BLOB
734
+ ):
735
+ _azure_multi_part_upload(
736
+ multipart_upload=multipart_upload,
737
+ local_file=local_file,
738
+ executor=executor,
739
+ multipart_info=multipart_info,
740
+ abort_event=abort_event,
741
+ progress_bar=progress_bar,
742
+ )
743
+ else:
744
+ raise NotImplementedError()
745
+
746
+ def _log_artifact(
747
+ self,
748
+ local_file: str,
749
+ artifact_path: str,
750
+ multipart_info: _FileMultiPartInfo,
751
+ progress_bar: Progress,
752
+ signed_url: Optional[SignedURLDto] = None,
753
+ abort_event: Optional[Event] = None,
754
+ executor_for_multipart_upload: Optional[ThreadPoolExecutor] = None,
755
+ ):
756
+ if multipart_info.num_parts == 1:
757
+ return self._normal_upload(
758
+ local_file=local_file,
759
+ artifact_path=artifact_path,
760
+ signed_url=signed_url,
761
+ abort_event=abort_event,
762
+ progress_bar=progress_bar,
763
+ )
764
+
765
+ if not executor_for_multipart_upload:
766
+ with ThreadPoolExecutor(max_workers=_MAX_WORKERS_FOR_UPLOAD) as executor:
767
+ return self._multipart_upload(
768
+ local_file=local_file,
769
+ artifact_path=artifact_path,
770
+ executor=executor,
771
+ multipart_info=multipart_info,
772
+ progress_bar=progress_bar,
773
+ )
774
+
775
+ return self._multipart_upload(
776
+ local_file=local_file,
777
+ artifact_path=artifact_path,
778
+ executor=executor_for_multipart_upload,
779
+ multipart_info=multipart_info,
780
+ progress_bar=progress_bar,
781
+ )
782
+
783
+ def log_artifact(self, local_file: str, artifact_path: Optional[str] = None):
784
+ upload_path = artifact_path or ""
785
+ upload_path = upload_path.lstrip(posixpath.sep)
786
+ upload_path = posixpath.join(upload_path, os.path.basename(local_file))
787
+ with Progress(
788
+ "[progress.description]{task.description}",
789
+ BarColumn(bar_width=None),
790
+ "[progress.percentage]{task.percentage:>3.0f}%",
791
+ DownloadColumn(),
792
+ TransferSpeedColumn(),
793
+ TimeRemainingColumn(),
794
+ TimeElapsedColumn(),
795
+ refresh_per_second=1,
796
+ disable=True,
797
+ expand=True,
798
+ ) as progress_bar:
799
+ self._log_artifact(
800
+ local_file=local_file,
801
+ artifact_path=upload_path,
802
+ multipart_info=_decide_file_parts(local_file),
803
+ progress_bar=progress_bar,
804
+ )
805
+
806
+ def _is_directory(self, artifact_path):
807
+ for _ in self.list_artifacts(artifact_path, page_size=3):
808
+ return True
809
+ return False
810
+
811
+ def download_artifacts( # noqa: C901
812
+ self,
813
+ artifact_path: str,
814
+ dst_path: Optional[str] = None,
815
+ overwrite: bool = False,
816
+ progress: Optional[bool] = None,
817
+ ) -> str:
818
+ """
819
+ Download an artifact file or directory to a local directory if applicable, and return a
820
+ local path for it. The caller is responsible for managing the lifecycle of the downloaded artifacts.
821
+
822
+ Args:
823
+ artifact_path: Relative source path to the desired artifacts.
824
+ dst_path: Absolute path of the local filesystem destination directory to which to
825
+ download the specified artifacts. This directory must already exist.
826
+ If unspecified, the artifacts will either be downloaded to a new
827
+ uniquely-named directory.
828
+ overwrite: if to overwrite the files at/inside `dst_path` if they exist
829
+ progress: Show or hide progress bar
830
+
831
+ Returns:
832
+ str: Absolute path of the local filesystem location containing the desired artifacts.
833
+ """
834
+
835
+ show_progress = _can_display_progress()
836
+
837
+ is_dir_temp = False
838
+ if dst_path is None:
839
+ dst_path = tempfile.mkdtemp()
840
+ is_dir_temp = True
841
+
842
+ dst_path = os.path.abspath(dst_path)
843
+ if is_dir_temp:
844
+ logger.info(
845
+ f"Using temporary directory {dst_path} as the download directory"
846
+ )
847
+
848
+ if not os.path.exists(dst_path):
849
+ raise MlFoundryException(
850
+ message=(
851
+ "The destination path for downloaded artifacts does not"
852
+ " exist! Destination path: {dst_path}".format(dst_path=dst_path)
853
+ ),
854
+ )
855
+ elif not os.path.isdir(dst_path):
856
+ raise MlFoundryException(
857
+ message=(
858
+ "The destination path for downloaded artifacts must be a directory!"
859
+ " Destination path: {dst_path}".format(dst_path=dst_path)
860
+ ),
861
+ )
862
+
863
+ progress_bar = Progress(
864
+ "[progress.description]{task.description}",
865
+ BarColumn(),
866
+ "[progress.percentage]{task.percentage:>3.0f}%",
867
+ DownloadColumn(),
868
+ TransferSpeedColumn(),
869
+ TimeRemainingColumn(),
870
+ TimeElapsedColumn(),
871
+ refresh_per_second=1,
872
+ disable=not show_progress,
873
+ expand=True,
874
+ )
875
+
876
+ try:
877
+ progress_bar.start()
878
+ # Check if the artifacts points to a directory
879
+ if self._is_directory(artifact_path):
880
+ futures: List[Future] = []
881
+ file_paths: List[Tuple[str, str]] = []
882
+ abort_event = Event()
883
+
884
+ # Check if any file is being overwritten before downloading them
885
+ for file_path, download_dest_path in self._get_file_paths_recur(
886
+ src_artifact_dir_path=artifact_path, dst_local_dir_path=dst_path
887
+ ):
888
+ final_file_path = os.path.join(download_dest_path, file_path)
889
+
890
+ # There would be no overwrite if temp directory is being used
891
+ if (
892
+ not is_dir_temp
893
+ and os.path.exists(final_file_path)
894
+ and not overwrite
895
+ ):
896
+ raise MlFoundryException(
897
+ f"File already exists at {final_file_path}, aborting download "
898
+ f"(set `overwrite` flag to overwrite this and any subsequent files)"
899
+ )
900
+ file_paths.append((file_path, download_dest_path))
901
+
902
+ with ThreadPoolExecutor(_MAX_WORKERS_FOR_DOWNLOAD) as executor:
903
+ # Note: While this batching is beneficial when there is a large number of files, there is also
904
+ # a rare case risk of the signed url expiring before a request is made to it
905
+ batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
906
+ for start_idx in range(0, len(file_paths), batch_size):
907
+ end_idx = min(start_idx + batch_size, len(file_paths))
908
+ if _any_future_has_failed(futures):
909
+ break
910
+ logger.debug("Generating read signed urls for a batch ...")
911
+ remote_file_paths = [
912
+ file_paths[idx][0] for idx in range(start_idx, end_idx)
913
+ ]
914
+ signed_urls = self.get_signed_urls_for_read(
915
+ artifact_identifier=self.artifact_identifier,
916
+ paths=remote_file_paths,
917
+ )
918
+ for idx, signed_url in zip(
919
+ range(start_idx, end_idx), signed_urls
920
+ ):
921
+ file_path, download_dest_path = file_paths[idx]
922
+ future = executor.submit(
923
+ self._download_artifact,
924
+ src_artifact_path=file_path,
925
+ dst_local_dir_path=download_dest_path,
926
+ signed_url=signed_url,
927
+ abort_event=abort_event,
928
+ progress_bar=progress_bar,
929
+ )
930
+ futures.append(future)
931
+
932
+ done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
933
+ if len(not_done) > 0:
934
+ abort_event.set()
935
+ for future in not_done:
936
+ future.cancel()
937
+ for future in done:
938
+ if future.exception() is not None:
939
+ raise future.exception()
940
+
941
+ output_dir = os.path.join(dst_path, artifact_path)
942
+ return output_dir
943
+ else:
944
+ return self._download_artifact(
945
+ src_artifact_path=artifact_path,
946
+ dst_local_dir_path=dst_path,
947
+ signed_url=None,
948
+ progress_bar=progress_bar,
949
+ )
950
+ except Exception as err:
951
+ if is_dir_temp:
952
+ logger.info(
953
+ f"Error encountered, removing temporary download directory at {dst_path}"
954
+ )
955
+ rmtree(dst_path) # remove temp directory alongside it's contents
956
+ raise err
957
+
958
+ finally:
959
+ progress_bar.stop()
960
+
961
+ # noinspection PyMethodOverriding
962
+ def _download_file(
963
+ self,
964
+ remote_file_path: str,
965
+ local_path: str,
966
+ progress_bar: Optional[Progress],
967
+ signed_url: Optional[SignedURLDto],
968
+ abort_event: Optional[Event] = None,
969
+ ):
970
+ if not remote_file_path:
971
+ raise MlFoundryException(
972
+ f"remote_file_path cannot be None or empty str {remote_file_path}"
973
+ )
974
+ if not signed_url:
975
+ signed_url = self.get_signed_urls_for_read(
976
+ artifact_identifier=self.artifact_identifier, paths=[remote_file_path]
977
+ )[0]
978
+
979
+ if progress_bar is None or not progress_bar.disable:
980
+ logger.info("Downloading %s to %s", remote_file_path, local_path)
981
+
982
+ if progress_bar is not None:
983
+ download_progress_bar = progress_bar.add_task(
984
+ f"[green]Downloading to {remote_file_path}:", start=True
985
+ )
986
+
987
+ def callback(chunk, total_file_size):
988
+ if progress_bar is not None:
989
+ progress_bar.update(
990
+ download_progress_bar,
991
+ advance=chunk,
992
+ total=total_file_size,
993
+ )
994
+ if abort_event and abort_event.is_set():
995
+ raise Exception("aborting download")
996
+
997
+ _download_file_using_http_uri(
998
+ http_uri=signed_url.signed_url,
999
+ download_path=local_path,
1000
+ callback=callback,
1001
+ )
1002
+ logger.debug("Downloaded %s to %s", remote_file_path, local_path)
1003
+
1004
+ def _download_artifact(
1005
+ self,
1006
+ src_artifact_path,
1007
+ dst_local_dir_path,
1008
+ signed_url: Optional[SignedURLDto],
1009
+ progress_bar: Optional[Progress] = None,
1010
+ abort_event=None,
1011
+ ) -> str:
1012
+ """
1013
+ Download the file artifact specified by `src_artifact_path` to the local filesystem
1014
+ directory specified by `dst_local_dir_path`.
1015
+ :param src_artifact_path: A relative, POSIX-style path referring to a file artifact
1016
+ stored within the repository's artifact root location.
1017
+ `src_artifact_path` should be specified relative to the
1018
+ repository's artifact root location.
1019
+ :param dst_local_dir_path: Absolute path of the local filesystem destination directory
1020
+ to which to download the specified artifact. The downloaded
1021
+ artifact may be written to a subdirectory of
1022
+ `dst_local_dir_path` if `src_artifact_path` contains
1023
+ subdirectories.
1024
+ :param progress_bar: An instance of a Rich progress bar used to visually display the
1025
+ progress of the file download.
1026
+ :return: A local filesystem path referring to the downloaded file.
1027
+ """
1028
+ local_destination_file_path = self._create_download_destination(
1029
+ src_artifact_path=src_artifact_path, dst_local_dir_path=dst_local_dir_path
1030
+ )
1031
+ self._download_file(
1032
+ remote_file_path=src_artifact_path,
1033
+ local_path=local_destination_file_path,
1034
+ signed_url=signed_url,
1035
+ abort_event=abort_event,
1036
+ progress_bar=progress_bar,
1037
+ )
1038
+ return local_destination_file_path
1039
+
1040
+ def _get_file_paths_recur(self, src_artifact_dir_path, dst_local_dir_path):
1041
+ local_dir = os.path.join(dst_local_dir_path, src_artifact_dir_path)
1042
+ dir_content = [ # prevent infinite loop, sometimes the dir is recursively included
1043
+ file_info
1044
+ for file_info in self.list_artifacts(src_artifact_dir_path)
1045
+ if file_info.path != "." and file_info.path != src_artifact_dir_path
1046
+ ]
1047
+ if not dir_content: # empty dir
1048
+ if not os.path.exists(local_dir):
1049
+ os.makedirs(local_dir, exist_ok=True)
1050
+ else:
1051
+ for file_info in dir_content:
1052
+ if file_info.is_dir:
1053
+ yield from self._get_file_paths_recur(
1054
+ src_artifact_dir_path=file_info.path,
1055
+ dst_local_dir_path=dst_local_dir_path,
1056
+ )
1057
+ else:
1058
+ yield file_info.path, dst_local_dir_path
1059
+
1060
+ # TODO (chiragjn): Refactor these methods - if else is very inconvenient
1061
+ def get_signed_urls_for_read(
1062
+ self,
1063
+ artifact_identifier: ArtifactIdentifier,
1064
+ paths,
1065
+ ) -> List[SignedURLDto]:
1066
+ if artifact_identifier.artifact_version_id:
1067
+ signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_read_post(
1068
+ get_signed_urls_for_artifact_version_read_request_dto=GetSignedURLsForArtifactVersionReadRequestDto(
1069
+ id=str(artifact_identifier.artifact_version_id), paths=paths
1070
+ )
1071
+ )
1072
+ signed_urls = signed_urls_response.signed_urls
1073
+ elif artifact_identifier.dataset_fqn:
1074
+ signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_dataset_read_post(
1075
+ get_signed_urls_for_dataset_read_request_dto=GetSignedURLsForDatasetReadRequestDto(
1076
+ dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
1077
+ )
1078
+ )
1079
+ signed_urls = signed_urls_dataset_response.signed_urls
1080
+ else:
1081
+ raise ValueError(
1082
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1083
+ )
1084
+ return signed_urls
1085
+
1086
+ def get_signed_urls_for_write(
1087
+ self,
1088
+ artifact_identifier: ArtifactIdentifier,
1089
+ paths: List[str],
1090
+ ) -> List[SignedURLDto]:
1091
+ if artifact_identifier.artifact_version_id:
1092
+ signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_write_post(
1093
+ get_signed_urls_for_artifact_version_write_request_dto=GetSignedURLsForArtifactVersionWriteRequestDto(
1094
+ id=str(artifact_identifier.artifact_version_id), paths=paths
1095
+ )
1096
+ )
1097
+ signed_urls = signed_urls_response.signed_urls
1098
+ elif artifact_identifier.dataset_fqn:
1099
+ signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_for_dataset_write_post(
1100
+ get_signed_url_for_dataset_write_request_dto=GetSignedURLForDatasetWriteRequestDto(
1101
+ dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
1102
+ )
1103
+ )
1104
+ signed_urls = signed_urls_dataset_response.signed_urls
1105
+ else:
1106
+ raise ValueError(
1107
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1108
+ )
1109
+ return signed_urls
1110
+
1111
+ def create_multipart_upload_for_identifier(
1112
+ self,
1113
+ artifact_identifier: ArtifactIdentifier,
1114
+ path,
1115
+ num_parts,
1116
+ ) -> MultiPartUploadDto:
1117
+ if artifact_identifier.artifact_version_id:
1118
+ create_multipart_response: MultiPartUploadResponseDto = self._mlfoundry_artifacts_api.create_multi_part_upload_post(
1119
+ create_multi_part_upload_request_dto=CreateMultiPartUploadRequestDto(
1120
+ artifact_version_id=str(artifact_identifier.artifact_version_id),
1121
+ path=path,
1122
+ num_parts=num_parts,
1123
+ )
1124
+ )
1125
+ multipart_upload = create_multipart_response.multipart_upload
1126
+ elif artifact_identifier.dataset_fqn:
1127
+ create_multipart_for_dataset_response = self._mlfoundry_artifacts_api.create_multipart_upload_for_dataset_post(
1128
+ create_multi_part_upload_for_dataset_request_dto=CreateMultiPartUploadForDatasetRequestDto(
1129
+ dataset_fqn=artifact_identifier.dataset_fqn,
1130
+ path=path,
1131
+ num_parts=num_parts,
1132
+ )
1133
+ )
1134
+ multipart_upload = create_multipart_for_dataset_response.multipart_upload
1135
+ else:
1136
+ raise ValueError(
1137
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1138
+ )
1139
+ return multipart_upload
1140
+
1141
+ def list_files(
1142
+ self, artifact_identifier: ArtifactIdentifier, path, page_size, page_token
1143
+ ) -> Union[ListFilesForDatasetResponseDto, ListFilesForArtifactVersionsResponseDto]:
1144
+ if artifact_identifier.dataset_fqn:
1145
+ return self._mlfoundry_artifacts_api.list_files_for_dataset_post(
1146
+ list_files_for_dataset_request_dto=ListFilesForDatasetRequestDto(
1147
+ dataset_fqn=artifact_identifier.dataset_fqn,
1148
+ path=path,
1149
+ max_results=page_size,
1150
+ page_token=page_token,
1151
+ )
1152
+ )
1153
+ else:
1154
+ return self._mlfoundry_artifacts_api.list_files_for_artifact_version_post(
1155
+ list_files_for_artifact_version_request_dto=ListFilesForArtifactVersionRequestDto(
1156
+ id=str(artifact_identifier.artifact_version_id),
1157
+ path=path,
1158
+ max_results=page_size,
1159
+ page_token=page_token,
1160
+ )
1161
+ )