truefoundry 0.4.4rc10__py3-none-any.whl → 0.4.4rc12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

@@ -16,7 +16,6 @@ from typing import (
16
16
  List,
17
17
  NamedTuple,
18
18
  Optional,
19
- Sequence,
20
19
  Tuple,
21
20
  Union,
22
21
  )
@@ -89,6 +88,13 @@ _GENERATE_SIGNED_URL_BATCH_SIZE = 50
89
88
  DEFAULT_PRESIGNED_URL_EXPIRY_TIME = 3600
90
89
 
91
90
 
91
+ def _get_relpath_if_in_tempdir(path: str) -> str:
92
+ tempdir = tempfile.gettempdir()
93
+ if path.startswith(tempdir):
94
+ return os.path.relpath(path, tempdir)
95
+ return path
96
+
97
+
92
98
  def _can_display_progress(user_choice: Optional[bool] = None) -> bool:
93
99
  if user_choice is False:
94
100
  return False
@@ -233,7 +239,7 @@ def _signed_url_upload_file(
233
239
  return
234
240
 
235
241
  task_progress_bar = progress_bar.add_task(
236
- f"[green]Uploading {local_file}:", start=True
242
+ f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
237
243
  )
238
244
 
239
245
  def callback(length):
@@ -334,7 +340,7 @@ def _s3_compatible_multipart_upload(
334
340
  parts = []
335
341
 
336
342
  multi_part_upload_progress = progress_bar.add_task(
337
- f"[green]Uploading {local_file}:", start=True
343
+ f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
338
344
  )
339
345
 
340
346
  def upload(part_number: int, seek: int) -> None:
@@ -404,7 +410,7 @@ def _azure_multi_part_upload(
404
410
  abort_event = abort_event or Event()
405
411
 
406
412
  multi_part_upload_progress = progress_bar.add_task(
407
- f"[green]Uploading {local_file}:", start=True
413
+ f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
408
414
  )
409
415
 
410
416
  def upload(part_number: int, seek: int):
@@ -504,56 +510,164 @@ class MlFoundryArtifactsRepository:
504
510
  api_client=self._api_client
505
511
  )
506
512
 
507
- # TODO (chiragjn): Refactor these methods - if else is very inconvenient
508
- def get_signed_urls_for_read(
509
- self,
510
- artifact_identifier: ArtifactIdentifier,
511
- paths,
512
- ) -> List[SignedURLDto]:
513
- if artifact_identifier.artifact_version_id:
514
- signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_read_post(
515
- get_signed_urls_for_artifact_version_read_request_dto=GetSignedURLsForArtifactVersionReadRequestDto(
516
- id=str(artifact_identifier.artifact_version_id), paths=paths
517
- )
513
+ def _create_download_destination(
514
+ self, src_artifact_path, dst_local_dir_path=None
515
+ ) -> str:
516
+ """
517
+ Creates a local filesystem location to be used as a destination for downloading the artifact
518
+ specified by `src_artifact_path`. The destination location is a subdirectory of the
519
+ specified `dst_local_dir_path`, which is determined according to the structure of
520
+ `src_artifact_path`. For example, if `src_artifact_path` is `dir1/file1.txt`, then the
521
+ resulting destination path is `<dst_local_dir_path>/dir1/file1.txt`. Local directories are
522
+ created for the resulting destination location if they do not exist.
523
+
524
+ :param src_artifact_path: A relative, POSIX-style path referring to an artifact stored
525
+ within the repository's artifact root location.
526
+ `src_artifact_path` should be specified relative to the
527
+ repository's artifact root location.
528
+ :param dst_local_dir_path: The absolute path to a local filesystem directory in which the
529
+ local destination path will be contained. The local destination
530
+ path may be contained in a subdirectory of `dst_root_dir` if
531
+ `src_artifact_path` contains subdirectories.
532
+ :return: The absolute path to a local filesystem location to be used as a destination
533
+ for downloading the artifact specified by `src_artifact_path`.
534
+ """
535
+ src_artifact_path = src_artifact_path.rstrip(
536
+ "/"
537
+ ) # Ensure correct dirname for trailing '/'
538
+ dirpath = posixpath.dirname(src_artifact_path)
539
+ local_dir_path = os.path.join(dst_local_dir_path, dirpath)
540
+ local_file_path = os.path.join(dst_local_dir_path, src_artifact_path)
541
+ if not os.path.exists(local_dir_path):
542
+ os.makedirs(local_dir_path, exist_ok=True)
543
+ return local_file_path
544
+
545
+ # these methods should be named list_files, log_directory, log_file, etc
546
+ def list_artifacts(
547
+ self, path=None, page_size=_LIST_FILES_PAGE_SIZE, **kwargs
548
+ ) -> Iterator[FileInfoDto]:
549
+ page_token = None
550
+ started = False
551
+ while not started or page_token is not None:
552
+ started = True
553
+ page = self.list_files(
554
+ artifact_identifier=self.artifact_identifier,
555
+ path=path,
556
+ page_size=page_size,
557
+ page_token=page_token,
518
558
  )
519
- signed_urls = signed_urls_response.signed_urls
520
- elif artifact_identifier.dataset_fqn:
521
- signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_dataset_read_post(
522
- get_signed_urls_for_dataset_read_request_dto=GetSignedURLsForDatasetReadRequestDto(
523
- dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
559
+ for file_info in page.files:
560
+ yield file_info
561
+ page_token = page.next_page_token
562
+
563
+ def log_artifacts( # noqa: C901
564
+ self, local_dir, artifact_path=None, progress=None
565
+ ):
566
+ show_progress = _can_display_progress(progress)
567
+
568
+ dest_path = artifact_path or ""
569
+ dest_path = dest_path.lstrip(posixpath.sep)
570
+
571
+ files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
572
+ files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
573
+
574
+ for root, _, file_names in os.walk(local_dir):
575
+ upload_path = dest_path
576
+ if root != local_dir:
577
+ rel_path = os.path.relpath(root, local_dir)
578
+ rel_path = relative_path_to_artifact_path(rel_path)
579
+ upload_path = posixpath.join(dest_path, rel_path)
580
+ for file_name in file_names:
581
+ local_file = os.path.join(root, file_name)
582
+ multipart_info = _decide_file_parts(local_file)
583
+
584
+ final_upload_path = upload_path or ""
585
+ final_upload_path = final_upload_path.lstrip(posixpath.sep)
586
+ final_upload_path = posixpath.join(
587
+ final_upload_path, os.path.basename(local_file)
524
588
  )
525
- )
526
- signed_urls = signed_urls_dataset_response.signed_urls
527
- else:
528
- raise ValueError(
529
- "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
530
- )
531
- return signed_urls
532
589
 
533
- def get_signed_urls_for_write(
534
- self,
535
- artifact_identifier: ArtifactIdentifier,
536
- paths: List[str],
537
- ) -> List[SignedURLDto]:
538
- if artifact_identifier.artifact_version_id:
539
- signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_write_post(
540
- get_signed_urls_for_artifact_version_write_request_dto=GetSignedURLsForArtifactVersionWriteRequestDto(
541
- id=str(artifact_identifier.artifact_version_id), paths=paths
590
+ if multipart_info.num_parts == 1:
591
+ files_for_normal_upload.append(
592
+ (final_upload_path, local_file, multipart_info)
593
+ )
594
+ else:
595
+ files_for_multipart_upload.append(
596
+ (final_upload_path, local_file, multipart_info)
597
+ )
598
+
599
+ abort_event = Event()
600
+
601
+ with Progress(
602
+ "[progress.description]{task.description}",
603
+ BarColumn(bar_width=None),
604
+ "[progress.percentage]{task.percentage:>3.0f}%",
605
+ DownloadColumn(),
606
+ TransferSpeedColumn(),
607
+ TimeRemainingColumn(),
608
+ TimeElapsedColumn(),
609
+ refresh_per_second=1,
610
+ disable=not show_progress,
611
+ expand=True,
612
+ ) as progress_bar, ThreadPoolExecutor(
613
+ max_workers=ENV_VARS.TFY_ARTIFACTS_UPLOAD_MAX_WORKERS
614
+ ) as executor:
615
+ futures: List[Future] = []
616
+ # Note: While this batching is beneficial when there is a large number of files, there is also
617
+ # a rare case risk of the signed url expiring before a request is made to it
618
+ _batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
619
+ for start_idx in range(0, len(files_for_normal_upload), _batch_size):
620
+ end_idx = min(start_idx + _batch_size, len(files_for_normal_upload))
621
+ if _any_future_has_failed(futures):
622
+ break
623
+ logger.debug("Generating write signed urls for a batch ...")
624
+ remote_file_paths = [
625
+ files_for_normal_upload[idx][0] for idx in range(start_idx, end_idx)
626
+ ]
627
+ signed_urls = self.get_signed_urls_for_write(
628
+ artifact_identifier=self.artifact_identifier,
629
+ paths=remote_file_paths,
542
630
  )
543
- )
544
- signed_urls = signed_urls_response.signed_urls
545
- elif artifact_identifier.dataset_fqn:
546
- signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_for_dataset_write_post(
547
- get_signed_url_for_dataset_write_request_dto=GetSignedURLForDatasetWriteRequestDto(
548
- dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
631
+ for idx, signed_url in zip(range(start_idx, end_idx), signed_urls):
632
+ (
633
+ upload_path,
634
+ local_file,
635
+ multipart_info,
636
+ ) = files_for_normal_upload[idx]
637
+ future = executor.submit(
638
+ self._log_artifact,
639
+ local_file=local_file,
640
+ artifact_path=upload_path,
641
+ multipart_info=multipart_info,
642
+ signed_url=signed_url,
643
+ abort_event=abort_event,
644
+ executor_for_multipart_upload=None,
645
+ progress_bar=progress_bar,
646
+ )
647
+ futures.append(future)
648
+
649
+ done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
650
+ if len(not_done) > 0:
651
+ abort_event.set()
652
+ for future in not_done:
653
+ future.cancel()
654
+ for future in done:
655
+ if future.exception() is not None:
656
+ raise future.exception()
657
+
658
+ for (
659
+ upload_path,
660
+ local_file,
661
+ multipart_info,
662
+ ) in files_for_multipart_upload:
663
+ self._log_artifact(
664
+ local_file=local_file,
665
+ artifact_path=upload_path,
666
+ signed_url=None,
667
+ multipart_info=multipart_info,
668
+ executor_for_multipart_upload=executor,
669
+ progress_bar=progress_bar,
549
670
  )
550
- )
551
- signed_urls = signed_urls_dataset_response.signed_urls
552
- else:
553
- raise ValueError(
554
- "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
555
- )
556
- return signed_urls
557
671
 
558
672
  def _normal_upload(
559
673
  self,
@@ -571,7 +685,7 @@ class MlFoundryArtifactsRepository:
571
685
  if progress_bar.disable:
572
686
  logger.info(
573
687
  "Uploading %s to %s",
574
- local_file,
688
+ _get_relpath_if_in_tempdir(local_file),
575
689
  artifact_path,
576
690
  )
577
691
 
@@ -583,36 +697,6 @@ class MlFoundryArtifactsRepository:
583
697
  )
584
698
  logger.debug("Uploaded %s to %s", local_file, artifact_path)
585
699
 
586
- def _create_multipart_upload_for_identifier(
587
- self,
588
- artifact_identifier: ArtifactIdentifier,
589
- path,
590
- num_parts,
591
- ) -> MultiPartUploadDto:
592
- if artifact_identifier.artifact_version_id:
593
- create_multipart_response: MultiPartUploadResponseDto = self._mlfoundry_artifacts_api.create_multi_part_upload_post(
594
- create_multi_part_upload_request_dto=CreateMultiPartUploadRequestDto(
595
- artifact_version_id=str(artifact_identifier.artifact_version_id),
596
- path=path,
597
- num_parts=num_parts,
598
- )
599
- )
600
- multipart_upload = create_multipart_response.multipart_upload
601
- elif artifact_identifier.dataset_fqn:
602
- create_multipart_for_dataset_response = self._mlfoundry_artifacts_api.create_multipart_upload_for_dataset_post(
603
- create_multi_part_upload_for_dataset_request_dto=CreateMultiPartUploadForDatasetRequestDto(
604
- dataset_fqn=artifact_identifier.dataset_fqn,
605
- path=path,
606
- num_parts=num_parts,
607
- )
608
- )
609
- multipart_upload = create_multipart_for_dataset_response.multipart_upload
610
- else:
611
- raise ValueError(
612
- "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
613
- )
614
- return multipart_upload
615
-
616
700
  def _multipart_upload(
617
701
  self,
618
702
  local_file: str,
@@ -625,11 +709,11 @@ class MlFoundryArtifactsRepository:
625
709
  if progress_bar.disable:
626
710
  logger.info(
627
711
  "Uploading %s to %s using multipart upload",
628
- local_file,
712
+ _get_relpath_if_in_tempdir(local_file),
629
713
  artifact_path,
630
714
  )
631
715
 
632
- multipart_upload = self._create_multipart_upload_for_identifier(
716
+ multipart_upload = self.create_multipart_upload_for_identifier(
633
717
  artifact_identifier=self.artifact_identifier,
634
718
  path=artifact_path,
635
719
  num_parts=multipart_info.num_parts,
@@ -661,7 +745,7 @@ class MlFoundryArtifactsRepository:
661
745
  else:
662
746
  raise NotImplementedError()
663
747
 
664
- def _upload_file(
748
+ def _log_artifact(
665
749
  self,
666
750
  local_file: str,
667
751
  artifact_path: str,
@@ -700,14 +784,10 @@ class MlFoundryArtifactsRepository:
700
784
  progress_bar=progress_bar,
701
785
  )
702
786
 
703
- def _upload(
704
- self,
705
- files_for_normal_upload: Sequence[Tuple[str, str, _FileMultiPartInfo]],
706
- files_for_multipart_upload: Sequence[Tuple[str, str, _FileMultiPartInfo]],
707
- progress: Optional[bool] = None,
708
- ):
709
- abort_event = Event()
710
- show_progress = _can_display_progress(progress)
787
+ def log_artifact(self, local_file: str, artifact_path: Optional[str] = None):
788
+ upload_path = artifact_path or ""
789
+ upload_path = upload_path.lstrip(posixpath.sep)
790
+ upload_path = posixpath.join(upload_path, os.path.basename(local_file))
711
791
  with Progress(
712
792
  "[progress.description]{task.description}",
713
793
  BarColumn(bar_width=None),
@@ -717,321 +797,21 @@ class MlFoundryArtifactsRepository:
717
797
  TimeRemainingColumn(),
718
798
  TimeElapsedColumn(),
719
799
  refresh_per_second=1,
720
- disable=not show_progress,
800
+ disable=True,
721
801
  expand=True,
722
- ) as progress_bar, ThreadPoolExecutor(
723
- max_workers=ENV_VARS.TFY_ARTIFACTS_UPLOAD_MAX_WORKERS
724
- ) as executor:
725
- futures: List[Future] = []
726
- # Note: While this batching is beneficial when there is a large number of files, there is also
727
- # a rare case risk of the signed url expiring before a request is made to it
728
- _batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
729
- for start_idx in range(0, len(files_for_normal_upload), _batch_size):
730
- end_idx = min(start_idx + _batch_size, len(files_for_normal_upload))
731
- if _any_future_has_failed(futures):
732
- break
733
- logger.debug("Generating write signed urls for a batch ...")
734
- remote_file_paths = [
735
- files_for_normal_upload[idx][0] for idx in range(start_idx, end_idx)
736
- ]
737
- signed_urls = self.get_signed_urls_for_write(
738
- artifact_identifier=self.artifact_identifier,
739
- paths=remote_file_paths,
740
- )
741
- for idx, signed_url in zip(range(start_idx, end_idx), signed_urls):
742
- (
743
- upload_path,
744
- local_file,
745
- multipart_info,
746
- ) = files_for_normal_upload[idx]
747
- future = executor.submit(
748
- self._upload_file,
749
- local_file=local_file,
750
- artifact_path=upload_path,
751
- multipart_info=multipart_info,
752
- signed_url=signed_url,
753
- abort_event=abort_event,
754
- executor_for_multipart_upload=None,
755
- progress_bar=progress_bar,
756
- )
757
- futures.append(future)
758
-
759
- done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
760
- if len(not_done) > 0:
761
- abort_event.set()
762
- for future in not_done:
763
- future.cancel()
764
- for future in done:
765
- if future.exception() is not None:
766
- raise future.exception()
767
-
768
- for (
769
- upload_path,
770
- local_file,
771
- multipart_info,
772
- ) in files_for_multipart_upload:
773
- self._upload_file(
774
- local_file=local_file,
775
- artifact_path=upload_path,
776
- signed_url=None,
777
- multipart_info=multipart_info,
778
- executor_for_multipart_upload=executor,
779
- progress_bar=progress_bar,
780
- )
781
-
782
- def _add_file_for_upload(
783
- self,
784
- local_file: str,
785
- artifact_path: str,
786
- files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]],
787
- files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]],
788
- ):
789
- local_file = os.path.realpath(local_file)
790
- if os.path.isdir(local_file):
791
- raise MlFoundryException(
792
- "Cannot log a directory as an artifact. Use `log_artifacts` instead"
793
- )
794
- upload_path = artifact_path
795
- upload_path = upload_path.lstrip(posixpath.sep)
796
- multipart_info = _decide_file_parts(local_file)
797
- if multipart_info.num_parts == 1:
798
- files_for_normal_upload.append((upload_path, local_file, multipart_info))
799
- else:
800
- files_for_multipart_upload.append((upload_path, local_file, multipart_info))
801
-
802
- def _add_dir_for_upload(
803
- self,
804
- local_dir: str,
805
- artifact_path: str,
806
- files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]],
807
- files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]],
808
- ):
809
- local_dir = os.path.realpath(local_dir)
810
- if not os.path.isdir(local_dir):
811
- raise MlFoundryException(
812
- "Cannot log a file as a directory. Use `log_artifact` instead"
813
- )
814
- dest_path = artifact_path
815
- dest_path = dest_path.lstrip(posixpath.sep)
816
-
817
- for root, _, file_names in os.walk(local_dir):
818
- upload_path = dest_path
819
- if root != local_dir:
820
- rel_path = os.path.relpath(root, local_dir)
821
- rel_path = relative_path_to_artifact_path(rel_path)
822
- upload_path = posixpath.join(dest_path, rel_path)
823
- for file_name in file_names:
824
- local_file = os.path.join(root, file_name)
825
- self._add_file_for_upload(
826
- local_file=local_file,
827
- artifact_path=upload_path,
828
- files_for_normal_upload=files_for_normal_upload,
829
- files_for_multipart_upload=files_for_multipart_upload,
830
- )
831
-
832
- def log_artifacts(
833
- self,
834
- src_dest_pairs: Sequence[Tuple[str, Optional[str]]],
835
- progress: Optional[bool] = None,
836
- ):
837
- files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
838
- files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
839
- for src, dest in src_dest_pairs:
840
- if os.path.isdir(src):
841
- self._add_dir_for_upload(
842
- local_dir=src,
843
- artifact_path=dest or "",
844
- files_for_normal_upload=files_for_normal_upload,
845
- files_for_multipart_upload=files_for_multipart_upload,
846
- )
847
- else:
848
- self._add_file_for_upload(
849
- local_file=src,
850
- artifact_path=dest or "",
851
- files_for_normal_upload=files_for_normal_upload,
852
- files_for_multipart_upload=files_for_multipart_upload,
853
- )
854
- self._upload(
855
- files_for_normal_upload=files_for_normal_upload,
856
- files_for_multipart_upload=files_for_multipart_upload,
857
- progress=progress,
858
- )
859
-
860
- def _list_files(
861
- self, artifact_identifier: ArtifactIdentifier, path, page_size, page_token
862
- ) -> Union[ListFilesForDatasetResponseDto, ListFilesForArtifactVersionsResponseDto]:
863
- if artifact_identifier.dataset_fqn:
864
- return self._mlfoundry_artifacts_api.list_files_for_dataset_post(
865
- list_files_for_dataset_request_dto=ListFilesForDatasetRequestDto(
866
- dataset_fqn=artifact_identifier.dataset_fqn,
867
- path=path,
868
- max_results=page_size,
869
- page_token=page_token,
870
- )
871
- )
872
- else:
873
- return self._mlfoundry_artifacts_api.list_files_for_artifact_version_post(
874
- list_files_for_artifact_version_request_dto=ListFilesForArtifactVersionRequestDto(
875
- id=str(artifact_identifier.artifact_version_id),
876
- path=path,
877
- max_results=page_size,
878
- page_token=page_token,
879
- )
880
- )
881
-
882
- def list_artifacts(
883
- self, path=None, page_size=_LIST_FILES_PAGE_SIZE, **kwargs
884
- ) -> Iterator[FileInfoDto]:
885
- page_token = None
886
- started = False
887
- while not started or page_token is not None:
888
- started = True
889
- page = self._list_files(
890
- artifact_identifier=self.artifact_identifier,
891
- path=path,
892
- page_size=page_size,
893
- page_token=page_token,
802
+ ) as progress_bar:
803
+ self._log_artifact(
804
+ local_file=local_file,
805
+ artifact_path=upload_path,
806
+ multipart_info=_decide_file_parts(local_file),
807
+ progress_bar=progress_bar,
894
808
  )
895
- for file_info in page.files:
896
- yield file_info
897
- page_token = page.next_page_token
898
809
 
899
810
  def _is_directory(self, artifact_path):
900
811
  for _ in self.list_artifacts(artifact_path, page_size=3):
901
812
  return True
902
813
  return False
903
814
 
904
- def _create_download_destination(
905
- self, src_artifact_path: str, dst_local_dir_path: str
906
- ) -> str:
907
- """
908
- Creates a local filesystem location to be used as a destination for downloading the artifact
909
- specified by `src_artifact_path`. The destination location is a subdirectory of the
910
- specified `dst_local_dir_path`, which is determined according to the structure of
911
- `src_artifact_path`. For example, if `src_artifact_path` is `dir1/file1.txt`, then the
912
- resulting destination path is `<dst_local_dir_path>/dir1/file1.txt`. Local directories are
913
- created for the resulting destination location if they do not exist.
914
-
915
- :param src_artifact_path: A relative, POSIX-style path referring to an artifact stored
916
- within the repository's artifact root location.
917
- `src_artifact_path` should be specified relative to the
918
- repository's artifact root location.
919
- :param dst_local_dir_path: The absolute path to a local filesystem directory in which the
920
- local destination path will be contained. The local destination
921
- path may be contained in a subdirectory of `dst_root_dir` if
922
- `src_artifact_path` contains subdirectories.
923
- :return: The absolute path to a local filesystem location to be used as a destination
924
- for downloading the artifact specified by `src_artifact_path`.
925
- """
926
- src_artifact_path = src_artifact_path.rstrip(
927
- "/"
928
- ) # Ensure correct dirname for trailing '/'
929
- dirpath = posixpath.dirname(src_artifact_path)
930
- local_dir_path = os.path.join(dst_local_dir_path, dirpath)
931
- local_file_path = os.path.join(dst_local_dir_path, src_artifact_path)
932
- if not os.path.exists(local_dir_path):
933
- os.makedirs(local_dir_path, exist_ok=True)
934
- return local_file_path
935
-
936
- # noinspection PyMethodOverriding
937
- def _download_file(
938
- self,
939
- remote_file_path: str,
940
- local_path: str,
941
- progress_bar: Optional[Progress],
942
- signed_url: Optional[SignedURLDto],
943
- abort_event: Optional[Event] = None,
944
- ):
945
- if not remote_file_path:
946
- raise MlFoundryException(
947
- f"remote_file_path cannot be None or empty str {remote_file_path}"
948
- )
949
- if not signed_url:
950
- signed_url = self.get_signed_urls_for_read(
951
- artifact_identifier=self.artifact_identifier, paths=[remote_file_path]
952
- )[0]
953
-
954
- if progress_bar is None or not progress_bar.disable:
955
- logger.info("Downloading %s to %s", remote_file_path, local_path)
956
-
957
- if progress_bar is not None:
958
- download_progress_bar = progress_bar.add_task(
959
- f"[green]Downloading to {remote_file_path}:", start=True
960
- )
961
-
962
- def callback(chunk, total_file_size):
963
- if progress_bar is not None:
964
- progress_bar.update(
965
- download_progress_bar,
966
- advance=chunk,
967
- total=total_file_size,
968
- )
969
- if abort_event and abort_event.is_set():
970
- raise Exception("aborting download")
971
-
972
- _download_file_using_http_uri(
973
- http_uri=signed_url.signed_url,
974
- download_path=local_path,
975
- callback=callback,
976
- )
977
- logger.debug("Downloaded %s to %s", remote_file_path, local_path)
978
-
979
- def _download_artifact(
980
- self,
981
- src_artifact_path: str,
982
- dst_local_dir_path: str,
983
- signed_url: Optional[SignedURLDto],
984
- progress_bar: Optional[Progress] = None,
985
- abort_event=None,
986
- ) -> str:
987
- """
988
- Download the file artifact specified by `src_artifact_path` to the local filesystem
989
- directory specified by `dst_local_dir_path`.
990
- :param src_artifact_path: A relative, POSIX-style path referring to a file artifact
991
- stored within the repository's artifact root location.
992
- `src_artifact_path` should be specified relative to the
993
- repository's artifact root location.
994
- :param dst_local_dir_path: Absolute path of the local filesystem destination directory
995
- to which to download the specified artifact. The downloaded
996
- artifact may be written to a subdirectory of
997
- `dst_local_dir_path` if `src_artifact_path` contains
998
- subdirectories.
999
- :param progress_bar: An instance of a Rich progress bar used to visually display the
1000
- progress of the file download.
1001
- :return: A local filesystem path referring to the downloaded file.
1002
- """
1003
- local_destination_file_path = self._create_download_destination(
1004
- src_artifact_path=src_artifact_path, dst_local_dir_path=dst_local_dir_path
1005
- )
1006
- self._download_file(
1007
- remote_file_path=src_artifact_path,
1008
- local_path=local_destination_file_path,
1009
- signed_url=signed_url,
1010
- abort_event=abort_event,
1011
- progress_bar=progress_bar,
1012
- )
1013
- return local_destination_file_path
1014
-
1015
- def _get_file_paths_recur(self, src_artifact_dir_path, dst_local_dir_path):
1016
- local_dir = os.path.join(dst_local_dir_path, src_artifact_dir_path)
1017
- dir_content = [ # prevent infinite loop, sometimes the dir is recursively included
1018
- file_info
1019
- for file_info in self.list_artifacts(src_artifact_dir_path)
1020
- if file_info.path != "." and file_info.path != src_artifact_dir_path
1021
- ]
1022
- if not dir_content: # empty dir
1023
- if not os.path.exists(local_dir):
1024
- os.makedirs(local_dir, exist_ok=True)
1025
- else:
1026
- for file_info in dir_content:
1027
- if file_info.is_dir:
1028
- yield from self._get_file_paths_recur(
1029
- src_artifact_dir_path=file_info.path,
1030
- dst_local_dir_path=dst_local_dir_path,
1031
- )
1032
- else:
1033
- yield file_info.path, dst_local_dir_path
1034
-
1035
815
  def download_artifacts( # noqa: C901
1036
816
  self,
1037
817
  artifact_path: str,
@@ -1056,7 +836,7 @@ class MlFoundryArtifactsRepository:
1056
836
  str: Absolute path of the local filesystem location containing the desired artifacts.
1057
837
  """
1058
838
 
1059
- show_progress = _can_display_progress(user_choice=progress)
839
+ show_progress = _can_display_progress()
1060
840
 
1061
841
  is_dir_temp = False
1062
842
  if dst_path is None:
@@ -1183,3 +963,205 @@ class MlFoundryArtifactsRepository:
1183
963
 
1184
964
  finally:
1185
965
  progress_bar.stop()
966
+
967
+ # noinspection PyMethodOverriding
968
+ def _download_file(
969
+ self,
970
+ remote_file_path: str,
971
+ local_path: str,
972
+ progress_bar: Optional[Progress],
973
+ signed_url: Optional[SignedURLDto],
974
+ abort_event: Optional[Event] = None,
975
+ ):
976
+ if not remote_file_path:
977
+ raise MlFoundryException(
978
+ f"remote_file_path cannot be None or empty str {remote_file_path}"
979
+ )
980
+ if not signed_url:
981
+ signed_url = self.get_signed_urls_for_read(
982
+ artifact_identifier=self.artifact_identifier, paths=[remote_file_path]
983
+ )[0]
984
+
985
+ if progress_bar is None or not progress_bar.disable:
986
+ logger.info("Downloading %s to %s", remote_file_path, local_path)
987
+
988
+ if progress_bar is not None:
989
+ download_progress_bar = progress_bar.add_task(
990
+ f"[green]Downloading to {remote_file_path}:", start=True
991
+ )
992
+
993
+ def callback(chunk, total_file_size):
994
+ if progress_bar is not None:
995
+ progress_bar.update(
996
+ download_progress_bar,
997
+ advance=chunk,
998
+ total=total_file_size,
999
+ )
1000
+ if abort_event and abort_event.is_set():
1001
+ raise Exception("aborting download")
1002
+
1003
+ _download_file_using_http_uri(
1004
+ http_uri=signed_url.signed_url,
1005
+ download_path=local_path,
1006
+ callback=callback,
1007
+ )
1008
+ logger.debug("Downloaded %s to %s", remote_file_path, local_path)
1009
+
1010
+ def _download_artifact(
1011
+ self,
1012
+ src_artifact_path,
1013
+ dst_local_dir_path,
1014
+ signed_url: Optional[SignedURLDto],
1015
+ progress_bar: Optional[Progress] = None,
1016
+ abort_event=None,
1017
+ ) -> str:
1018
+ """
1019
+ Download the file artifact specified by `src_artifact_path` to the local filesystem
1020
+ directory specified by `dst_local_dir_path`.
1021
+ :param src_artifact_path: A relative, POSIX-style path referring to a file artifact
1022
+ stored within the repository's artifact root location.
1023
+ `src_artifact_path` should be specified relative to the
1024
+ repository's artifact root location.
1025
+ :param dst_local_dir_path: Absolute path of the local filesystem destination directory
1026
+ to which to download the specified artifact. The downloaded
1027
+ artifact may be written to a subdirectory of
1028
+ `dst_local_dir_path` if `src_artifact_path` contains
1029
+ subdirectories.
1030
+ :param progress_bar: An instance of a Rich progress bar used to visually display the
1031
+ progress of the file download.
1032
+ :return: A local filesystem path referring to the downloaded file.
1033
+ """
1034
+ local_destination_file_path = self._create_download_destination(
1035
+ src_artifact_path=src_artifact_path, dst_local_dir_path=dst_local_dir_path
1036
+ )
1037
+ self._download_file(
1038
+ remote_file_path=src_artifact_path,
1039
+ local_path=local_destination_file_path,
1040
+ signed_url=signed_url,
1041
+ abort_event=abort_event,
1042
+ progress_bar=progress_bar,
1043
+ )
1044
+ return local_destination_file_path
1045
+
1046
+ def _get_file_paths_recur(self, src_artifact_dir_path, dst_local_dir_path):
1047
+ local_dir = os.path.join(dst_local_dir_path, src_artifact_dir_path)
1048
+ dir_content = [ # prevent infinite loop, sometimes the dir is recursively included
1049
+ file_info
1050
+ for file_info in self.list_artifacts(src_artifact_dir_path)
1051
+ if file_info.path != "." and file_info.path != src_artifact_dir_path
1052
+ ]
1053
+ if not dir_content: # empty dir
1054
+ if not os.path.exists(local_dir):
1055
+ os.makedirs(local_dir, exist_ok=True)
1056
+ else:
1057
+ for file_info in dir_content:
1058
+ if file_info.is_dir:
1059
+ yield from self._get_file_paths_recur(
1060
+ src_artifact_dir_path=file_info.path,
1061
+ dst_local_dir_path=dst_local_dir_path,
1062
+ )
1063
+ else:
1064
+ yield file_info.path, dst_local_dir_path
1065
+
1066
+ # TODO (chiragjn): Refactor these methods - if else is very inconvenient
1067
+ def get_signed_urls_for_read(
1068
+ self,
1069
+ artifact_identifier: ArtifactIdentifier,
1070
+ paths,
1071
+ ) -> List[SignedURLDto]:
1072
+ if artifact_identifier.artifact_version_id:
1073
+ signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_read_post(
1074
+ get_signed_urls_for_artifact_version_read_request_dto=GetSignedURLsForArtifactVersionReadRequestDto(
1075
+ id=str(artifact_identifier.artifact_version_id), paths=paths
1076
+ )
1077
+ )
1078
+ signed_urls = signed_urls_response.signed_urls
1079
+ elif artifact_identifier.dataset_fqn:
1080
+ signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_dataset_read_post(
1081
+ get_signed_urls_for_dataset_read_request_dto=GetSignedURLsForDatasetReadRequestDto(
1082
+ dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
1083
+ )
1084
+ )
1085
+ signed_urls = signed_urls_dataset_response.signed_urls
1086
+ else:
1087
+ raise ValueError(
1088
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1089
+ )
1090
+ return signed_urls
1091
+
1092
+ def get_signed_urls_for_write(
1093
+ self,
1094
+ artifact_identifier: ArtifactIdentifier,
1095
+ paths: List[str],
1096
+ ) -> List[SignedURLDto]:
1097
+ if artifact_identifier.artifact_version_id:
1098
+ signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_write_post(
1099
+ get_signed_urls_for_artifact_version_write_request_dto=GetSignedURLsForArtifactVersionWriteRequestDto(
1100
+ id=str(artifact_identifier.artifact_version_id), paths=paths
1101
+ )
1102
+ )
1103
+ signed_urls = signed_urls_response.signed_urls
1104
+ elif artifact_identifier.dataset_fqn:
1105
+ signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_for_dataset_write_post(
1106
+ get_signed_url_for_dataset_write_request_dto=GetSignedURLForDatasetWriteRequestDto(
1107
+ dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
1108
+ )
1109
+ )
1110
+ signed_urls = signed_urls_dataset_response.signed_urls
1111
+ else:
1112
+ raise ValueError(
1113
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1114
+ )
1115
+ return signed_urls
1116
+
1117
+ def create_multipart_upload_for_identifier(
1118
+ self,
1119
+ artifact_identifier: ArtifactIdentifier,
1120
+ path,
1121
+ num_parts,
1122
+ ) -> MultiPartUploadDto:
1123
+ if artifact_identifier.artifact_version_id:
1124
+ create_multipart_response: MultiPartUploadResponseDto = self._mlfoundry_artifacts_api.create_multi_part_upload_post(
1125
+ create_multi_part_upload_request_dto=CreateMultiPartUploadRequestDto(
1126
+ artifact_version_id=str(artifact_identifier.artifact_version_id),
1127
+ path=path,
1128
+ num_parts=num_parts,
1129
+ )
1130
+ )
1131
+ multipart_upload = create_multipart_response.multipart_upload
1132
+ elif artifact_identifier.dataset_fqn:
1133
+ create_multipart_for_dataset_response = self._mlfoundry_artifacts_api.create_multipart_upload_for_dataset_post(
1134
+ create_multi_part_upload_for_dataset_request_dto=CreateMultiPartUploadForDatasetRequestDto(
1135
+ dataset_fqn=artifact_identifier.dataset_fqn,
1136
+ path=path,
1137
+ num_parts=num_parts,
1138
+ )
1139
+ )
1140
+ multipart_upload = create_multipart_for_dataset_response.multipart_upload
1141
+ else:
1142
+ raise ValueError(
1143
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1144
+ )
1145
+ return multipart_upload
1146
+
1147
+ def list_files(
1148
+ self, artifact_identifier: ArtifactIdentifier, path, page_size, page_token
1149
+ ) -> Union[ListFilesForDatasetResponseDto, ListFilesForArtifactVersionsResponseDto]:
1150
+ if artifact_identifier.dataset_fqn:
1151
+ return self._mlfoundry_artifacts_api.list_files_for_dataset_post(
1152
+ list_files_for_dataset_request_dto=ListFilesForDatasetRequestDto(
1153
+ dataset_fqn=artifact_identifier.dataset_fqn,
1154
+ path=path,
1155
+ max_results=page_size,
1156
+ page_token=page_token,
1157
+ )
1158
+ )
1159
+ else:
1160
+ return self._mlfoundry_artifacts_api.list_files_for_artifact_version_post(
1161
+ list_files_for_artifact_version_request_dto=ListFilesForArtifactVersionRequestDto(
1162
+ id=str(artifact_identifier.artifact_version_id),
1163
+ path=path,
1164
+ max_results=page_size,
1165
+ page_token=page_token,
1166
+ )
1167
+ )