truefoundry 0.5.3rc3__py3-none-any.whl → 0.5.3rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (43) hide show
  1. truefoundry/__init__.py +10 -1
  2. truefoundry/autodeploy/cli.py +2 -2
  3. truefoundry/cli/__main__.py +0 -4
  4. truefoundry/cli/util.py +12 -3
  5. truefoundry/common/auth_service_client.py +7 -4
  6. truefoundry/common/constants.py +3 -0
  7. truefoundry/common/credential_provider.py +7 -8
  8. truefoundry/common/exceptions.py +11 -7
  9. truefoundry/common/request_utils.py +96 -14
  10. truefoundry/common/servicefoundry_client.py +31 -29
  11. truefoundry/common/session.py +93 -0
  12. truefoundry/common/storage_provider_utils.py +331 -0
  13. truefoundry/common/utils.py +9 -9
  14. truefoundry/common/warnings.py +21 -0
  15. truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +8 -20
  16. truefoundry/deploy/cli/commands/deploy_command.py +4 -4
  17. truefoundry/deploy/lib/clients/servicefoundry_client.py +13 -14
  18. truefoundry/deploy/lib/dao/application.py +2 -2
  19. truefoundry/deploy/lib/dao/workspace.py +1 -1
  20. truefoundry/deploy/lib/session.py +1 -1
  21. truefoundry/deploy/v2/lib/deploy.py +2 -2
  22. truefoundry/deploy/v2/lib/deploy_workflow.py +1 -1
  23. truefoundry/deploy/v2/lib/patched_models.py +70 -4
  24. truefoundry/deploy/v2/lib/source.py +2 -1
  25. truefoundry/ml/artifact/truefoundry_artifact_repo.py +33 -297
  26. truefoundry/ml/clients/servicefoundry_client.py +36 -15
  27. truefoundry/ml/exceptions.py +2 -1
  28. truefoundry/ml/log_types/artifacts/artifact.py +3 -2
  29. truefoundry/ml/log_types/artifacts/model.py +6 -5
  30. truefoundry/ml/log_types/artifacts/utils.py +2 -2
  31. truefoundry/ml/mlfoundry_api.py +6 -38
  32. truefoundry/ml/mlfoundry_run.py +6 -15
  33. truefoundry/ml/model_framework.py +2 -1
  34. truefoundry/ml/session.py +69 -97
  35. truefoundry/workflow/remote_filesystem/tfy_signed_url_client.py +42 -9
  36. truefoundry/workflow/remote_filesystem/tfy_signed_url_fs.py +126 -7
  37. {truefoundry-0.5.3rc3.dist-info → truefoundry-0.5.3rc5.dist-info}/METADATA +1 -1
  38. {truefoundry-0.5.3rc3.dist-info → truefoundry-0.5.3rc5.dist-info}/RECORD +40 -40
  39. truefoundry/deploy/lib/auth/servicefoundry_session.py +0 -61
  40. truefoundry/ml/clients/entities.py +0 -8
  41. truefoundry/ml/clients/utils.py +0 -122
  42. {truefoundry-0.5.3rc3.dist-info → truefoundry-0.5.3rc5.dist-info}/WHEEL +0 -0
  43. {truefoundry-0.5.3rc3.dist-info → truefoundry-0.5.3rc5.dist-info}/entry_points.txt +0 -0
@@ -1,9 +1,11 @@
1
1
  import enum
2
+ import os
2
3
  import re
3
4
  import warnings
4
5
  from collections.abc import Mapping
5
6
  from typing import Literal, Optional, Union
6
7
 
8
+ from truefoundry.common.warnings import TrueFoundryDeprecationWarning
7
9
  from truefoundry.deploy.auto_gen import models
8
10
  from truefoundry.pydantic_v1 import (
9
11
  BaseModel,
@@ -71,6 +73,66 @@ resources:
71
73
  """
72
74
 
73
75
 
76
+ AUTO_DISCOVERED_REQUIREMENTS_TXT_WARNING_MESSAGE_TEMPLATE = """\
77
+ Using automatically discovered {requirements_txt_path} as the requirements file.
78
+ This auto discovery behavior is deprecated and will be removed in a future release.
79
+ Please specify the relative path of requirements file explicitly as
80
+
81
+ ```python
82
+ build=Build(
83
+ ...
84
+ build_spec=PythonBuild(
85
+ ...
86
+ requirements_path={requirements_txt_path!r}
87
+ )
88
+ )
89
+ ```
90
+
91
+ OR
92
+
93
+ ```yaml
94
+ build:
95
+ type: build
96
+ build_spec:
97
+ type: tfy-python-buildpack
98
+ requirements_path: {requirements_txt_path!r}
99
+ ...
100
+ ...
101
+ ```
102
+
103
+ or set it to None if you don't want use any requirements file.
104
+ """
105
+
106
+
107
+ def _resolve_requirements_path(
108
+ build_context_path: str,
109
+ requirements_path: Optional[str],
110
+ ) -> Optional[str]:
111
+ if requirements_path:
112
+ return requirements_path
113
+
114
+ # TODO: Deprecated behavior, phase out auto discovery in future release
115
+ possible_requirements_txt_filename = "requirements.txt"
116
+ possible_requirements_txt_path = os.path.join(
117
+ build_context_path, possible_requirements_txt_filename
118
+ )
119
+
120
+ if os.path.isfile(possible_requirements_txt_path):
121
+ requirements_txt_path = os.path.relpath(
122
+ possible_requirements_txt_path, start=build_context_path
123
+ )
124
+ warnings.warn(
125
+ AUTO_DISCOVERED_REQUIREMENTS_TXT_WARNING_MESSAGE_TEMPLATE.format(
126
+ requirements_txt_path=requirements_txt_path
127
+ ),
128
+ category=TrueFoundryDeprecationWarning,
129
+ stacklevel=2,
130
+ )
131
+ return requirements_txt_path
132
+
133
+ return None
134
+
135
+
74
136
  class CUDAVersion(str, enum.Enum):
75
137
  CUDA_11_0_CUDNN8 = "11.0-cudnn8"
76
138
  CUDA_11_1_CUDNN8 = "11.1-cudnn8"
@@ -139,7 +201,7 @@ class PythonBuild(models.PythonBuild, PatchedModelBase):
139
201
  type: Literal["tfy-python-buildpack"] = "tfy-python-buildpack"
140
202
 
141
203
  @root_validator
142
- def validate_python_version_when_cuda_version(cls, values):
204
+ def validate_values(cls, values):
143
205
  if values.get("cuda_version"):
144
206
  python_version = values.get("python_version")
145
207
  if python_version and not re.match(r"^3\.\d+$", python_version):
@@ -148,6 +210,10 @@ class PythonBuild(models.PythonBuild, PatchedModelBase):
148
210
  f"provided but got {python_version!r}. If you are adding a "
149
211
  f'patch version, please remove it (e.g. "3.9.2" should be "3.9")'
150
212
  )
213
+ _resolve_requirements_path(
214
+ build_context_path=values.get("build_context_path") or "./",
215
+ requirements_path=values.get("requirements_path"),
216
+ )
151
217
  return values
152
218
 
153
219
 
@@ -251,7 +317,7 @@ class Resources(models.Resources, PatchedModelBase):
251
317
  gpu_type=gpu_type,
252
318
  gpu_count=gpu_count,
253
319
  ),
254
- category=FutureWarning,
320
+ category=TrueFoundryDeprecationWarning,
255
321
  stacklevel=2,
256
322
  )
257
323
  elif gpu_count:
@@ -259,7 +325,7 @@ class Resources(models.Resources, PatchedModelBase):
259
325
  LEGACY_GPU_COUNT_WARNING_MESSAGE_TEMPLATE.format(
260
326
  gpu_count=gpu_count,
261
327
  ),
262
- category=FutureWarning,
328
+ category=TrueFoundryDeprecationWarning,
263
329
  stacklevel=2,
264
330
  )
265
331
  return values
@@ -295,7 +361,7 @@ class Autoscaling(ServiceAutoscaling):
295
361
  "`truefoundry.deploy.Autoscaling` is deprecated and will be removed in a future version. "
296
362
  "Please use `truefoundry.deploy.ServiceAutoscaling` instead. "
297
363
  "You can rename `Autoscaling` to `ServiceAutoscaling` in your script.",
298
- category=DeprecationWarning,
364
+ category=TrueFoundryDeprecationWarning,
299
365
  stacklevel=2,
300
366
  )
301
367
  super().__init__(**kwargs)
@@ -8,6 +8,7 @@ from typing import Callable, List, Optional
8
8
  import gitignorefile
9
9
  from tqdm import tqdm
10
10
 
11
+ from truefoundry.common.warnings import TrueFoundryDeprecationWarning
11
12
  from truefoundry.deploy import builder
12
13
  from truefoundry.deploy.auto_gen import models
13
14
  from truefoundry.deploy.builder.docker_service import (
@@ -102,7 +103,7 @@ def _get_callback_handler_to_ignore_file_path(
102
103
  warnings.warn(
103
104
  "`.sfyignore` is deprecated and will be ignored in future versions. "
104
105
  "Please rename the file to `.tfyignore`",
105
- category=DeprecationWarning,
106
+ category=TrueFoundryDeprecationWarning,
106
107
  stacklevel=2,
107
108
  )
108
109
  return gitignorefile.parse(path=ignorefile_path, base_path=source_dir)
@@ -1,5 +1,3 @@
1
- import math
2
- import mmap
3
1
  import os
4
2
  import posixpath
5
3
  import sys
@@ -14,7 +12,6 @@ from typing import (
14
12
  Dict,
15
13
  Iterator,
16
14
  List,
17
- NamedTuple,
18
15
  Optional,
19
16
  Sequence,
20
17
  Tuple,
@@ -23,7 +20,6 @@ from typing import (
23
20
  from urllib.parse import unquote
24
21
  from urllib.request import pathname2url
25
22
 
26
- import requests
27
23
  from rich.console import _is_jupyter
28
24
  from rich.progress import (
29
25
  BarColumn,
@@ -36,6 +32,17 @@ from rich.progress import (
36
32
  from tqdm.utils import CallbackIOWrapper
37
33
 
38
34
  from truefoundry.common.constants import ENV_VARS
35
+ from truefoundry.common.request_utils import (
36
+ augmented_raise_for_status,
37
+ cloud_storage_http_request,
38
+ )
39
+ from truefoundry.common.storage_provider_utils import (
40
+ MultiPartUpload,
41
+ _FileMultiPartInfo,
42
+ azure_multi_part_upload,
43
+ decide_file_parts,
44
+ s3_compatible_multipart_upload,
45
+ )
39
46
  from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
40
47
  ApiClient,
41
48
  CreateMultiPartUploadForDatasetRequestDto,
@@ -56,34 +63,11 @@ from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
56
63
  RunArtifactsApi,
57
64
  SignedURLDto,
58
65
  )
59
- from truefoundry.ml.clients.utils import (
60
- augmented_raise_for_status,
61
- cloud_storage_http_request,
62
- )
63
66
  from truefoundry.ml.exceptions import MlFoundryException
64
67
  from truefoundry.ml.logger import logger
65
68
  from truefoundry.ml.session import _get_api_client
66
69
  from truefoundry.pydantic_v1 import BaseModel, root_validator
67
70
 
68
- _MIN_BYTES_REQUIRED_FOR_MULTIPART = 100 * 1024 * 1024
69
- # GCP/S3 Maximum number of parts per upload 10,000
70
- # Maximum number of blocks in a block blob 50,000 blocks
71
- # TODO: This number is artificially limited now. Later
72
- # we will ask for parts signed URI in batches rather than in a single
73
- # API Calls:
74
- # Create Multipart Upload (Returns maximum number of parts, size limit of
75
- # a single part, upload id for s3 etc )
76
- # Get me signed uris for first 500 parts
77
- # Upload 500 parts
78
- # Get me signed uris for the next 500 parts
79
- # Upload 500 parts
80
- # ...
81
- # Finalize the Multipart upload using the finalize signed url returned
82
- # by Create Multipart Upload or get a new one.
83
- _MAX_NUM_PARTS_FOR_MULTIPART = 1000
84
- # Azure Maximum size of a block in a block blob 4000 MiB
85
- # GCP/S3 Maximum size of an individual part in a multipart upload 5 GiB
86
- _MAX_PART_SIZE_BYTES_FOR_MULTIPART = 4 * 1024 * 1024 * 1000
87
71
  _LIST_FILES_PAGE_SIZE = 500
88
72
  _GENERATE_SIGNED_URL_BATCH_SIZE = 50
89
73
  DEFAULT_PRESIGNED_URL_EXPIRY_TIME = 3600
@@ -123,21 +107,6 @@ def relative_path_to_artifact_path(path):
123
107
  return unquote(pathname2url(path))
124
108
 
125
109
 
126
- def _align_part_size_with_mmap_allocation_granularity(part_size: int) -> int:
127
- modulo = part_size % mmap.ALLOCATIONGRANULARITY
128
- if modulo == 0:
129
- return part_size
130
-
131
- part_size += mmap.ALLOCATIONGRANULARITY - modulo
132
- return part_size
133
-
134
-
135
- # Can not be less than 5 * 1024 * 1024
136
- _PART_SIZE_BYTES_FOR_MULTIPART = _align_part_size_with_mmap_allocation_granularity(
137
- 10 * 1024 * 1024
138
- )
139
-
140
-
141
110
  def bad_path_message(name):
142
111
  return (
143
112
  "Names may be treated as files in certain cases, and must not resolve to other names"
@@ -157,68 +126,6 @@ def verify_artifact_path(artifact_path):
157
126
  )
158
127
 
159
128
 
160
- class _PartNumberEtag(NamedTuple):
161
- part_number: int
162
- etag: str
163
-
164
-
165
- def _get_s3_compatible_completion_body(multi_parts: List[_PartNumberEtag]) -> str:
166
- body = "<CompleteMultipartUpload>\n"
167
- for part in multi_parts:
168
- body += " <Part>\n"
169
- body += f" <PartNumber>{part.part_number}</PartNumber>\n"
170
- body += f" <ETag>{part.etag}</ETag>\n"
171
- body += " </Part>\n"
172
- body += "</CompleteMultipartUpload>"
173
- return body
174
-
175
-
176
- def _get_azure_blob_completion_body(block_ids: List[str]) -> str:
177
- body = "<BlockList>\n"
178
- for block_id in block_ids:
179
- body += f"<Uncommitted>{block_id}</Uncommitted> "
180
- body += "</BlockList>"
181
- return body
182
-
183
-
184
- class _FileMultiPartInfo(NamedTuple):
185
- num_parts: int
186
- part_size: int
187
- file_size: int
188
-
189
-
190
- def _decide_file_parts(file_path: str) -> _FileMultiPartInfo:
191
- file_size = os.path.getsize(file_path)
192
- if (
193
- file_size < _MIN_BYTES_REQUIRED_FOR_MULTIPART
194
- or ENV_VARS.TFY_ARTIFACTS_DISABLE_MULTIPART_UPLOAD
195
- ):
196
- return _FileMultiPartInfo(1, part_size=file_size, file_size=file_size)
197
-
198
- ideal_num_parts = math.ceil(file_size / _PART_SIZE_BYTES_FOR_MULTIPART)
199
- if ideal_num_parts <= _MAX_NUM_PARTS_FOR_MULTIPART:
200
- return _FileMultiPartInfo(
201
- ideal_num_parts,
202
- part_size=_PART_SIZE_BYTES_FOR_MULTIPART,
203
- file_size=file_size,
204
- )
205
-
206
- part_size_when_using_max_parts = math.ceil(file_size / _MAX_NUM_PARTS_FOR_MULTIPART)
207
- part_size_when_using_max_parts = _align_part_size_with_mmap_allocation_granularity(
208
- part_size_when_using_max_parts
209
- )
210
- if part_size_when_using_max_parts > _MAX_PART_SIZE_BYTES_FOR_MULTIPART:
211
- raise ValueError(
212
- f"file {file_path!r} is too big for upload. Multipart chunk"
213
- f" size {part_size_when_using_max_parts} is higher"
214
- f" than {_MAX_PART_SIZE_BYTES_FOR_MULTIPART}"
215
- )
216
- num_parts = math.ceil(file_size / part_size_when_using_max_parts)
217
- return _FileMultiPartInfo(
218
- num_parts, part_size=part_size_when_using_max_parts, file_size=file_size
219
- )
220
-
221
-
222
129
  def _signed_url_upload_file(
223
130
  signed_url: SignedURLDto,
224
131
  local_file: str,
@@ -227,9 +134,12 @@ def _signed_url_upload_file(
227
134
  ):
228
135
  if os.stat(local_file).st_size == 0:
229
136
  with cloud_storage_http_request(
230
- method="put", url=signed_url.signed_url, data=""
137
+ method="put",
138
+ url=signed_url.signed_url,
139
+ data="",
140
+ exception_class=MlFoundryException, # type: ignore
231
141
  ) as response:
232
- augmented_raise_for_status(response)
142
+ augmented_raise_for_status(response, exception_class=MlFoundryException) # type: ignore
233
143
  return
234
144
 
235
145
  task_progress_bar = progress_bar.add_task(
@@ -247,9 +157,12 @@ def _signed_url_upload_file(
247
157
  # NOTE: Azure Put Blob does not support Transfer Encoding header.
248
158
  wrapped_file = CallbackIOWrapper(callback, file, "read")
249
159
  with cloud_storage_http_request(
250
- method="put", url=signed_url.signed_url, data=wrapped_file
160
+ method="put",
161
+ url=signed_url.signed_url,
162
+ data=wrapped_file,
163
+ exception_class=MlFoundryException, # type: ignore
251
164
  ) as response:
252
- augmented_raise_for_status(response)
165
+ augmented_raise_for_status(response, exception_class=MlFoundryException) # type: ignore
253
166
 
254
167
 
255
168
  def _download_file_using_http_uri(
@@ -265,9 +178,12 @@ def _download_file_using_http_uri(
265
178
  providers.
266
179
  """
267
180
  with cloud_storage_http_request(
268
- method="get", url=http_uri, stream=True
181
+ method="get",
182
+ url=http_uri,
183
+ stream=True,
184
+ exception_class=MlFoundryException, # type: ignore
269
185
  ) as response:
270
- augmented_raise_for_status(response)
186
+ augmented_raise_for_status(response, exception_class=MlFoundryException) # type: ignore
271
187
  file_size = int(response.headers.get("Content-Length", 0))
272
188
  with open(download_path, "wb") as output_file:
273
189
  for chunk in response.iter_content(chunk_size=chunk_size):
@@ -281,188 +197,6 @@ def _download_file_using_http_uri(
281
197
  os.fsync(output_file.fileno())
282
198
 
283
199
 
284
- class _CallbackIOWrapperForMultiPartUpload(CallbackIOWrapper):
285
- def __init__(self, callback, stream, method, length: int):
286
- self.wrapper_setattr("_length", length)
287
- super().__init__(callback, stream, method)
288
-
289
- def __len__(self):
290
- return self.wrapper_getattr("_length")
291
-
292
-
293
- def _file_part_upload(
294
- url: str,
295
- file_path: str,
296
- seek: int,
297
- length: int,
298
- file_size: int,
299
- abort_event: Optional[Event] = None,
300
- method: str = "put",
301
- ):
302
- def callback(*_, **__):
303
- if abort_event and abort_event.is_set():
304
- raise Exception("aborting upload")
305
-
306
- with open(file_path, "rb") as file:
307
- with mmap.mmap(
308
- file.fileno(),
309
- length=min(file_size - seek, length),
310
- offset=seek,
311
- access=mmap.ACCESS_READ,
312
- ) as mapped_file:
313
- wrapped_file = _CallbackIOWrapperForMultiPartUpload(
314
- callback, mapped_file, "read", len(mapped_file)
315
- )
316
- with cloud_storage_http_request(
317
- method=method,
318
- url=url,
319
- data=wrapped_file,
320
- ) as response:
321
- augmented_raise_for_status(response)
322
- return response
323
-
324
-
325
- def _s3_compatible_multipart_upload(
326
- multipart_upload: MultiPartUploadDto,
327
- local_file: str,
328
- multipart_info: _FileMultiPartInfo,
329
- executor: ThreadPoolExecutor,
330
- progress_bar: Progress,
331
- abort_event: Optional[Event] = None,
332
- ):
333
- abort_event = abort_event or Event()
334
- parts = []
335
-
336
- multi_part_upload_progress = progress_bar.add_task(
337
- f"[green]Uploading {local_file}:", start=True
338
- )
339
-
340
- def upload(part_number: int, seek: int) -> None:
341
- logger.debug(
342
- "Uploading part %d/%d of %s",
343
- part_number,
344
- multipart_info.num_parts,
345
- local_file,
346
- )
347
- response = _file_part_upload(
348
- url=multipart_upload.part_signed_urls[part_number].signed_url,
349
- file_path=local_file,
350
- seek=seek,
351
- length=multipart_info.part_size,
352
- file_size=multipart_info.file_size,
353
- abort_event=abort_event,
354
- )
355
- logger.debug(
356
- "Uploaded part %d/%d of %s",
357
- part_number,
358
- multipart_info.num_parts,
359
- local_file,
360
- )
361
- progress_bar.update(
362
- multi_part_upload_progress,
363
- advance=multipart_info.part_size,
364
- total=multipart_info.file_size,
365
- )
366
- etag = response.headers["ETag"]
367
- parts.append(_PartNumberEtag(etag=etag, part_number=part_number + 1))
368
-
369
- futures: List[Future] = []
370
- for part_number, seek in enumerate(
371
- range(0, multipart_info.file_size, multipart_info.part_size)
372
- ):
373
- future = executor.submit(upload, part_number=part_number, seek=seek)
374
- futures.append(future)
375
-
376
- done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
377
- if len(not_done) > 0:
378
- abort_event.set()
379
- for future in not_done:
380
- future.cancel()
381
- for future in done:
382
- if future.exception() is not None:
383
- raise future.exception()
384
-
385
- logger.debug("Finalizing multipart upload of %s", local_file)
386
- parts = sorted(parts, key=lambda part: part.part_number)
387
- response = requests.post(
388
- multipart_upload.finalize_signed_url.signed_url,
389
- data=_get_s3_compatible_completion_body(parts),
390
- timeout=2 * 60,
391
- )
392
- response.raise_for_status()
393
- logger.debug("Multipart upload of %s completed", local_file)
394
-
395
-
396
- def _azure_multi_part_upload(
397
- multipart_upload: MultiPartUploadDto,
398
- local_file: str,
399
- multipart_info: _FileMultiPartInfo,
400
- executor: ThreadPoolExecutor,
401
- progress_bar: Progress,
402
- abort_event: Optional[Event] = None,
403
- ):
404
- abort_event = abort_event or Event()
405
-
406
- multi_part_upload_progress = progress_bar.add_task(
407
- f"[green]Uploading {local_file}:", start=True
408
- )
409
-
410
- def upload(part_number: int, seek: int):
411
- logger.debug(
412
- "Uploading part %d/%d of %s",
413
- part_number,
414
- multipart_info.num_parts,
415
- local_file,
416
- )
417
- _file_part_upload(
418
- url=multipart_upload.part_signed_urls[part_number].signed_url,
419
- file_path=local_file,
420
- seek=seek,
421
- length=multipart_info.part_size,
422
- file_size=multipart_info.file_size,
423
- abort_event=abort_event,
424
- )
425
- progress_bar.update(
426
- multi_part_upload_progress,
427
- advance=multipart_info.part_size,
428
- total=multipart_info.file_size,
429
- )
430
- logger.debug(
431
- "Uploaded part %d/%d of %s",
432
- part_number,
433
- multipart_info.num_parts,
434
- local_file,
435
- )
436
-
437
- futures: List[Future] = []
438
- for part_number, seek in enumerate(
439
- range(0, multipart_info.file_size, multipart_info.part_size)
440
- ):
441
- future = executor.submit(upload, part_number=part_number, seek=seek)
442
- futures.append(future)
443
-
444
- done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
445
- if len(not_done) > 0:
446
- abort_event.set()
447
- for future in not_done:
448
- future.cancel()
449
- for future in done:
450
- if future.exception() is not None:
451
- raise future.exception()
452
-
453
- logger.debug("Finalizing multipart upload of %s", local_file)
454
- if multipart_upload.azure_blob_block_ids:
455
- response = requests.put(
456
- multipart_upload.finalize_signed_url.signed_url,
457
- data=_get_azure_blob_completion_body(
458
- block_ids=multipart_upload.azure_blob_block_ids
459
- ),
460
- timeout=2 * 60,
461
- )
462
- response.raise_for_status()
463
- logger.debug("Multipart upload of %s completed", local_file)
464
-
465
-
466
200
  def _any_future_has_failed(futures) -> bool:
467
201
  return any(
468
202
  future.done() and not future.cancelled() and future.exception() is not None
@@ -638,25 +372,27 @@ class MlFoundryArtifactsRepository:
638
372
  multipart_upload.storage_provider
639
373
  is MultiPartUploadStorageProvider.S3_COMPATIBLE
640
374
  ):
641
- _s3_compatible_multipart_upload(
642
- multipart_upload=multipart_upload,
375
+ s3_compatible_multipart_upload(
376
+ multipart_upload=MultiPartUpload.parse_obj(multipart_upload.to_dict()),
643
377
  local_file=local_file,
644
378
  executor=executor,
645
379
  multipart_info=multipart_info,
646
380
  abort_event=abort_event,
647
381
  progress_bar=progress_bar,
382
+ exception_class=MlFoundryException, # type: ignore
648
383
  )
649
384
  elif (
650
385
  multipart_upload.storage_provider
651
386
  is MultiPartUploadStorageProvider.AZURE_BLOB
652
387
  ):
653
- _azure_multi_part_upload(
654
- multipart_upload=multipart_upload,
388
+ azure_multi_part_upload(
389
+ multipart_upload=MultiPartUpload.parse_obj(multipart_upload.to_dict()),
655
390
  local_file=local_file,
656
391
  executor=executor,
657
392
  multipart_info=multipart_info,
658
393
  abort_event=abort_event,
659
394
  progress_bar=progress_bar,
395
+ exception_class=MlFoundryException, # type: ignore
660
396
  )
661
397
  else:
662
398
  raise NotImplementedError()
@@ -793,7 +529,7 @@ class MlFoundryArtifactsRepository:
793
529
  )
794
530
  upload_path = artifact_path
795
531
  upload_path = upload_path.lstrip(posixpath.sep)
796
- multipart_info = _decide_file_parts(local_file)
532
+ multipart_info = decide_file_parts(local_file)
797
533
  if multipart_info.num_parts == 1:
798
534
  files_for_normal_upload.append((upload_path, local_file, multipart_info))
799
535
  else:
@@ -1,36 +1,57 @@
1
- from typing import Optional
1
+ import functools
2
2
 
3
3
  from truefoundry.common.constants import (
4
4
  SERVICEFOUNDRY_CLIENT_MAX_RETRIES,
5
5
  VERSION_PREFIX,
6
6
  )
7
+ from truefoundry.common.exceptions import HttpRequestException
8
+ from truefoundry.common.request_utils import (
9
+ http_request,
10
+ request_handling,
11
+ requests_retry_session,
12
+ )
7
13
  from truefoundry.common.servicefoundry_client import (
8
14
  ServiceFoundryServiceClient as BaseServiceFoundryServiceClient,
9
15
  )
10
- from truefoundry.ml.clients.entities import (
11
- HostCreds,
12
- )
13
- from truefoundry.ml.clients.utils import http_request_safe
14
16
  from truefoundry.ml.exceptions import MlFoundryException
15
17
 
16
18
 
17
19
  class ServiceFoundryServiceClient(BaseServiceFoundryServiceClient):
18
- # TODO (chiragjn): Rename tracking_uri to tfy_host
19
- def __init__(self, tracking_uri: str, token: Optional[str] = None):
20
- super().__init__(base_url=tracking_uri)
21
- self.host_creds = HostCreds(host=self._api_server_url, token=token)
20
+ def __init__(self, tfy_host: str, token: str):
21
+ super().__init__(tfy_host=tfy_host)
22
+ self._token = token
23
+
24
+ @functools.cached_property
25
+ def _min_cli_version_required(self) -> str:
26
+ # TODO (chiragjn): read the mlfoundry min cli version from the config?
27
+ return self.python_sdk_config.truefoundry_cli_min_version
22
28
 
23
29
  def get_integration_from_id(self, integration_id: str):
24
30
  integration_id = integration_id or ""
25
- response = http_request_safe(
26
- host_creds=self.host_creds,
27
- endpoint=f"{VERSION_PREFIX}/provider-accounts/provider-integrations",
28
- params={"id": integration_id, "type": "blob-storage"},
31
+ session = requests_retry_session(retries=SERVICEFOUNDRY_CLIENT_MAX_RETRIES)
32
+ response = http_request(
29
33
  method="get",
34
+ url=f"{self._api_server_url}/{VERSION_PREFIX}/provider-accounts/provider-integrations",
35
+ token=self._token,
30
36
  timeout=3,
31
- max_retries=SERVICEFOUNDRY_CLIENT_MAX_RETRIES,
37
+ params={"id": integration_id, "type": "blob-storage"},
38
+ session=session,
32
39
  )
33
- data = response.json()
40
+
41
+ try:
42
+ data = request_handling(response)
43
+ assert isinstance(data, dict)
44
+ except HttpRequestException as he:
45
+ raise MlFoundryException(
46
+ f"Failed to get storage integration from id: {integration_id}. Error: {he.message}",
47
+ status_code=he.status_code,
48
+ ) from None
49
+ except Exception as e:
50
+ raise MlFoundryException(
51
+ f"Failed to get storage integration from id: {integration_id}. Error: {str(e)}"
52
+ ) from None
53
+
54
+ # TODO (chiragjn): Parse this using Pydantic
34
55
  if (
35
56
  data.get("providerIntegrations")
36
57
  and len(data["providerIntegrations"]) > 0
@@ -1,8 +1,9 @@
1
1
  from typing import Optional
2
2
 
3
3
 
4
+ # TODO (chiragjn): We need to establish uniform exception handling across codebase
4
5
  class MlFoundryException(Exception):
5
- def __init__(self, message, status_code: Optional[int] = None):
6
+ def __init__(self, message: str, status_code: Optional[int] = None):
6
7
  self.message = str(message)
7
8
  self.status_code = status_code
8
9
  super().__init__(message)
@@ -9,6 +9,7 @@ import warnings
9
9
  from pathlib import Path
10
10
  from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
11
11
 
12
+ from truefoundry.common.warnings import TrueFoundryDeprecationWarning
12
13
  from truefoundry.ml.artifact.truefoundry_artifact_repo import (
13
14
  ArtifactIdentifier,
14
15
  MlFoundryArtifactsRepository,
@@ -217,7 +218,7 @@ class ArtifactVersion:
217
218
  if not self._artifact_version.manifest:
218
219
  warnings.warn(
219
220
  message="This model version was created using an older serialization format. tags do not exist, returning empty list",
220
- category=DeprecationWarning,
221
+ category=TrueFoundryDeprecationWarning,
221
222
  stacklevel=2,
222
223
  )
223
224
  return self._tags
@@ -230,7 +231,7 @@ class ArtifactVersion:
230
231
  if not self._artifact_version.manifest:
231
232
  warnings.warn(
232
233
  message="This model version was created using an older serialization format. Tags will not be updated",
233
- category=DeprecationWarning,
234
+ category=TrueFoundryDeprecationWarning,
234
235
  stacklevel=2,
235
236
  )
236
237
  return