truefoundry 0.5.3rc4__py3-none-any.whl → 0.5.3rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (47) hide show
  1. truefoundry/__init__.py +10 -1
  2. truefoundry/autodeploy/cli.py +2 -2
  3. truefoundry/cli/__main__.py +0 -4
  4. truefoundry/cli/util.py +12 -3
  5. truefoundry/common/auth_service_client.py +7 -4
  6. truefoundry/common/constants.py +3 -1
  7. truefoundry/common/credential_provider.py +7 -8
  8. truefoundry/common/exceptions.py +11 -7
  9. truefoundry/common/request_utils.py +96 -14
  10. truefoundry/common/servicefoundry_client.py +31 -29
  11. truefoundry/common/session.py +93 -0
  12. truefoundry/common/storage_provider_utils.py +331 -0
  13. truefoundry/common/utils.py +9 -9
  14. truefoundry/common/warnings.py +21 -0
  15. truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +8 -20
  16. truefoundry/deploy/cli/commands/deploy_command.py +4 -4
  17. truefoundry/deploy/lib/clients/servicefoundry_client.py +13 -14
  18. truefoundry/deploy/lib/dao/application.py +2 -2
  19. truefoundry/deploy/lib/dao/workspace.py +1 -1
  20. truefoundry/deploy/lib/session.py +1 -1
  21. truefoundry/deploy/v2/lib/deploy.py +2 -2
  22. truefoundry/deploy/v2/lib/deploy_workflow.py +1 -1
  23. truefoundry/deploy/v2/lib/patched_models.py +70 -4
  24. truefoundry/deploy/v2/lib/source.py +2 -1
  25. truefoundry/gateway/cli/cli.py +1 -22
  26. truefoundry/gateway/lib/entities.py +3 -8
  27. truefoundry/gateway/lib/models.py +0 -38
  28. truefoundry/ml/artifact/truefoundry_artifact_repo.py +33 -297
  29. truefoundry/ml/clients/servicefoundry_client.py +36 -15
  30. truefoundry/ml/exceptions.py +2 -1
  31. truefoundry/ml/log_types/artifacts/artifact.py +3 -2
  32. truefoundry/ml/log_types/artifacts/model.py +6 -5
  33. truefoundry/ml/log_types/artifacts/utils.py +2 -2
  34. truefoundry/ml/mlfoundry_api.py +6 -38
  35. truefoundry/ml/mlfoundry_run.py +6 -15
  36. truefoundry/ml/model_framework.py +2 -1
  37. truefoundry/ml/session.py +69 -97
  38. truefoundry/workflow/remote_filesystem/tfy_signed_url_client.py +42 -9
  39. truefoundry/workflow/remote_filesystem/tfy_signed_url_fs.py +126 -7
  40. {truefoundry-0.5.3rc4.dist-info → truefoundry-0.5.3rc5.dist-info}/METADATA +1 -1
  41. {truefoundry-0.5.3rc4.dist-info → truefoundry-0.5.3rc5.dist-info}/RECORD +43 -44
  42. truefoundry/deploy/lib/auth/servicefoundry_session.py +0 -61
  43. truefoundry/gateway/lib/client.py +0 -51
  44. truefoundry/ml/clients/entities.py +0 -8
  45. truefoundry/ml/clients/utils.py +0 -122
  46. {truefoundry-0.5.3rc4.dist-info → truefoundry-0.5.3rc5.dist-info}/WHEEL +0 -0
  47. {truefoundry-0.5.3rc4.dist-info → truefoundry-0.5.3rc5.dist-info}/entry_points.txt +0 -0
@@ -1,9 +1,11 @@
1
1
  import enum
2
+ import os
2
3
  import re
3
4
  import warnings
4
5
  from collections.abc import Mapping
5
6
  from typing import Literal, Optional, Union
6
7
 
8
+ from truefoundry.common.warnings import TrueFoundryDeprecationWarning
7
9
  from truefoundry.deploy.auto_gen import models
8
10
  from truefoundry.pydantic_v1 import (
9
11
  BaseModel,
@@ -71,6 +73,66 @@ resources:
71
73
  """
72
74
 
73
75
 
76
+ AUTO_DISCOVERED_REQUIREMENTS_TXT_WARNING_MESSAGE_TEMPLATE = """\
77
+ Using automatically discovered {requirements_txt_path} as the requirements file.
78
+ This auto discovery behavior is deprecated and will be removed in a future release.
79
+ Please specify the relative path of requirements file explicitly as
80
+
81
+ ```python
82
+ build=Build(
83
+ ...
84
+ build_spec=PythonBuild(
85
+ ...
86
+ requirements_path={requirements_txt_path!r}
87
+ )
88
+ )
89
+ ```
90
+
91
+ OR
92
+
93
+ ```yaml
94
+ build:
95
+ type: build
96
+ build_spec:
97
+ type: tfy-python-buildpack
98
+ requirements_path: {requirements_txt_path!r}
99
+ ...
100
+ ...
101
+ ```
102
+
103
+ or set it to None if you don't want use any requirements file.
104
+ """
105
+
106
+
107
+ def _resolve_requirements_path(
108
+ build_context_path: str,
109
+ requirements_path: Optional[str],
110
+ ) -> Optional[str]:
111
+ if requirements_path:
112
+ return requirements_path
113
+
114
+ # TODO: Deprecated behavior, phase out auto discovery in future release
115
+ possible_requirements_txt_filename = "requirements.txt"
116
+ possible_requirements_txt_path = os.path.join(
117
+ build_context_path, possible_requirements_txt_filename
118
+ )
119
+
120
+ if os.path.isfile(possible_requirements_txt_path):
121
+ requirements_txt_path = os.path.relpath(
122
+ possible_requirements_txt_path, start=build_context_path
123
+ )
124
+ warnings.warn(
125
+ AUTO_DISCOVERED_REQUIREMENTS_TXT_WARNING_MESSAGE_TEMPLATE.format(
126
+ requirements_txt_path=requirements_txt_path
127
+ ),
128
+ category=TrueFoundryDeprecationWarning,
129
+ stacklevel=2,
130
+ )
131
+ return requirements_txt_path
132
+
133
+ return None
134
+
135
+
74
136
  class CUDAVersion(str, enum.Enum):
75
137
  CUDA_11_0_CUDNN8 = "11.0-cudnn8"
76
138
  CUDA_11_1_CUDNN8 = "11.1-cudnn8"
@@ -139,7 +201,7 @@ class PythonBuild(models.PythonBuild, PatchedModelBase):
139
201
  type: Literal["tfy-python-buildpack"] = "tfy-python-buildpack"
140
202
 
141
203
  @root_validator
142
- def validate_python_version_when_cuda_version(cls, values):
204
+ def validate_values(cls, values):
143
205
  if values.get("cuda_version"):
144
206
  python_version = values.get("python_version")
145
207
  if python_version and not re.match(r"^3\.\d+$", python_version):
@@ -148,6 +210,10 @@ class PythonBuild(models.PythonBuild, PatchedModelBase):
148
210
  f"provided but got {python_version!r}. If you are adding a "
149
211
  f'patch version, please remove it (e.g. "3.9.2" should be "3.9")'
150
212
  )
213
+ _resolve_requirements_path(
214
+ build_context_path=values.get("build_context_path") or "./",
215
+ requirements_path=values.get("requirements_path"),
216
+ )
151
217
  return values
152
218
 
153
219
 
@@ -251,7 +317,7 @@ class Resources(models.Resources, PatchedModelBase):
251
317
  gpu_type=gpu_type,
252
318
  gpu_count=gpu_count,
253
319
  ),
254
- category=FutureWarning,
320
+ category=TrueFoundryDeprecationWarning,
255
321
  stacklevel=2,
256
322
  )
257
323
  elif gpu_count:
@@ -259,7 +325,7 @@ class Resources(models.Resources, PatchedModelBase):
259
325
  LEGACY_GPU_COUNT_WARNING_MESSAGE_TEMPLATE.format(
260
326
  gpu_count=gpu_count,
261
327
  ),
262
- category=FutureWarning,
328
+ category=TrueFoundryDeprecationWarning,
263
329
  stacklevel=2,
264
330
  )
265
331
  return values
@@ -295,7 +361,7 @@ class Autoscaling(ServiceAutoscaling):
295
361
  "`truefoundry.deploy.Autoscaling` is deprecated and will be removed in a future version. "
296
362
  "Please use `truefoundry.deploy.ServiceAutoscaling` instead. "
297
363
  "You can rename `Autoscaling` to `ServiceAutoscaling` in your script.",
298
- category=DeprecationWarning,
364
+ category=TrueFoundryDeprecationWarning,
299
365
  stacklevel=2,
300
366
  )
301
367
  super().__init__(**kwargs)
@@ -8,6 +8,7 @@ from typing import Callable, List, Optional
8
8
  import gitignorefile
9
9
  from tqdm import tqdm
10
10
 
11
+ from truefoundry.common.warnings import TrueFoundryDeprecationWarning
11
12
  from truefoundry.deploy import builder
12
13
  from truefoundry.deploy.auto_gen import models
13
14
  from truefoundry.deploy.builder.docker_service import (
@@ -102,7 +103,7 @@ def _get_callback_handler_to_ignore_file_path(
102
103
  warnings.warn(
103
104
  "`.sfyignore` is deprecated and will be ignored in future versions. "
104
105
  "Please rename the file to `.tfyignore`",
105
- category=DeprecationWarning,
106
+ category=TrueFoundryDeprecationWarning,
106
107
  stacklevel=2,
107
108
  )
108
109
  return gitignorefile.parse(path=ignorefile_path, base_path=source_dir)
@@ -2,7 +2,7 @@ import click
2
2
 
3
3
  from truefoundry.cli.const import COMMAND_CLS, GROUP_CLS
4
4
  from truefoundry.cli.display_util import print_entity_list
5
- from truefoundry.gateway.lib.models import generate_code_for_model, list_models
5
+ from truefoundry.gateway.lib.models import list_models
6
6
 
7
7
 
8
8
  def get_gateway_cli():
@@ -27,25 +27,4 @@ def get_gateway_cli():
27
27
  enabled_models = list_models(model_type)
28
28
  print_entity_list("Models", enabled_models)
29
29
 
30
- @gateway.command("generate-code", cls=COMMAND_CLS, help="Generate code for a model")
31
- @click.argument("model_id")
32
- @click.option(
33
- "--inference-type",
34
- type=click.Choice(["chat", "completion", "embedding"]),
35
- default="chat",
36
- help="Type of inference to generate code for",
37
- )
38
- @click.option(
39
- "--client",
40
- type=click.Choice(["openai", "rest", "langchain", "stream", "node", "curl"]),
41
- default="curl",
42
- help="Language/framework to generate code for",
43
- )
44
- def generate_code_cli(model_id: str, inference_type: str, client: str):
45
- """Generate code for a model"""
46
- code = generate_code_for_model(
47
- model_id, client=client, inference_type=inference_type
48
- )
49
- print(code)
50
-
51
30
  return gateway
@@ -17,15 +17,10 @@ class GatewayModel(BaseModel):
17
17
  model_fqn: str
18
18
 
19
19
  def list_row_data(self) -> Dict[str, Any]:
20
- model_display = self.model_fqn
21
- provider_display = self.provider
22
- if self.model_id:
23
- provider_display = f"{self.provider} ({self.model_id})"
24
-
25
20
  return {
26
- "model": model_display,
27
- "provider": provider_display,
28
- "type": self.types if isinstance(self.types, str) else ", ".join(self.types)
21
+ "model_id": self.model_fqn,
22
+ "provider": self.provider,
23
+ "provider_model_id": self.model_id,
29
24
  }
30
25
 
31
26
 
@@ -3,7 +3,6 @@ from typing import List, Literal, Optional
3
3
  from truefoundry.deploy.lib.clients.servicefoundry_client import (
4
4
  ServiceFoundryServiceClient,
5
5
  )
6
- from truefoundry.gateway.lib.client import GatewayServiceClient
7
6
  from truefoundry.gateway.lib.entities import GatewayModel
8
7
 
9
8
 
@@ -28,40 +27,3 @@ def list_models(
28
27
  enabled_models.append(model)
29
28
 
30
29
  return enabled_models
31
-
32
-
33
- def generate_code_for_model(
34
- model_id: str,
35
- client: Literal["openai", "rest", "langchain", "stream", "node", "curl"] = "curl",
36
- inference_type: Literal["chat", "completion", "embedding"] = "chat",
37
- ) -> str:
38
- """Generate code snippet for using a model in the specified language/framework
39
-
40
- Args:
41
- model_id (str): ID of the model to generate code for
42
- language (Literal["openai", "rest", "langchain", "stream", "node", "curl"]): Language/framework to generate code for. Defaults to "curl"
43
- inference_type (Literal["chat", "completion", "embedding"]): Type of inference to generate code for. Defaults to "chat"
44
-
45
- Returns:
46
- str: Code snippet for using the model in the specified language/framework
47
- """
48
- gateway_client = GatewayServiceClient()
49
- response = gateway_client.generate_code(model_id, inference_type)
50
-
51
- code_map = {
52
- "openai": ("openai_code", "Python code using OpenAI SDK for direct API calls"),
53
- "rest": ("rest_code", "Python code using requests library for REST API calls"),
54
- "langchain": (
55
- "langchain_code",
56
- "Python code using LangChain framework for LLM integration",
57
- ),
58
- "stream": ("stream_code", "Python code with streaming response handling"),
59
- "node": ("node_code", "Node.js code using Axios for API calls"),
60
- "curl": ("curl_code", "cURL command for direct API access via terminal"),
61
- }
62
-
63
- code_key, description = code_map[client]
64
- if code_key in response and response[code_key]:
65
- return f"{description}\n{response[code_key]}"
66
-
67
- return "No code snippet available for the specified language"
@@ -1,5 +1,3 @@
1
- import math
2
- import mmap
3
1
  import os
4
2
  import posixpath
5
3
  import sys
@@ -14,7 +12,6 @@ from typing import (
14
12
  Dict,
15
13
  Iterator,
16
14
  List,
17
- NamedTuple,
18
15
  Optional,
19
16
  Sequence,
20
17
  Tuple,
@@ -23,7 +20,6 @@ from typing import (
23
20
  from urllib.parse import unquote
24
21
  from urllib.request import pathname2url
25
22
 
26
- import requests
27
23
  from rich.console import _is_jupyter
28
24
  from rich.progress import (
29
25
  BarColumn,
@@ -36,6 +32,17 @@ from rich.progress import (
36
32
  from tqdm.utils import CallbackIOWrapper
37
33
 
38
34
  from truefoundry.common.constants import ENV_VARS
35
+ from truefoundry.common.request_utils import (
36
+ augmented_raise_for_status,
37
+ cloud_storage_http_request,
38
+ )
39
+ from truefoundry.common.storage_provider_utils import (
40
+ MultiPartUpload,
41
+ _FileMultiPartInfo,
42
+ azure_multi_part_upload,
43
+ decide_file_parts,
44
+ s3_compatible_multipart_upload,
45
+ )
39
46
  from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
40
47
  ApiClient,
41
48
  CreateMultiPartUploadForDatasetRequestDto,
@@ -56,34 +63,11 @@ from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
56
63
  RunArtifactsApi,
57
64
  SignedURLDto,
58
65
  )
59
- from truefoundry.ml.clients.utils import (
60
- augmented_raise_for_status,
61
- cloud_storage_http_request,
62
- )
63
66
  from truefoundry.ml.exceptions import MlFoundryException
64
67
  from truefoundry.ml.logger import logger
65
68
  from truefoundry.ml.session import _get_api_client
66
69
  from truefoundry.pydantic_v1 import BaseModel, root_validator
67
70
 
68
- _MIN_BYTES_REQUIRED_FOR_MULTIPART = 100 * 1024 * 1024
69
- # GCP/S3 Maximum number of parts per upload 10,000
70
- # Maximum number of blocks in a block blob 50,000 blocks
71
- # TODO: This number is artificially limited now. Later
72
- # we will ask for parts signed URI in batches rather than in a single
73
- # API Calls:
74
- # Create Multipart Upload (Returns maximum number of parts, size limit of
75
- # a single part, upload id for s3 etc )
76
- # Get me signed uris for first 500 parts
77
- # Upload 500 parts
78
- # Get me signed uris for the next 500 parts
79
- # Upload 500 parts
80
- # ...
81
- # Finalize the Multipart upload using the finalize signed url returned
82
- # by Create Multipart Upload or get a new one.
83
- _MAX_NUM_PARTS_FOR_MULTIPART = 1000
84
- # Azure Maximum size of a block in a block blob 4000 MiB
85
- # GCP/S3 Maximum size of an individual part in a multipart upload 5 GiB
86
- _MAX_PART_SIZE_BYTES_FOR_MULTIPART = 4 * 1024 * 1024 * 1000
87
71
  _LIST_FILES_PAGE_SIZE = 500
88
72
  _GENERATE_SIGNED_URL_BATCH_SIZE = 50
89
73
  DEFAULT_PRESIGNED_URL_EXPIRY_TIME = 3600
@@ -123,21 +107,6 @@ def relative_path_to_artifact_path(path):
123
107
  return unquote(pathname2url(path))
124
108
 
125
109
 
126
- def _align_part_size_with_mmap_allocation_granularity(part_size: int) -> int:
127
- modulo = part_size % mmap.ALLOCATIONGRANULARITY
128
- if modulo == 0:
129
- return part_size
130
-
131
- part_size += mmap.ALLOCATIONGRANULARITY - modulo
132
- return part_size
133
-
134
-
135
- # Can not be less than 5 * 1024 * 1024
136
- _PART_SIZE_BYTES_FOR_MULTIPART = _align_part_size_with_mmap_allocation_granularity(
137
- 10 * 1024 * 1024
138
- )
139
-
140
-
141
110
  def bad_path_message(name):
142
111
  return (
143
112
  "Names may be treated as files in certain cases, and must not resolve to other names"
@@ -157,68 +126,6 @@ def verify_artifact_path(artifact_path):
157
126
  )
158
127
 
159
128
 
160
- class _PartNumberEtag(NamedTuple):
161
- part_number: int
162
- etag: str
163
-
164
-
165
- def _get_s3_compatible_completion_body(multi_parts: List[_PartNumberEtag]) -> str:
166
- body = "<CompleteMultipartUpload>\n"
167
- for part in multi_parts:
168
- body += " <Part>\n"
169
- body += f" <PartNumber>{part.part_number}</PartNumber>\n"
170
- body += f" <ETag>{part.etag}</ETag>\n"
171
- body += " </Part>\n"
172
- body += "</CompleteMultipartUpload>"
173
- return body
174
-
175
-
176
- def _get_azure_blob_completion_body(block_ids: List[str]) -> str:
177
- body = "<BlockList>\n"
178
- for block_id in block_ids:
179
- body += f"<Uncommitted>{block_id}</Uncommitted> "
180
- body += "</BlockList>"
181
- return body
182
-
183
-
184
- class _FileMultiPartInfo(NamedTuple):
185
- num_parts: int
186
- part_size: int
187
- file_size: int
188
-
189
-
190
- def _decide_file_parts(file_path: str) -> _FileMultiPartInfo:
191
- file_size = os.path.getsize(file_path)
192
- if (
193
- file_size < _MIN_BYTES_REQUIRED_FOR_MULTIPART
194
- or ENV_VARS.TFY_ARTIFACTS_DISABLE_MULTIPART_UPLOAD
195
- ):
196
- return _FileMultiPartInfo(1, part_size=file_size, file_size=file_size)
197
-
198
- ideal_num_parts = math.ceil(file_size / _PART_SIZE_BYTES_FOR_MULTIPART)
199
- if ideal_num_parts <= _MAX_NUM_PARTS_FOR_MULTIPART:
200
- return _FileMultiPartInfo(
201
- ideal_num_parts,
202
- part_size=_PART_SIZE_BYTES_FOR_MULTIPART,
203
- file_size=file_size,
204
- )
205
-
206
- part_size_when_using_max_parts = math.ceil(file_size / _MAX_NUM_PARTS_FOR_MULTIPART)
207
- part_size_when_using_max_parts = _align_part_size_with_mmap_allocation_granularity(
208
- part_size_when_using_max_parts
209
- )
210
- if part_size_when_using_max_parts > _MAX_PART_SIZE_BYTES_FOR_MULTIPART:
211
- raise ValueError(
212
- f"file {file_path!r} is too big for upload. Multipart chunk"
213
- f" size {part_size_when_using_max_parts} is higher"
214
- f" than {_MAX_PART_SIZE_BYTES_FOR_MULTIPART}"
215
- )
216
- num_parts = math.ceil(file_size / part_size_when_using_max_parts)
217
- return _FileMultiPartInfo(
218
- num_parts, part_size=part_size_when_using_max_parts, file_size=file_size
219
- )
220
-
221
-
222
129
  def _signed_url_upload_file(
223
130
  signed_url: SignedURLDto,
224
131
  local_file: str,
@@ -227,9 +134,12 @@ def _signed_url_upload_file(
227
134
  ):
228
135
  if os.stat(local_file).st_size == 0:
229
136
  with cloud_storage_http_request(
230
- method="put", url=signed_url.signed_url, data=""
137
+ method="put",
138
+ url=signed_url.signed_url,
139
+ data="",
140
+ exception_class=MlFoundryException, # type: ignore
231
141
  ) as response:
232
- augmented_raise_for_status(response)
142
+ augmented_raise_for_status(response, exception_class=MlFoundryException) # type: ignore
233
143
  return
234
144
 
235
145
  task_progress_bar = progress_bar.add_task(
@@ -247,9 +157,12 @@ def _signed_url_upload_file(
247
157
  # NOTE: Azure Put Blob does not support Transfer Encoding header.
248
158
  wrapped_file = CallbackIOWrapper(callback, file, "read")
249
159
  with cloud_storage_http_request(
250
- method="put", url=signed_url.signed_url, data=wrapped_file
160
+ method="put",
161
+ url=signed_url.signed_url,
162
+ data=wrapped_file,
163
+ exception_class=MlFoundryException, # type: ignore
251
164
  ) as response:
252
- augmented_raise_for_status(response)
165
+ augmented_raise_for_status(response, exception_class=MlFoundryException) # type: ignore
253
166
 
254
167
 
255
168
  def _download_file_using_http_uri(
@@ -265,9 +178,12 @@ def _download_file_using_http_uri(
265
178
  providers.
266
179
  """
267
180
  with cloud_storage_http_request(
268
- method="get", url=http_uri, stream=True
181
+ method="get",
182
+ url=http_uri,
183
+ stream=True,
184
+ exception_class=MlFoundryException, # type: ignore
269
185
  ) as response:
270
- augmented_raise_for_status(response)
186
+ augmented_raise_for_status(response, exception_class=MlFoundryException) # type: ignore
271
187
  file_size = int(response.headers.get("Content-Length", 0))
272
188
  with open(download_path, "wb") as output_file:
273
189
  for chunk in response.iter_content(chunk_size=chunk_size):
@@ -281,188 +197,6 @@ def _download_file_using_http_uri(
281
197
  os.fsync(output_file.fileno())
282
198
 
283
199
 
284
- class _CallbackIOWrapperForMultiPartUpload(CallbackIOWrapper):
285
- def __init__(self, callback, stream, method, length: int):
286
- self.wrapper_setattr("_length", length)
287
- super().__init__(callback, stream, method)
288
-
289
- def __len__(self):
290
- return self.wrapper_getattr("_length")
291
-
292
-
293
- def _file_part_upload(
294
- url: str,
295
- file_path: str,
296
- seek: int,
297
- length: int,
298
- file_size: int,
299
- abort_event: Optional[Event] = None,
300
- method: str = "put",
301
- ):
302
- def callback(*_, **__):
303
- if abort_event and abort_event.is_set():
304
- raise Exception("aborting upload")
305
-
306
- with open(file_path, "rb") as file:
307
- with mmap.mmap(
308
- file.fileno(),
309
- length=min(file_size - seek, length),
310
- offset=seek,
311
- access=mmap.ACCESS_READ,
312
- ) as mapped_file:
313
- wrapped_file = _CallbackIOWrapperForMultiPartUpload(
314
- callback, mapped_file, "read", len(mapped_file)
315
- )
316
- with cloud_storage_http_request(
317
- method=method,
318
- url=url,
319
- data=wrapped_file,
320
- ) as response:
321
- augmented_raise_for_status(response)
322
- return response
323
-
324
-
325
- def _s3_compatible_multipart_upload(
326
- multipart_upload: MultiPartUploadDto,
327
- local_file: str,
328
- multipart_info: _FileMultiPartInfo,
329
- executor: ThreadPoolExecutor,
330
- progress_bar: Progress,
331
- abort_event: Optional[Event] = None,
332
- ):
333
- abort_event = abort_event or Event()
334
- parts = []
335
-
336
- multi_part_upload_progress = progress_bar.add_task(
337
- f"[green]Uploading {local_file}:", start=True
338
- )
339
-
340
- def upload(part_number: int, seek: int) -> None:
341
- logger.debug(
342
- "Uploading part %d/%d of %s",
343
- part_number,
344
- multipart_info.num_parts,
345
- local_file,
346
- )
347
- response = _file_part_upload(
348
- url=multipart_upload.part_signed_urls[part_number].signed_url,
349
- file_path=local_file,
350
- seek=seek,
351
- length=multipart_info.part_size,
352
- file_size=multipart_info.file_size,
353
- abort_event=abort_event,
354
- )
355
- logger.debug(
356
- "Uploaded part %d/%d of %s",
357
- part_number,
358
- multipart_info.num_parts,
359
- local_file,
360
- )
361
- progress_bar.update(
362
- multi_part_upload_progress,
363
- advance=multipart_info.part_size,
364
- total=multipart_info.file_size,
365
- )
366
- etag = response.headers["ETag"]
367
- parts.append(_PartNumberEtag(etag=etag, part_number=part_number + 1))
368
-
369
- futures: List[Future] = []
370
- for part_number, seek in enumerate(
371
- range(0, multipart_info.file_size, multipart_info.part_size)
372
- ):
373
- future = executor.submit(upload, part_number=part_number, seek=seek)
374
- futures.append(future)
375
-
376
- done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
377
- if len(not_done) > 0:
378
- abort_event.set()
379
- for future in not_done:
380
- future.cancel()
381
- for future in done:
382
- if future.exception() is not None:
383
- raise future.exception()
384
-
385
- logger.debug("Finalizing multipart upload of %s", local_file)
386
- parts = sorted(parts, key=lambda part: part.part_number)
387
- response = requests.post(
388
- multipart_upload.finalize_signed_url.signed_url,
389
- data=_get_s3_compatible_completion_body(parts),
390
- timeout=2 * 60,
391
- )
392
- response.raise_for_status()
393
- logger.debug("Multipart upload of %s completed", local_file)
394
-
395
-
396
- def _azure_multi_part_upload(
397
- multipart_upload: MultiPartUploadDto,
398
- local_file: str,
399
- multipart_info: _FileMultiPartInfo,
400
- executor: ThreadPoolExecutor,
401
- progress_bar: Progress,
402
- abort_event: Optional[Event] = None,
403
- ):
404
- abort_event = abort_event or Event()
405
-
406
- multi_part_upload_progress = progress_bar.add_task(
407
- f"[green]Uploading {local_file}:", start=True
408
- )
409
-
410
- def upload(part_number: int, seek: int):
411
- logger.debug(
412
- "Uploading part %d/%d of %s",
413
- part_number,
414
- multipart_info.num_parts,
415
- local_file,
416
- )
417
- _file_part_upload(
418
- url=multipart_upload.part_signed_urls[part_number].signed_url,
419
- file_path=local_file,
420
- seek=seek,
421
- length=multipart_info.part_size,
422
- file_size=multipart_info.file_size,
423
- abort_event=abort_event,
424
- )
425
- progress_bar.update(
426
- multi_part_upload_progress,
427
- advance=multipart_info.part_size,
428
- total=multipart_info.file_size,
429
- )
430
- logger.debug(
431
- "Uploaded part %d/%d of %s",
432
- part_number,
433
- multipart_info.num_parts,
434
- local_file,
435
- )
436
-
437
- futures: List[Future] = []
438
- for part_number, seek in enumerate(
439
- range(0, multipart_info.file_size, multipart_info.part_size)
440
- ):
441
- future = executor.submit(upload, part_number=part_number, seek=seek)
442
- futures.append(future)
443
-
444
- done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
445
- if len(not_done) > 0:
446
- abort_event.set()
447
- for future in not_done:
448
- future.cancel()
449
- for future in done:
450
- if future.exception() is not None:
451
- raise future.exception()
452
-
453
- logger.debug("Finalizing multipart upload of %s", local_file)
454
- if multipart_upload.azure_blob_block_ids:
455
- response = requests.put(
456
- multipart_upload.finalize_signed_url.signed_url,
457
- data=_get_azure_blob_completion_body(
458
- block_ids=multipart_upload.azure_blob_block_ids
459
- ),
460
- timeout=2 * 60,
461
- )
462
- response.raise_for_status()
463
- logger.debug("Multipart upload of %s completed", local_file)
464
-
465
-
466
200
  def _any_future_has_failed(futures) -> bool:
467
201
  return any(
468
202
  future.done() and not future.cancelled() and future.exception() is not None
@@ -638,25 +372,27 @@ class MlFoundryArtifactsRepository:
638
372
  multipart_upload.storage_provider
639
373
  is MultiPartUploadStorageProvider.S3_COMPATIBLE
640
374
  ):
641
- _s3_compatible_multipart_upload(
642
- multipart_upload=multipart_upload,
375
+ s3_compatible_multipart_upload(
376
+ multipart_upload=MultiPartUpload.parse_obj(multipart_upload.to_dict()),
643
377
  local_file=local_file,
644
378
  executor=executor,
645
379
  multipart_info=multipart_info,
646
380
  abort_event=abort_event,
647
381
  progress_bar=progress_bar,
382
+ exception_class=MlFoundryException, # type: ignore
648
383
  )
649
384
  elif (
650
385
  multipart_upload.storage_provider
651
386
  is MultiPartUploadStorageProvider.AZURE_BLOB
652
387
  ):
653
- _azure_multi_part_upload(
654
- multipart_upload=multipart_upload,
388
+ azure_multi_part_upload(
389
+ multipart_upload=MultiPartUpload.parse_obj(multipart_upload.to_dict()),
655
390
  local_file=local_file,
656
391
  executor=executor,
657
392
  multipart_info=multipart_info,
658
393
  abort_event=abort_event,
659
394
  progress_bar=progress_bar,
395
+ exception_class=MlFoundryException, # type: ignore
660
396
  )
661
397
  else:
662
398
  raise NotImplementedError()
@@ -793,7 +529,7 @@ class MlFoundryArtifactsRepository:
793
529
  )
794
530
  upload_path = artifact_path
795
531
  upload_path = upload_path.lstrip(posixpath.sep)
796
- multipart_info = _decide_file_parts(local_file)
532
+ multipart_info = decide_file_parts(local_file)
797
533
  if multipart_info.num_parts == 1:
798
534
  files_for_normal_upload.append((upload_path, local_file, multipart_info))
799
535
  else: