truefoundry 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (59) hide show
  1. truefoundry/common/constants.py +36 -2
  2. truefoundry/common/credential_provider.py +4 -2
  3. truefoundry/common/request_utils.py +1 -1
  4. truefoundry/common/servicefoundry_client.py +4 -2
  5. truefoundry/common/tfy_signed_url_client.py +260 -0
  6. truefoundry/common/tfy_signed_url_fs.py +244 -0
  7. truefoundry/common/utils.py +18 -5
  8. truefoundry/deploy/auto_gen/models.py +39 -4
  9. truefoundry/deploy/builder/builders/tfy_notebook_buildpack/dockerfile_template.py +1 -1
  10. truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +2 -4
  11. truefoundry/deploy/lib/clients/servicefoundry_client.py +2 -1
  12. truefoundry/deploy/lib/model/entity.py +0 -4
  13. truefoundry/deploy/python_deploy_codegen.py +79 -7
  14. truefoundry/ml/artifact/truefoundry_artifact_repo.py +448 -424
  15. truefoundry/ml/autogen/client/__init__.py +24 -3
  16. truefoundry/ml/autogen/client/api/experiments_api.py +0 -137
  17. truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +2 -0
  18. truefoundry/ml/autogen/client/models/__init__.py +24 -3
  19. truefoundry/ml/autogen/client/models/artifact_dto.py +9 -0
  20. truefoundry/ml/autogen/client/models/artifact_version_dto.py +26 -0
  21. truefoundry/ml/autogen/client/models/artifact_version_serialization_format.py +34 -0
  22. truefoundry/ml/autogen/client/models/create_artifact_version_response_dto.py +8 -2
  23. truefoundry/ml/autogen/client/models/create_run_request_dto.py +1 -10
  24. truefoundry/ml/autogen/client/models/dataset_dto.py +9 -0
  25. truefoundry/ml/autogen/client/models/experiment_dto.py +14 -3
  26. truefoundry/ml/autogen/client/models/external_model_source.py +79 -0
  27. truefoundry/ml/autogen/client/models/finalize_artifact_version_request_dto.py +11 -0
  28. truefoundry/ml/autogen/client/models/framework.py +154 -0
  29. truefoundry/ml/autogen/client/models/library_name.py +35 -0
  30. truefoundry/ml/autogen/client/models/model_dto.py +9 -0
  31. truefoundry/ml/autogen/client/models/model_version_dto.py +26 -0
  32. truefoundry/ml/autogen/client/models/model_version_manifest.py +119 -0
  33. truefoundry/ml/autogen/client/models/run_info_dto.py +10 -1
  34. truefoundry/ml/autogen/client/models/source.py +177 -0
  35. truefoundry/ml/autogen/client/models/subject.py +79 -0
  36. truefoundry/ml/autogen/client/models/subject_type.py +34 -0
  37. truefoundry/ml/autogen/client/models/tensorflow_framework.py +74 -0
  38. truefoundry/ml/autogen/client/models/transformers_framework.py +90 -0
  39. truefoundry/ml/autogen/client/models/truefoundry_model_source.py +79 -0
  40. truefoundry/ml/autogen/client/models/update_model_version_request_dto.py +11 -0
  41. truefoundry/ml/autogen/client/models/upload_model_source.py +74 -0
  42. truefoundry/ml/autogen/client_README.md +12 -2
  43. truefoundry/ml/autogen/entities/artifacts.py +236 -4
  44. truefoundry/ml/log_types/artifacts/artifact.py +10 -11
  45. truefoundry/ml/log_types/artifacts/dataset.py +13 -10
  46. truefoundry/ml/log_types/artifacts/general_artifact.py +3 -1
  47. truefoundry/ml/log_types/artifacts/model.py +18 -35
  48. truefoundry/ml/log_types/artifacts/utils.py +42 -25
  49. truefoundry/ml/log_types/image/image.py +2 -0
  50. truefoundry/ml/log_types/plot.py +2 -0
  51. truefoundry/ml/mlfoundry_api.py +7 -3
  52. truefoundry/ml/session.py +3 -1
  53. truefoundry/workflow/__init__.py +10 -0
  54. {truefoundry-0.4.3.dist-info → truefoundry-0.4.4.dist-info}/METADATA +1 -1
  55. {truefoundry-0.4.3.dist-info → truefoundry-0.4.4.dist-info}/RECORD +57 -45
  56. truefoundry/ml/autogen/client/models/list_seed_experiments_response_dto.py +0 -81
  57. truefoundry/ml/env_vars.py +0 -9
  58. {truefoundry-0.4.3.dist-info → truefoundry-0.4.4.dist-info}/WHEEL +0 -0
  59. {truefoundry-0.4.3.dist-info → truefoundry-0.4.4.dist-info}/entry_points.txt +0 -0
@@ -16,6 +16,7 @@ from typing import (
16
16
  List,
17
17
  NamedTuple,
18
18
  Optional,
19
+ Sequence,
19
20
  Tuple,
20
21
  Union,
21
22
  )
@@ -34,6 +35,7 @@ from rich.progress import (
34
35
  )
35
36
  from tqdm.utils import CallbackIOWrapper
36
37
 
38
+ from truefoundry.common.constants import ENV_VARS
37
39
  from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
38
40
  ApiClient,
39
41
  CreateMultiPartUploadForDatasetRequestDto,
@@ -58,14 +60,12 @@ from truefoundry.ml.clients.utils import (
58
60
  augmented_raise_for_status,
59
61
  cloud_storage_http_request,
60
62
  )
61
- from truefoundry.ml.env_vars import DISABLE_MULTIPART_UPLOAD
62
63
  from truefoundry.ml.exceptions import MlFoundryException
63
64
  from truefoundry.ml.logger import logger
64
65
  from truefoundry.ml.session import _get_api_client
65
66
  from truefoundry.pydantic_v1 import BaseModel, root_validator
66
67
 
67
68
  _MIN_BYTES_REQUIRED_FOR_MULTIPART = 100 * 1024 * 1024
68
- _MULTIPART_DISABLED = os.getenv(DISABLE_MULTIPART_UPLOAD, "").lower() == "true"
69
69
  # GCP/S3 Maximum number of parts per upload 10,000
70
70
  # Maximum number of blocks in a block blob 50,000 blocks
71
71
  # TODO: This number is artificially limited now. Later
@@ -84,21 +84,11 @@ _MAX_NUM_PARTS_FOR_MULTIPART = 1000
84
84
  # Azure Maximum size of a block in a block blob 4000 MiB
85
85
  # GCP/S3 Maximum size of an individual part in a multipart upload 5 GiB
86
86
  _MAX_PART_SIZE_BYTES_FOR_MULTIPART = 4 * 1024 * 1024 * 1000
87
- _cpu_count = os.cpu_count() or 2
88
- _MAX_WORKERS_FOR_UPLOAD = max(min(32, _cpu_count * 2), 4)
89
- _MAX_WORKERS_FOR_DOWNLOAD = max(min(32, _cpu_count * 2), 4)
90
87
  _LIST_FILES_PAGE_SIZE = 500
91
88
  _GENERATE_SIGNED_URL_BATCH_SIZE = 50
92
89
  DEFAULT_PRESIGNED_URL_EXPIRY_TIME = 3600
93
90
 
94
91
 
95
- def _get_relpath_if_in_tempdir(path: str) -> str:
96
- tempdir = tempfile.gettempdir()
97
- if path.startswith(tempdir):
98
- return os.path.relpath(path, tempdir)
99
- return path
100
-
101
-
102
92
  def _can_display_progress(user_choice: Optional[bool] = None) -> bool:
103
93
  if user_choice is False:
104
94
  return False
@@ -199,7 +189,10 @@ class _FileMultiPartInfo(NamedTuple):
199
189
 
200
190
  def _decide_file_parts(file_path: str) -> _FileMultiPartInfo:
201
191
  file_size = os.path.getsize(file_path)
202
- if file_size < _MIN_BYTES_REQUIRED_FOR_MULTIPART or _MULTIPART_DISABLED:
192
+ if (
193
+ file_size < _MIN_BYTES_REQUIRED_FOR_MULTIPART
194
+ or ENV_VARS.TFY_ARTIFACTS_DISABLE_MULTIPART_UPLOAD
195
+ ):
203
196
  return _FileMultiPartInfo(1, part_size=file_size, file_size=file_size)
204
197
 
205
198
  ideal_num_parts = math.ceil(file_size / _PART_SIZE_BYTES_FOR_MULTIPART)
@@ -240,7 +233,7 @@ def _signed_url_upload_file(
240
233
  return
241
234
 
242
235
  task_progress_bar = progress_bar.add_task(
243
- f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
236
+ f"[green]Uploading {local_file}:", start=True
244
237
  )
245
238
 
246
239
  def callback(length):
@@ -262,7 +255,7 @@ def _signed_url_upload_file(
262
255
  def _download_file_using_http_uri(
263
256
  http_uri,
264
257
  download_path,
265
- chunk_size=100000000,
258
+ chunk_size=ENV_VARS.TFY_ARTIFACTS_DOWNLOAD_CHUNK_SIZE_BYTES,
266
259
  callback: Optional[Callable[[int, int], Any]] = None,
267
260
  ):
268
261
  """
@@ -283,6 +276,9 @@ def _download_file_using_http_uri(
283
276
  if not chunk:
284
277
  break
285
278
  output_file.write(chunk)
279
+ if ENV_VARS.TFY_ARTIFACTS_DOWNLOAD_FSYNC_CHUNKS:
280
+ output_file.flush()
281
+ os.fsync(output_file.fileno())
286
282
 
287
283
 
288
284
  class _CallbackIOWrapperForMultiPartUpload(CallbackIOWrapper):
@@ -338,7 +334,7 @@ def _s3_compatible_multipart_upload(
338
334
  parts = []
339
335
 
340
336
  multi_part_upload_progress = progress_bar.add_task(
341
- f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
337
+ f"[green]Uploading {local_file}:", start=True
342
338
  )
343
339
 
344
340
  def upload(part_number: int, seek: int) -> None:
@@ -408,7 +404,7 @@ def _azure_multi_part_upload(
408
404
  abort_event = abort_event or Event()
409
405
 
410
406
  multi_part_upload_progress = progress_bar.add_task(
411
- f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
407
+ f"[green]Uploading {local_file}:", start=True
412
408
  )
413
409
 
414
410
  def upload(part_number: int, seek: int):
@@ -508,164 +504,56 @@ class MlFoundryArtifactsRepository:
508
504
  api_client=self._api_client
509
505
  )
510
506
 
511
- def _create_download_destination(
512
- self, src_artifact_path, dst_local_dir_path=None
513
- ) -> str:
514
- """
515
- Creates a local filesystem location to be used as a destination for downloading the artifact
516
- specified by `src_artifact_path`. The destination location is a subdirectory of the
517
- specified `dst_local_dir_path`, which is determined according to the structure of
518
- `src_artifact_path`. For example, if `src_artifact_path` is `dir1/file1.txt`, then the
519
- resulting destination path is `<dst_local_dir_path>/dir1/file1.txt`. Local directories are
520
- created for the resulting destination location if they do not exist.
521
-
522
- :param src_artifact_path: A relative, POSIX-style path referring to an artifact stored
523
- within the repository's artifact root location.
524
- `src_artifact_path` should be specified relative to the
525
- repository's artifact root location.
526
- :param dst_local_dir_path: The absolute path to a local filesystem directory in which the
527
- local destination path will be contained. The local destination
528
- path may be contained in a subdirectory of `dst_root_dir` if
529
- `src_artifact_path` contains subdirectories.
530
- :return: The absolute path to a local filesystem location to be used as a destination
531
- for downloading the artifact specified by `src_artifact_path`.
532
- """
533
- src_artifact_path = src_artifact_path.rstrip(
534
- "/"
535
- ) # Ensure correct dirname for trailing '/'
536
- dirpath = posixpath.dirname(src_artifact_path)
537
- local_dir_path = os.path.join(dst_local_dir_path, dirpath)
538
- local_file_path = os.path.join(dst_local_dir_path, src_artifact_path)
539
- if not os.path.exists(local_dir_path):
540
- os.makedirs(local_dir_path, exist_ok=True)
541
- return local_file_path
542
-
543
- # these methods should be named list_files, log_directory, log_file, etc
544
- def list_artifacts(
545
- self, path=None, page_size=_LIST_FILES_PAGE_SIZE, **kwargs
546
- ) -> Iterator[FileInfoDto]:
547
- page_token = None
548
- started = False
549
- while not started or page_token is not None:
550
- started = True
551
- page = self.list_files(
552
- artifact_identifier=self.artifact_identifier,
553
- path=path,
554
- page_size=page_size,
555
- page_token=page_token,
507
+ # TODO (chiragjn): Refactor these methods - if else is very inconvenient
508
+ def get_signed_urls_for_read(
509
+ self,
510
+ artifact_identifier: ArtifactIdentifier,
511
+ paths,
512
+ ) -> List[SignedURLDto]:
513
+ if artifact_identifier.artifact_version_id:
514
+ signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_read_post(
515
+ get_signed_urls_for_artifact_version_read_request_dto=GetSignedURLsForArtifactVersionReadRequestDto(
516
+ id=str(artifact_identifier.artifact_version_id), paths=paths
517
+ )
556
518
  )
557
- for file_info in page.files:
558
- yield file_info
559
- page_token = page.next_page_token
560
-
561
- def log_artifacts( # noqa: C901
562
- self, local_dir, artifact_path=None, progress=None
563
- ):
564
- show_progress = _can_display_progress(progress)
565
-
566
- dest_path = artifact_path or ""
567
- dest_path = dest_path.lstrip(posixpath.sep)
568
-
569
- files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
570
- files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
571
-
572
- for root, _, file_names in os.walk(local_dir):
573
- upload_path = dest_path
574
- if root != local_dir:
575
- rel_path = os.path.relpath(root, local_dir)
576
- rel_path = relative_path_to_artifact_path(rel_path)
577
- upload_path = posixpath.join(dest_path, rel_path)
578
- for file_name in file_names:
579
- local_file = os.path.join(root, file_name)
580
- multipart_info = _decide_file_parts(local_file)
581
-
582
- final_upload_path = upload_path or ""
583
- final_upload_path = final_upload_path.lstrip(posixpath.sep)
584
- final_upload_path = posixpath.join(
585
- final_upload_path, os.path.basename(local_file)
519
+ signed_urls = signed_urls_response.signed_urls
520
+ elif artifact_identifier.dataset_fqn:
521
+ signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_dataset_read_post(
522
+ get_signed_urls_for_dataset_read_request_dto=GetSignedURLsForDatasetReadRequestDto(
523
+ dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
586
524
  )
525
+ )
526
+ signed_urls = signed_urls_dataset_response.signed_urls
527
+ else:
528
+ raise ValueError(
529
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
530
+ )
531
+ return signed_urls
587
532
 
588
- if multipart_info.num_parts == 1:
589
- files_for_normal_upload.append(
590
- (final_upload_path, local_file, multipart_info)
591
- )
592
- else:
593
- files_for_multipart_upload.append(
594
- (final_upload_path, local_file, multipart_info)
595
- )
596
-
597
- abort_event = Event()
598
-
599
- with Progress(
600
- "[progress.description]{task.description}",
601
- BarColumn(bar_width=None),
602
- "[progress.percentage]{task.percentage:>3.0f}%",
603
- DownloadColumn(),
604
- TransferSpeedColumn(),
605
- TimeRemainingColumn(),
606
- TimeElapsedColumn(),
607
- refresh_per_second=1,
608
- disable=not show_progress,
609
- expand=True,
610
- ) as progress_bar, ThreadPoolExecutor(
611
- max_workers=_MAX_WORKERS_FOR_UPLOAD
612
- ) as executor:
613
- futures: List[Future] = []
614
- # Note: While this batching is beneficial when there is a large number of files, there is also
615
- # a rare case risk of the signed url expiring before a request is made to it
616
- _batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
617
- for start_idx in range(0, len(files_for_normal_upload), _batch_size):
618
- end_idx = min(start_idx + _batch_size, len(files_for_normal_upload))
619
- if _any_future_has_failed(futures):
620
- break
621
- logger.debug("Generating write signed urls for a batch ...")
622
- remote_file_paths = [
623
- files_for_normal_upload[idx][0] for idx in range(start_idx, end_idx)
624
- ]
625
- signed_urls = self.get_signed_urls_for_write(
626
- artifact_identifier=self.artifact_identifier,
627
- paths=remote_file_paths,
533
+ def get_signed_urls_for_write(
534
+ self,
535
+ artifact_identifier: ArtifactIdentifier,
536
+ paths: List[str],
537
+ ) -> List[SignedURLDto]:
538
+ if artifact_identifier.artifact_version_id:
539
+ signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_write_post(
540
+ get_signed_urls_for_artifact_version_write_request_dto=GetSignedURLsForArtifactVersionWriteRequestDto(
541
+ id=str(artifact_identifier.artifact_version_id), paths=paths
628
542
  )
629
- for idx, signed_url in zip(range(start_idx, end_idx), signed_urls):
630
- (
631
- upload_path,
632
- local_file,
633
- multipart_info,
634
- ) = files_for_normal_upload[idx]
635
- future = executor.submit(
636
- self._log_artifact,
637
- local_file=local_file,
638
- artifact_path=upload_path,
639
- multipart_info=multipart_info,
640
- signed_url=signed_url,
641
- abort_event=abort_event,
642
- executor_for_multipart_upload=None,
643
- progress_bar=progress_bar,
644
- )
645
- futures.append(future)
646
-
647
- done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
648
- if len(not_done) > 0:
649
- abort_event.set()
650
- for future in not_done:
651
- future.cancel()
652
- for future in done:
653
- if future.exception() is not None:
654
- raise future.exception()
655
-
656
- for (
657
- upload_path,
658
- local_file,
659
- multipart_info,
660
- ) in files_for_multipart_upload:
661
- self._log_artifact(
662
- local_file=local_file,
663
- artifact_path=upload_path,
664
- signed_url=None,
665
- multipart_info=multipart_info,
666
- executor_for_multipart_upload=executor,
667
- progress_bar=progress_bar,
543
+ )
544
+ signed_urls = signed_urls_response.signed_urls
545
+ elif artifact_identifier.dataset_fqn:
546
+ signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_for_dataset_write_post(
547
+ get_signed_url_for_dataset_write_request_dto=GetSignedURLForDatasetWriteRequestDto(
548
+ dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
668
549
  )
550
+ )
551
+ signed_urls = signed_urls_dataset_response.signed_urls
552
+ else:
553
+ raise ValueError(
554
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
555
+ )
556
+ return signed_urls
669
557
 
670
558
  def _normal_upload(
671
559
  self,
@@ -683,7 +571,7 @@ class MlFoundryArtifactsRepository:
683
571
  if progress_bar.disable:
684
572
  logger.info(
685
573
  "Uploading %s to %s",
686
- _get_relpath_if_in_tempdir(local_file),
574
+ local_file,
687
575
  artifact_path,
688
576
  )
689
577
 
@@ -695,6 +583,36 @@ class MlFoundryArtifactsRepository:
695
583
  )
696
584
  logger.debug("Uploaded %s to %s", local_file, artifact_path)
697
585
 
586
+ def _create_multipart_upload_for_identifier(
587
+ self,
588
+ artifact_identifier: ArtifactIdentifier,
589
+ path,
590
+ num_parts,
591
+ ) -> MultiPartUploadDto:
592
+ if artifact_identifier.artifact_version_id:
593
+ create_multipart_response: MultiPartUploadResponseDto = self._mlfoundry_artifacts_api.create_multi_part_upload_post(
594
+ create_multi_part_upload_request_dto=CreateMultiPartUploadRequestDto(
595
+ artifact_version_id=str(artifact_identifier.artifact_version_id),
596
+ path=path,
597
+ num_parts=num_parts,
598
+ )
599
+ )
600
+ multipart_upload = create_multipart_response.multipart_upload
601
+ elif artifact_identifier.dataset_fqn:
602
+ create_multipart_for_dataset_response = self._mlfoundry_artifacts_api.create_multipart_upload_for_dataset_post(
603
+ create_multi_part_upload_for_dataset_request_dto=CreateMultiPartUploadForDatasetRequestDto(
604
+ dataset_fqn=artifact_identifier.dataset_fqn,
605
+ path=path,
606
+ num_parts=num_parts,
607
+ )
608
+ )
609
+ multipart_upload = create_multipart_for_dataset_response.multipart_upload
610
+ else:
611
+ raise ValueError(
612
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
613
+ )
614
+ return multipart_upload
615
+
698
616
  def _multipart_upload(
699
617
  self,
700
618
  local_file: str,
@@ -707,11 +625,11 @@ class MlFoundryArtifactsRepository:
707
625
  if progress_bar.disable:
708
626
  logger.info(
709
627
  "Uploading %s to %s using multipart upload",
710
- _get_relpath_if_in_tempdir(local_file),
628
+ local_file,
711
629
  artifact_path,
712
630
  )
713
631
 
714
- multipart_upload = self.create_multipart_upload_for_identifier(
632
+ multipart_upload = self._create_multipart_upload_for_identifier(
715
633
  artifact_identifier=self.artifact_identifier,
716
634
  path=artifact_path,
717
635
  num_parts=multipart_info.num_parts,
@@ -743,7 +661,7 @@ class MlFoundryArtifactsRepository:
743
661
  else:
744
662
  raise NotImplementedError()
745
663
 
746
- def _log_artifact(
664
+ def _upload_file(
747
665
  self,
748
666
  local_file: str,
749
667
  artifact_path: str,
@@ -763,7 +681,9 @@ class MlFoundryArtifactsRepository:
763
681
  )
764
682
 
765
683
  if not executor_for_multipart_upload:
766
- with ThreadPoolExecutor(max_workers=_MAX_WORKERS_FOR_UPLOAD) as executor:
684
+ with ThreadPoolExecutor(
685
+ max_workers=ENV_VARS.TFY_ARTIFACTS_UPLOAD_MAX_WORKERS
686
+ ) as executor:
767
687
  return self._multipart_upload(
768
688
  local_file=local_file,
769
689
  artifact_path=artifact_path,
@@ -780,10 +700,14 @@ class MlFoundryArtifactsRepository:
780
700
  progress_bar=progress_bar,
781
701
  )
782
702
 
783
- def log_artifact(self, local_file: str, artifact_path: Optional[str] = None):
784
- upload_path = artifact_path or ""
785
- upload_path = upload_path.lstrip(posixpath.sep)
786
- upload_path = posixpath.join(upload_path, os.path.basename(local_file))
703
+ def _upload(
704
+ self,
705
+ files_for_normal_upload: Sequence[Tuple[str, str, _FileMultiPartInfo]],
706
+ files_for_multipart_upload: Sequence[Tuple[str, str, _FileMultiPartInfo]],
707
+ progress: Optional[bool] = None,
708
+ ):
709
+ abort_event = Event()
710
+ show_progress = _can_display_progress(progress)
787
711
  with Progress(
788
712
  "[progress.description]{task.description}",
789
713
  BarColumn(bar_width=None),
@@ -793,170 +717,221 @@ class MlFoundryArtifactsRepository:
793
717
  TimeRemainingColumn(),
794
718
  TimeElapsedColumn(),
795
719
  refresh_per_second=1,
796
- disable=True,
720
+ disable=not show_progress,
797
721
  expand=True,
798
- ) as progress_bar:
799
- self._log_artifact(
800
- local_file=local_file,
801
- artifact_path=upload_path,
802
- multipart_info=_decide_file_parts(local_file),
803
- progress_bar=progress_bar,
804
- )
805
-
806
- def _is_directory(self, artifact_path):
807
- for _ in self.list_artifacts(artifact_path, page_size=3):
808
- return True
809
- return False
810
-
811
- def download_artifacts( # noqa: C901
812
- self,
813
- artifact_path: str,
814
- dst_path: Optional[str] = None,
815
- overwrite: bool = False,
816
- progress: Optional[bool] = None,
817
- ) -> str:
818
- """
819
- Download an artifact file or directory to a local directory if applicable, and return a
820
- local path for it. The caller is responsible for managing the lifecycle of the downloaded artifacts.
821
-
822
- Args:
823
- artifact_path: Relative source path to the desired artifacts.
824
- dst_path: Absolute path of the local filesystem destination directory to which to
825
- download the specified artifacts. This directory must already exist.
826
- If unspecified, the artifacts will either be downloaded to a new
827
- uniquely-named directory.
828
- overwrite: if to overwrite the files at/inside `dst_path` if they exist
829
- progress: Show or hide progress bar
830
-
831
- Returns:
832
- str: Absolute path of the local filesystem location containing the desired artifacts.
833
- """
834
-
835
- show_progress = _can_display_progress()
836
-
837
- is_dir_temp = False
838
- if dst_path is None:
839
- dst_path = tempfile.mkdtemp()
840
- is_dir_temp = True
722
+ ) as progress_bar, ThreadPoolExecutor(
723
+ max_workers=ENV_VARS.TFY_ARTIFACTS_UPLOAD_MAX_WORKERS
724
+ ) as executor:
725
+ futures: List[Future] = []
726
+ # Note: While this batching is beneficial when there is a large number of files, there is also
727
+ # a rare case risk of the signed url expiring before a request is made to it
728
+ _batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
729
+ for start_idx in range(0, len(files_for_normal_upload), _batch_size):
730
+ end_idx = min(start_idx + _batch_size, len(files_for_normal_upload))
731
+ if _any_future_has_failed(futures):
732
+ break
733
+ logger.debug("Generating write signed urls for a batch ...")
734
+ remote_file_paths = [
735
+ files_for_normal_upload[idx][0] for idx in range(start_idx, end_idx)
736
+ ]
737
+ signed_urls = self.get_signed_urls_for_write(
738
+ artifact_identifier=self.artifact_identifier,
739
+ paths=remote_file_paths,
740
+ )
741
+ for idx, signed_url in zip(range(start_idx, end_idx), signed_urls):
742
+ (
743
+ upload_path,
744
+ local_file,
745
+ multipart_info,
746
+ ) = files_for_normal_upload[idx]
747
+ future = executor.submit(
748
+ self._upload_file,
749
+ local_file=local_file,
750
+ artifact_path=upload_path,
751
+ multipart_info=multipart_info,
752
+ signed_url=signed_url,
753
+ abort_event=abort_event,
754
+ executor_for_multipart_upload=None,
755
+ progress_bar=progress_bar,
756
+ )
757
+ futures.append(future)
841
758
 
842
- dst_path = os.path.abspath(dst_path)
843
- if is_dir_temp:
844
- logger.info(
845
- f"Using temporary directory {dst_path} as the download directory"
846
- )
759
+ done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
760
+ if len(not_done) > 0:
761
+ abort_event.set()
762
+ for future in not_done:
763
+ future.cancel()
764
+ for future in done:
765
+ if future.exception() is not None:
766
+ raise future.exception()
847
767
 
848
- if not os.path.exists(dst_path):
768
+ for (
769
+ upload_path,
770
+ local_file,
771
+ multipart_info,
772
+ ) in files_for_multipart_upload:
773
+ self._upload_file(
774
+ local_file=local_file,
775
+ artifact_path=upload_path,
776
+ signed_url=None,
777
+ multipart_info=multipart_info,
778
+ executor_for_multipart_upload=executor,
779
+ progress_bar=progress_bar,
780
+ )
781
+
782
+ def _add_file_for_upload(
783
+ self,
784
+ local_file: str,
785
+ artifact_path: str,
786
+ files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]],
787
+ files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]],
788
+ ):
789
+ local_file = os.path.realpath(local_file)
790
+ if os.path.isdir(local_file):
849
791
  raise MlFoundryException(
850
- message=(
851
- "The destination path for downloaded artifacts does not"
852
- " exist! Destination path: {dst_path}".format(dst_path=dst_path)
853
- ),
792
+ "Cannot log a directory as an artifact. Use `log_artifacts` instead"
854
793
  )
855
- elif not os.path.isdir(dst_path):
794
+ upload_path = artifact_path
795
+ upload_path = upload_path.lstrip(posixpath.sep)
796
+ multipart_info = _decide_file_parts(local_file)
797
+ if multipart_info.num_parts == 1:
798
+ files_for_normal_upload.append((upload_path, local_file, multipart_info))
799
+ else:
800
+ files_for_multipart_upload.append((upload_path, local_file, multipart_info))
801
+
802
+ def _add_dir_for_upload(
803
+ self,
804
+ local_dir: str,
805
+ artifact_path: str,
806
+ files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]],
807
+ files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]],
808
+ ):
809
+ local_dir = os.path.realpath(local_dir)
810
+ if not os.path.isdir(local_dir):
856
811
  raise MlFoundryException(
857
- message=(
858
- "The destination path for downloaded artifacts must be a directory!"
859
- " Destination path: {dst_path}".format(dst_path=dst_path)
860
- ),
812
+ "Cannot log a file as a directory. Use `log_artifact` instead"
861
813
  )
814
+ dest_path = artifact_path
815
+ dest_path = dest_path.lstrip(posixpath.sep)
862
816
 
863
- progress_bar = Progress(
864
- "[progress.description]{task.description}",
865
- BarColumn(),
866
- "[progress.percentage]{task.percentage:>3.0f}%",
867
- DownloadColumn(),
868
- TransferSpeedColumn(),
869
- TimeRemainingColumn(),
870
- TimeElapsedColumn(),
871
- refresh_per_second=1,
872
- disable=not show_progress,
873
- expand=True,
874
- )
875
-
876
- try:
877
- progress_bar.start()
878
- # Check if the artifacts points to a directory
879
- if self._is_directory(artifact_path):
880
- futures: List[Future] = []
881
- file_paths: List[Tuple[str, str]] = []
882
- abort_event = Event()
817
+ for root, _, file_names in os.walk(local_dir):
818
+ upload_path = dest_path
819
+ if root != local_dir:
820
+ rel_path = os.path.relpath(root, local_dir)
821
+ rel_path = relative_path_to_artifact_path(rel_path)
822
+ upload_path = posixpath.join(dest_path, rel_path)
823
+ for file_name in file_names:
824
+ local_file = os.path.join(root, file_name)
825
+ self._add_file_for_upload(
826
+ local_file=local_file,
827
+ artifact_path=upload_path,
828
+ files_for_normal_upload=files_for_normal_upload,
829
+ files_for_multipart_upload=files_for_multipart_upload,
830
+ )
883
831
 
884
- # Check if any file is being overwritten before downloading them
885
- for file_path, download_dest_path in self._get_file_paths_recur(
886
- src_artifact_dir_path=artifact_path, dst_local_dir_path=dst_path
887
- ):
888
- final_file_path = os.path.join(download_dest_path, file_path)
832
+ def log_artifacts(
833
+ self,
834
+ src_dest_pairs: Sequence[Tuple[str, Optional[str]]],
835
+ progress: Optional[bool] = None,
836
+ ):
837
+ files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
838
+ files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
839
+ for src, dest in src_dest_pairs:
840
+ if os.path.isdir(src):
841
+ self._add_dir_for_upload(
842
+ local_dir=src,
843
+ artifact_path=dest or "",
844
+ files_for_normal_upload=files_for_normal_upload,
845
+ files_for_multipart_upload=files_for_multipart_upload,
846
+ )
847
+ else:
848
+ self._add_file_for_upload(
849
+ local_file=src,
850
+ artifact_path=dest or "",
851
+ files_for_normal_upload=files_for_normal_upload,
852
+ files_for_multipart_upload=files_for_multipart_upload,
853
+ )
854
+ self._upload(
855
+ files_for_normal_upload=files_for_normal_upload,
856
+ files_for_multipart_upload=files_for_multipart_upload,
857
+ progress=progress,
858
+ )
889
859
 
890
- # There would be no overwrite if temp directory is being used
891
- if (
892
- not is_dir_temp
893
- and os.path.exists(final_file_path)
894
- and not overwrite
895
- ):
896
- raise MlFoundryException(
897
- f"File already exists at {final_file_path}, aborting download "
898
- f"(set `overwrite` flag to overwrite this and any subsequent files)"
899
- )
900
- file_paths.append((file_path, download_dest_path))
860
+ def _list_files(
861
+ self, artifact_identifier: ArtifactIdentifier, path, page_size, page_token
862
+ ) -> Union[ListFilesForDatasetResponseDto, ListFilesForArtifactVersionsResponseDto]:
863
+ if artifact_identifier.dataset_fqn:
864
+ return self._mlfoundry_artifacts_api.list_files_for_dataset_post(
865
+ list_files_for_dataset_request_dto=ListFilesForDatasetRequestDto(
866
+ dataset_fqn=artifact_identifier.dataset_fqn,
867
+ path=path,
868
+ max_results=page_size,
869
+ page_token=page_token,
870
+ )
871
+ )
872
+ else:
873
+ return self._mlfoundry_artifacts_api.list_files_for_artifact_version_post(
874
+ list_files_for_artifact_version_request_dto=ListFilesForArtifactVersionRequestDto(
875
+ id=str(artifact_identifier.artifact_version_id),
876
+ path=path,
877
+ max_results=page_size,
878
+ page_token=page_token,
879
+ )
880
+ )
901
881
 
902
- with ThreadPoolExecutor(_MAX_WORKERS_FOR_DOWNLOAD) as executor:
903
- # Note: While this batching is beneficial when there is a large number of files, there is also
904
- # a rare case risk of the signed url expiring before a request is made to it
905
- batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
906
- for start_idx in range(0, len(file_paths), batch_size):
907
- end_idx = min(start_idx + batch_size, len(file_paths))
908
- if _any_future_has_failed(futures):
909
- break
910
- logger.debug("Generating read signed urls for a batch ...")
911
- remote_file_paths = [
912
- file_paths[idx][0] for idx in range(start_idx, end_idx)
913
- ]
914
- signed_urls = self.get_signed_urls_for_read(
915
- artifact_identifier=self.artifact_identifier,
916
- paths=remote_file_paths,
917
- )
918
- for idx, signed_url in zip(
919
- range(start_idx, end_idx), signed_urls
920
- ):
921
- file_path, download_dest_path = file_paths[idx]
922
- future = executor.submit(
923
- self._download_artifact,
924
- src_artifact_path=file_path,
925
- dst_local_dir_path=download_dest_path,
926
- signed_url=signed_url,
927
- abort_event=abort_event,
928
- progress_bar=progress_bar,
929
- )
930
- futures.append(future)
882
+ def list_artifacts(
883
+ self, path=None, page_size=_LIST_FILES_PAGE_SIZE, **kwargs
884
+ ) -> Iterator[FileInfoDto]:
885
+ page_token = None
886
+ started = False
887
+ while not started or page_token is not None:
888
+ started = True
889
+ page = self._list_files(
890
+ artifact_identifier=self.artifact_identifier,
891
+ path=path,
892
+ page_size=page_size,
893
+ page_token=page_token,
894
+ )
895
+ for file_info in page.files:
896
+ yield file_info
897
+ page_token = page.next_page_token
931
898
 
932
- done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
933
- if len(not_done) > 0:
934
- abort_event.set()
935
- for future in not_done:
936
- future.cancel()
937
- for future in done:
938
- if future.exception() is not None:
939
- raise future.exception()
899
+ def _is_directory(self, artifact_path):
900
+ for _ in self.list_artifacts(artifact_path, page_size=3):
901
+ return True
902
+ return False
940
903
 
941
- output_dir = os.path.join(dst_path, artifact_path)
942
- return output_dir
943
- else:
944
- return self._download_artifact(
945
- src_artifact_path=artifact_path,
946
- dst_local_dir_path=dst_path,
947
- signed_url=None,
948
- progress_bar=progress_bar,
949
- )
950
- except Exception as err:
951
- if is_dir_temp:
952
- logger.info(
953
- f"Error encountered, removing temporary download directory at {dst_path}"
954
- )
955
- rmtree(dst_path) # remove temp directory alongside it's contents
956
- raise err
904
+ def _create_download_destination(
905
+ self, src_artifact_path: str, dst_local_dir_path: str
906
+ ) -> str:
907
+ """
908
+ Creates a local filesystem location to be used as a destination for downloading the artifact
909
+ specified by `src_artifact_path`. The destination location is a subdirectory of the
910
+ specified `dst_local_dir_path`, which is determined according to the structure of
911
+ `src_artifact_path`. For example, if `src_artifact_path` is `dir1/file1.txt`, then the
912
+ resulting destination path is `<dst_local_dir_path>/dir1/file1.txt`. Local directories are
913
+ created for the resulting destination location if they do not exist.
957
914
 
958
- finally:
959
- progress_bar.stop()
915
+ :param src_artifact_path: A relative, POSIX-style path referring to an artifact stored
916
+ within the repository's artifact root location.
917
+ `src_artifact_path` should be specified relative to the
918
+ repository's artifact root location.
919
+ :param dst_local_dir_path: The absolute path to a local filesystem directory in which the
920
+ local destination path will be contained. The local destination
921
+ path may be contained in a subdirectory of `dst_root_dir` if
922
+ `src_artifact_path` contains subdirectories.
923
+ :return: The absolute path to a local filesystem location to be used as a destination
924
+ for downloading the artifact specified by `src_artifact_path`.
925
+ """
926
+ src_artifact_path = src_artifact_path.rstrip(
927
+ "/"
928
+ ) # Ensure correct dirname for trailing '/'
929
+ dirpath = posixpath.dirname(src_artifact_path)
930
+ local_dir_path = os.path.join(dst_local_dir_path, dirpath)
931
+ local_file_path = os.path.join(dst_local_dir_path, src_artifact_path)
932
+ if not os.path.exists(local_dir_path):
933
+ os.makedirs(local_dir_path, exist_ok=True)
934
+ return local_file_path
960
935
 
961
936
  # noinspection PyMethodOverriding
962
937
  def _download_file(
@@ -1003,8 +978,8 @@ class MlFoundryArtifactsRepository:
1003
978
 
1004
979
  def _download_artifact(
1005
980
  self,
1006
- src_artifact_path,
1007
- dst_local_dir_path,
981
+ src_artifact_path: str,
982
+ dst_local_dir_path: str,
1008
983
  signed_url: Optional[SignedURLDto],
1009
984
  progress_bar: Optional[Progress] = None,
1010
985
  abort_event=None,
@@ -1057,105 +1032,154 @@ class MlFoundryArtifactsRepository:
1057
1032
  else:
1058
1033
  yield file_info.path, dst_local_dir_path
1059
1034
 
1060
- # TODO (chiragjn): Refactor these methods - if else is very inconvenient
1061
- def get_signed_urls_for_read(
1035
+ def download_artifacts( # noqa: C901
1062
1036
  self,
1063
- artifact_identifier: ArtifactIdentifier,
1064
- paths,
1065
- ) -> List[SignedURLDto]:
1066
- if artifact_identifier.artifact_version_id:
1067
- signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_read_post(
1068
- get_signed_urls_for_artifact_version_read_request_dto=GetSignedURLsForArtifactVersionReadRequestDto(
1069
- id=str(artifact_identifier.artifact_version_id), paths=paths
1070
- )
1071
- )
1072
- signed_urls = signed_urls_response.signed_urls
1073
- elif artifact_identifier.dataset_fqn:
1074
- signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_dataset_read_post(
1075
- get_signed_urls_for_dataset_read_request_dto=GetSignedURLsForDatasetReadRequestDto(
1076
- dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
1077
- )
1078
- )
1079
- signed_urls = signed_urls_dataset_response.signed_urls
1080
- else:
1081
- raise ValueError(
1082
- "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1083
- )
1084
- return signed_urls
1037
+ artifact_path: str,
1038
+ dst_path: Optional[str] = None,
1039
+ overwrite: bool = False,
1040
+ progress: Optional[bool] = None,
1041
+ ) -> str:
1042
+ """
1043
+ Download an artifact file or directory to a local directory if applicable, and return a
1044
+ local path for it. The caller is responsible for managing the lifecycle of the downloaded artifacts.
1085
1045
 
1086
- def get_signed_urls_for_write(
1087
- self,
1088
- artifact_identifier: ArtifactIdentifier,
1089
- paths: List[str],
1090
- ) -> List[SignedURLDto]:
1091
- if artifact_identifier.artifact_version_id:
1092
- signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_write_post(
1093
- get_signed_urls_for_artifact_version_write_request_dto=GetSignedURLsForArtifactVersionWriteRequestDto(
1094
- id=str(artifact_identifier.artifact_version_id), paths=paths
1095
- )
1096
- )
1097
- signed_urls = signed_urls_response.signed_urls
1098
- elif artifact_identifier.dataset_fqn:
1099
- signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_for_dataset_write_post(
1100
- get_signed_url_for_dataset_write_request_dto=GetSignedURLForDatasetWriteRequestDto(
1101
- dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
1102
- )
1103
- )
1104
- signed_urls = signed_urls_dataset_response.signed_urls
1105
- else:
1106
- raise ValueError(
1107
- "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1108
- )
1109
- return signed_urls
1046
+ Args:
1047
+ artifact_path: Relative source path to the desired artifacts.
1048
+ dst_path: Absolute path of the local filesystem destination directory to which to
1049
+ download the specified artifacts. This directory must already exist.
1050
+ If unspecified, the artifacts will either be downloaded to a new
1051
+ uniquely-named directory.
1052
+ overwrite: if to overwrite the files at/inside `dst_path` if they exist
1053
+ progress: Show or hide progress bar
1110
1054
 
1111
- def create_multipart_upload_for_identifier(
1112
- self,
1113
- artifact_identifier: ArtifactIdentifier,
1114
- path,
1115
- num_parts,
1116
- ) -> MultiPartUploadDto:
1117
- if artifact_identifier.artifact_version_id:
1118
- create_multipart_response: MultiPartUploadResponseDto = self._mlfoundry_artifacts_api.create_multi_part_upload_post(
1119
- create_multi_part_upload_request_dto=CreateMultiPartUploadRequestDto(
1120
- artifact_version_id=str(artifact_identifier.artifact_version_id),
1121
- path=path,
1122
- num_parts=num_parts,
1123
- )
1055
+ Returns:
1056
+ str: Absolute path of the local filesystem location containing the desired artifacts.
1057
+ """
1058
+
1059
+ show_progress = _can_display_progress(user_choice=progress)
1060
+
1061
+ is_dir_temp = False
1062
+ if dst_path is None:
1063
+ dst_path = tempfile.mkdtemp()
1064
+ is_dir_temp = True
1065
+
1066
+ dst_path = os.path.abspath(dst_path)
1067
+ if is_dir_temp:
1068
+ logger.info(
1069
+ f"Using temporary directory {dst_path} as the download directory"
1124
1070
  )
1125
- multipart_upload = create_multipart_response.multipart_upload
1126
- elif artifact_identifier.dataset_fqn:
1127
- create_multipart_for_dataset_response = self._mlfoundry_artifacts_api.create_multipart_upload_for_dataset_post(
1128
- create_multi_part_upload_for_dataset_request_dto=CreateMultiPartUploadForDatasetRequestDto(
1129
- dataset_fqn=artifact_identifier.dataset_fqn,
1130
- path=path,
1131
- num_parts=num_parts,
1132
- )
1071
+
1072
+ if not os.path.exists(dst_path):
1073
+ raise MlFoundryException(
1074
+ message=(
1075
+ "The destination path for downloaded artifacts does not"
1076
+ " exist! Destination path: {dst_path}".format(dst_path=dst_path)
1077
+ ),
1133
1078
  )
1134
- multipart_upload = create_multipart_for_dataset_response.multipart_upload
1135
- else:
1136
- raise ValueError(
1137
- "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1079
+ elif not os.path.isdir(dst_path):
1080
+ raise MlFoundryException(
1081
+ message=(
1082
+ "The destination path for downloaded artifacts must be a directory!"
1083
+ " Destination path: {dst_path}".format(dst_path=dst_path)
1084
+ ),
1138
1085
  )
1139
- return multipart_upload
1140
1086
 
1141
- def list_files(
1142
- self, artifact_identifier: ArtifactIdentifier, path, page_size, page_token
1143
- ) -> Union[ListFilesForDatasetResponseDto, ListFilesForArtifactVersionsResponseDto]:
1144
- if artifact_identifier.dataset_fqn:
1145
- return self._mlfoundry_artifacts_api.list_files_for_dataset_post(
1146
- list_files_for_dataset_request_dto=ListFilesForDatasetRequestDto(
1147
- dataset_fqn=artifact_identifier.dataset_fqn,
1148
- path=path,
1149
- max_results=page_size,
1150
- page_token=page_token,
1087
+ progress_bar = Progress(
1088
+ "[progress.description]{task.description}",
1089
+ BarColumn(),
1090
+ "[progress.percentage]{task.percentage:>3.0f}%",
1091
+ DownloadColumn(),
1092
+ TransferSpeedColumn(),
1093
+ TimeRemainingColumn(),
1094
+ TimeElapsedColumn(),
1095
+ refresh_per_second=1,
1096
+ disable=not show_progress,
1097
+ expand=True,
1098
+ )
1099
+
1100
+ try:
1101
+ progress_bar.start()
1102
+ # Check if the artifacts points to a directory
1103
+ if self._is_directory(artifact_path):
1104
+ futures: List[Future] = []
1105
+ file_paths: List[Tuple[str, str]] = []
1106
+ abort_event = Event()
1107
+
1108
+ # Check if any file is being overwritten before downloading them
1109
+ for file_path, download_dest_path in self._get_file_paths_recur(
1110
+ src_artifact_dir_path=artifact_path, dst_local_dir_path=dst_path
1111
+ ):
1112
+ final_file_path = os.path.join(download_dest_path, file_path)
1113
+
1114
+ # There would be no overwrite if temp directory is being used
1115
+ if (
1116
+ not is_dir_temp
1117
+ and os.path.exists(final_file_path)
1118
+ and not overwrite
1119
+ ):
1120
+ raise MlFoundryException(
1121
+ f"File already exists at {final_file_path}, aborting download "
1122
+ f"(set `overwrite` flag to overwrite this and any subsequent files)"
1123
+ )
1124
+ file_paths.append((file_path, download_dest_path))
1125
+
1126
+ with ThreadPoolExecutor(
1127
+ max_workers=ENV_VARS.TFY_ARTIFACTS_DOWNLOAD_MAX_WORKERS
1128
+ ) as executor:
1129
+ # Note: While this batching is beneficial when there is a large number of files, there is also
1130
+ # a rare case risk of the signed url expiring before a request is made to it
1131
+ batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
1132
+ for start_idx in range(0, len(file_paths), batch_size):
1133
+ end_idx = min(start_idx + batch_size, len(file_paths))
1134
+ if _any_future_has_failed(futures):
1135
+ break
1136
+ logger.debug("Generating read signed urls for a batch ...")
1137
+ remote_file_paths = [
1138
+ file_paths[idx][0] for idx in range(start_idx, end_idx)
1139
+ ]
1140
+ signed_urls = self.get_signed_urls_for_read(
1141
+ artifact_identifier=self.artifact_identifier,
1142
+ paths=remote_file_paths,
1143
+ )
1144
+ for idx, signed_url in zip(
1145
+ range(start_idx, end_idx), signed_urls
1146
+ ):
1147
+ file_path, download_dest_path = file_paths[idx]
1148
+ future = executor.submit(
1149
+ self._download_artifact,
1150
+ src_artifact_path=file_path,
1151
+ dst_local_dir_path=download_dest_path,
1152
+ signed_url=signed_url,
1153
+ abort_event=abort_event,
1154
+ progress_bar=progress_bar,
1155
+ )
1156
+ futures.append(future)
1157
+
1158
+ done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
1159
+ if len(not_done) > 0:
1160
+ abort_event.set()
1161
+ for future in not_done:
1162
+ future.cancel()
1163
+ for future in done:
1164
+ if future.exception() is not None:
1165
+ raise future.exception()
1166
+
1167
+ output_dir = os.path.join(dst_path, artifact_path)
1168
+ return output_dir
1169
+ else:
1170
+ return self._download_artifact(
1171
+ src_artifact_path=artifact_path,
1172
+ dst_local_dir_path=dst_path,
1173
+ signed_url=None,
1174
+ progress_bar=progress_bar,
1151
1175
  )
1152
- )
1153
- else:
1154
- return self._mlfoundry_artifacts_api.list_files_for_artifact_version_post(
1155
- list_files_for_artifact_version_request_dto=ListFilesForArtifactVersionRequestDto(
1156
- id=str(artifact_identifier.artifact_version_id),
1157
- path=path,
1158
- max_results=page_size,
1159
- page_token=page_token,
1176
+ except Exception as err:
1177
+ if is_dir_temp:
1178
+ logger.info(
1179
+ f"Error encountered, removing temporary download directory at {dst_path}"
1160
1180
  )
1161
- )
1181
+ rmtree(dst_path) # remove temp directory alongside it's contents
1182
+ raise err
1183
+
1184
+ finally:
1185
+ progress_bar.stop()