truefoundry 0.4.4rc9__py3-none-any.whl → 0.4.4rc10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (43) hide show
  1. truefoundry/ml/artifact/truefoundry_artifact_repo.py +433 -415
  2. truefoundry/ml/autogen/client/__init__.py +24 -3
  3. truefoundry/ml/autogen/client/api/experiments_api.py +0 -137
  4. truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +2 -0
  5. truefoundry/ml/autogen/client/models/__init__.py +24 -3
  6. truefoundry/ml/autogen/client/models/artifact_dto.py +9 -0
  7. truefoundry/ml/autogen/client/models/artifact_version_dto.py +26 -0
  8. truefoundry/ml/autogen/client/models/artifact_version_serialization_format.py +34 -0
  9. truefoundry/ml/autogen/client/models/create_artifact_version_response_dto.py +8 -2
  10. truefoundry/ml/autogen/client/models/create_run_request_dto.py +1 -10
  11. truefoundry/ml/autogen/client/models/dataset_dto.py +9 -0
  12. truefoundry/ml/autogen/client/models/experiment_dto.py +14 -3
  13. truefoundry/ml/autogen/client/models/external_model_source.py +79 -0
  14. truefoundry/ml/autogen/client/models/finalize_artifact_version_request_dto.py +11 -0
  15. truefoundry/ml/autogen/client/models/framework.py +154 -0
  16. truefoundry/ml/autogen/client/models/library_name.py +35 -0
  17. truefoundry/ml/autogen/client/models/model_dto.py +9 -0
  18. truefoundry/ml/autogen/client/models/model_version_dto.py +26 -0
  19. truefoundry/ml/autogen/client/models/model_version_manifest.py +119 -0
  20. truefoundry/ml/autogen/client/models/run_info_dto.py +10 -1
  21. truefoundry/ml/autogen/client/models/source.py +177 -0
  22. truefoundry/ml/autogen/client/models/subject.py +79 -0
  23. truefoundry/ml/autogen/client/models/subject_type.py +34 -0
  24. truefoundry/ml/autogen/client/models/tensorflow_framework.py +74 -0
  25. truefoundry/ml/autogen/client/models/transformers_framework.py +90 -0
  26. truefoundry/ml/autogen/client/models/truefoundry_model_source.py +79 -0
  27. truefoundry/ml/autogen/client/models/update_model_version_request_dto.py +11 -0
  28. truefoundry/ml/autogen/client/models/upload_model_source.py +74 -0
  29. truefoundry/ml/autogen/client_README.md +12 -2
  30. truefoundry/ml/autogen/entities/artifacts.py +236 -4
  31. truefoundry/ml/log_types/artifacts/artifact.py +10 -6
  32. truefoundry/ml/log_types/artifacts/dataset.py +13 -5
  33. truefoundry/ml/log_types/artifacts/general_artifact.py +3 -1
  34. truefoundry/ml/log_types/artifacts/model.py +18 -30
  35. truefoundry/ml/log_types/artifacts/utils.py +42 -25
  36. truefoundry/ml/log_types/image/image.py +2 -0
  37. truefoundry/ml/log_types/plot.py +2 -0
  38. truefoundry/ml/mlfoundry_api.py +0 -1
  39. {truefoundry-0.4.4rc9.dist-info → truefoundry-0.4.4rc10.dist-info}/METADATA +1 -1
  40. {truefoundry-0.4.4rc9.dist-info → truefoundry-0.4.4rc10.dist-info}/RECORD +42 -31
  41. truefoundry/ml/autogen/client/models/list_seed_experiments_response_dto.py +0 -81
  42. {truefoundry-0.4.4rc9.dist-info → truefoundry-0.4.4rc10.dist-info}/WHEEL +0 -0
  43. {truefoundry-0.4.4rc9.dist-info → truefoundry-0.4.4rc10.dist-info}/entry_points.txt +0 -0
@@ -16,6 +16,7 @@ from typing import (
16
16
  List,
17
17
  NamedTuple,
18
18
  Optional,
19
+ Sequence,
19
20
  Tuple,
20
21
  Union,
21
22
  )
@@ -88,13 +89,6 @@ _GENERATE_SIGNED_URL_BATCH_SIZE = 50
88
89
  DEFAULT_PRESIGNED_URL_EXPIRY_TIME = 3600
89
90
 
90
91
 
91
- def _get_relpath_if_in_tempdir(path: str) -> str:
92
- tempdir = tempfile.gettempdir()
93
- if path.startswith(tempdir):
94
- return os.path.relpath(path, tempdir)
95
- return path
96
-
97
-
98
92
  def _can_display_progress(user_choice: Optional[bool] = None) -> bool:
99
93
  if user_choice is False:
100
94
  return False
@@ -239,7 +233,7 @@ def _signed_url_upload_file(
239
233
  return
240
234
 
241
235
  task_progress_bar = progress_bar.add_task(
242
- f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
236
+ f"[green]Uploading {local_file}:", start=True
243
237
  )
244
238
 
245
239
  def callback(length):
@@ -340,7 +334,7 @@ def _s3_compatible_multipart_upload(
340
334
  parts = []
341
335
 
342
336
  multi_part_upload_progress = progress_bar.add_task(
343
- f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
337
+ f"[green]Uploading {local_file}:", start=True
344
338
  )
345
339
 
346
340
  def upload(part_number: int, seek: int) -> None:
@@ -410,7 +404,7 @@ def _azure_multi_part_upload(
410
404
  abort_event = abort_event or Event()
411
405
 
412
406
  multi_part_upload_progress = progress_bar.add_task(
413
- f"[green]Uploading {_get_relpath_if_in_tempdir(local_file)}:", start=True
407
+ f"[green]Uploading {local_file}:", start=True
414
408
  )
415
409
 
416
410
  def upload(part_number: int, seek: int):
@@ -510,164 +504,56 @@ class MlFoundryArtifactsRepository:
510
504
  api_client=self._api_client
511
505
  )
512
506
 
513
- def _create_download_destination(
514
- self, src_artifact_path, dst_local_dir_path=None
515
- ) -> str:
516
- """
517
- Creates a local filesystem location to be used as a destination for downloading the artifact
518
- specified by `src_artifact_path`. The destination location is a subdirectory of the
519
- specified `dst_local_dir_path`, which is determined according to the structure of
520
- `src_artifact_path`. For example, if `src_artifact_path` is `dir1/file1.txt`, then the
521
- resulting destination path is `<dst_local_dir_path>/dir1/file1.txt`. Local directories are
522
- created for the resulting destination location if they do not exist.
523
-
524
- :param src_artifact_path: A relative, POSIX-style path referring to an artifact stored
525
- within the repository's artifact root location.
526
- `src_artifact_path` should be specified relative to the
527
- repository's artifact root location.
528
- :param dst_local_dir_path: The absolute path to a local filesystem directory in which the
529
- local destination path will be contained. The local destination
530
- path may be contained in a subdirectory of `dst_root_dir` if
531
- `src_artifact_path` contains subdirectories.
532
- :return: The absolute path to a local filesystem location to be used as a destination
533
- for downloading the artifact specified by `src_artifact_path`.
534
- """
535
- src_artifact_path = src_artifact_path.rstrip(
536
- "/"
537
- ) # Ensure correct dirname for trailing '/'
538
- dirpath = posixpath.dirname(src_artifact_path)
539
- local_dir_path = os.path.join(dst_local_dir_path, dirpath)
540
- local_file_path = os.path.join(dst_local_dir_path, src_artifact_path)
541
- if not os.path.exists(local_dir_path):
542
- os.makedirs(local_dir_path, exist_ok=True)
543
- return local_file_path
544
-
545
- # these methods should be named list_files, log_directory, log_file, etc
546
- def list_artifacts(
547
- self, path=None, page_size=_LIST_FILES_PAGE_SIZE, **kwargs
548
- ) -> Iterator[FileInfoDto]:
549
- page_token = None
550
- started = False
551
- while not started or page_token is not None:
552
- started = True
553
- page = self.list_files(
554
- artifact_identifier=self.artifact_identifier,
555
- path=path,
556
- page_size=page_size,
557
- page_token=page_token,
507
+ # TODO (chiragjn): Refactor these methods - if else is very inconvenient
508
+ def get_signed_urls_for_read(
509
+ self,
510
+ artifact_identifier: ArtifactIdentifier,
511
+ paths,
512
+ ) -> List[SignedURLDto]:
513
+ if artifact_identifier.artifact_version_id:
514
+ signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_read_post(
515
+ get_signed_urls_for_artifact_version_read_request_dto=GetSignedURLsForArtifactVersionReadRequestDto(
516
+ id=str(artifact_identifier.artifact_version_id), paths=paths
517
+ )
558
518
  )
559
- for file_info in page.files:
560
- yield file_info
561
- page_token = page.next_page_token
562
-
563
- def log_artifacts( # noqa: C901
564
- self, local_dir, artifact_path=None, progress=None
565
- ):
566
- show_progress = _can_display_progress(progress)
567
-
568
- dest_path = artifact_path or ""
569
- dest_path = dest_path.lstrip(posixpath.sep)
570
-
571
- files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
572
- files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
573
-
574
- for root, _, file_names in os.walk(local_dir):
575
- upload_path = dest_path
576
- if root != local_dir:
577
- rel_path = os.path.relpath(root, local_dir)
578
- rel_path = relative_path_to_artifact_path(rel_path)
579
- upload_path = posixpath.join(dest_path, rel_path)
580
- for file_name in file_names:
581
- local_file = os.path.join(root, file_name)
582
- multipart_info = _decide_file_parts(local_file)
583
-
584
- final_upload_path = upload_path or ""
585
- final_upload_path = final_upload_path.lstrip(posixpath.sep)
586
- final_upload_path = posixpath.join(
587
- final_upload_path, os.path.basename(local_file)
519
+ signed_urls = signed_urls_response.signed_urls
520
+ elif artifact_identifier.dataset_fqn:
521
+ signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_dataset_read_post(
522
+ get_signed_urls_for_dataset_read_request_dto=GetSignedURLsForDatasetReadRequestDto(
523
+ dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
588
524
  )
525
+ )
526
+ signed_urls = signed_urls_dataset_response.signed_urls
527
+ else:
528
+ raise ValueError(
529
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
530
+ )
531
+ return signed_urls
589
532
 
590
- if multipart_info.num_parts == 1:
591
- files_for_normal_upload.append(
592
- (final_upload_path, local_file, multipart_info)
593
- )
594
- else:
595
- files_for_multipart_upload.append(
596
- (final_upload_path, local_file, multipart_info)
597
- )
598
-
599
- abort_event = Event()
600
-
601
- with Progress(
602
- "[progress.description]{task.description}",
603
- BarColumn(bar_width=None),
604
- "[progress.percentage]{task.percentage:>3.0f}%",
605
- DownloadColumn(),
606
- TransferSpeedColumn(),
607
- TimeRemainingColumn(),
608
- TimeElapsedColumn(),
609
- refresh_per_second=1,
610
- disable=not show_progress,
611
- expand=True,
612
- ) as progress_bar, ThreadPoolExecutor(
613
- max_workers=ENV_VARS.TFY_ARTIFACTS_UPLOAD_MAX_WORKERS
614
- ) as executor:
615
- futures: List[Future] = []
616
- # Note: While this batching is beneficial when there is a large number of files, there is also
617
- # a rare case risk of the signed url expiring before a request is made to it
618
- _batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
619
- for start_idx in range(0, len(files_for_normal_upload), _batch_size):
620
- end_idx = min(start_idx + _batch_size, len(files_for_normal_upload))
621
- if _any_future_has_failed(futures):
622
- break
623
- logger.debug("Generating write signed urls for a batch ...")
624
- remote_file_paths = [
625
- files_for_normal_upload[idx][0] for idx in range(start_idx, end_idx)
626
- ]
627
- signed_urls = self.get_signed_urls_for_write(
628
- artifact_identifier=self.artifact_identifier,
629
- paths=remote_file_paths,
533
+ def get_signed_urls_for_write(
534
+ self,
535
+ artifact_identifier: ArtifactIdentifier,
536
+ paths: List[str],
537
+ ) -> List[SignedURLDto]:
538
+ if artifact_identifier.artifact_version_id:
539
+ signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_write_post(
540
+ get_signed_urls_for_artifact_version_write_request_dto=GetSignedURLsForArtifactVersionWriteRequestDto(
541
+ id=str(artifact_identifier.artifact_version_id), paths=paths
630
542
  )
631
- for idx, signed_url in zip(range(start_idx, end_idx), signed_urls):
632
- (
633
- upload_path,
634
- local_file,
635
- multipart_info,
636
- ) = files_for_normal_upload[idx]
637
- future = executor.submit(
638
- self._log_artifact,
639
- local_file=local_file,
640
- artifact_path=upload_path,
641
- multipart_info=multipart_info,
642
- signed_url=signed_url,
643
- abort_event=abort_event,
644
- executor_for_multipart_upload=None,
645
- progress_bar=progress_bar,
646
- )
647
- futures.append(future)
648
-
649
- done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
650
- if len(not_done) > 0:
651
- abort_event.set()
652
- for future in not_done:
653
- future.cancel()
654
- for future in done:
655
- if future.exception() is not None:
656
- raise future.exception()
657
-
658
- for (
659
- upload_path,
660
- local_file,
661
- multipart_info,
662
- ) in files_for_multipart_upload:
663
- self._log_artifact(
664
- local_file=local_file,
665
- artifact_path=upload_path,
666
- signed_url=None,
667
- multipart_info=multipart_info,
668
- executor_for_multipart_upload=executor,
669
- progress_bar=progress_bar,
543
+ )
544
+ signed_urls = signed_urls_response.signed_urls
545
+ elif artifact_identifier.dataset_fqn:
546
+ signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_for_dataset_write_post(
547
+ get_signed_url_for_dataset_write_request_dto=GetSignedURLForDatasetWriteRequestDto(
548
+ dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
670
549
  )
550
+ )
551
+ signed_urls = signed_urls_dataset_response.signed_urls
552
+ else:
553
+ raise ValueError(
554
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
555
+ )
556
+ return signed_urls
671
557
 
672
558
  def _normal_upload(
673
559
  self,
@@ -685,7 +571,7 @@ class MlFoundryArtifactsRepository:
685
571
  if progress_bar.disable:
686
572
  logger.info(
687
573
  "Uploading %s to %s",
688
- _get_relpath_if_in_tempdir(local_file),
574
+ local_file,
689
575
  artifact_path,
690
576
  )
691
577
 
@@ -697,6 +583,36 @@ class MlFoundryArtifactsRepository:
697
583
  )
698
584
  logger.debug("Uploaded %s to %s", local_file, artifact_path)
699
585
 
586
+ def _create_multipart_upload_for_identifier(
587
+ self,
588
+ artifact_identifier: ArtifactIdentifier,
589
+ path,
590
+ num_parts,
591
+ ) -> MultiPartUploadDto:
592
+ if artifact_identifier.artifact_version_id:
593
+ create_multipart_response: MultiPartUploadResponseDto = self._mlfoundry_artifacts_api.create_multi_part_upload_post(
594
+ create_multi_part_upload_request_dto=CreateMultiPartUploadRequestDto(
595
+ artifact_version_id=str(artifact_identifier.artifact_version_id),
596
+ path=path,
597
+ num_parts=num_parts,
598
+ )
599
+ )
600
+ multipart_upload = create_multipart_response.multipart_upload
601
+ elif artifact_identifier.dataset_fqn:
602
+ create_multipart_for_dataset_response = self._mlfoundry_artifacts_api.create_multipart_upload_for_dataset_post(
603
+ create_multi_part_upload_for_dataset_request_dto=CreateMultiPartUploadForDatasetRequestDto(
604
+ dataset_fqn=artifact_identifier.dataset_fqn,
605
+ path=path,
606
+ num_parts=num_parts,
607
+ )
608
+ )
609
+ multipart_upload = create_multipart_for_dataset_response.multipart_upload
610
+ else:
611
+ raise ValueError(
612
+ "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
613
+ )
614
+ return multipart_upload
615
+
700
616
  def _multipart_upload(
701
617
  self,
702
618
  local_file: str,
@@ -709,11 +625,11 @@ class MlFoundryArtifactsRepository:
709
625
  if progress_bar.disable:
710
626
  logger.info(
711
627
  "Uploading %s to %s using multipart upload",
712
- _get_relpath_if_in_tempdir(local_file),
628
+ local_file,
713
629
  artifact_path,
714
630
  )
715
631
 
716
- multipart_upload = self.create_multipart_upload_for_identifier(
632
+ multipart_upload = self._create_multipart_upload_for_identifier(
717
633
  artifact_identifier=self.artifact_identifier,
718
634
  path=artifact_path,
719
635
  num_parts=multipart_info.num_parts,
@@ -745,7 +661,7 @@ class MlFoundryArtifactsRepository:
745
661
  else:
746
662
  raise NotImplementedError()
747
663
 
748
- def _log_artifact(
664
+ def _upload_file(
749
665
  self,
750
666
  local_file: str,
751
667
  artifact_path: str,
@@ -784,10 +700,14 @@ class MlFoundryArtifactsRepository:
784
700
  progress_bar=progress_bar,
785
701
  )
786
702
 
787
- def log_artifact(self, local_file: str, artifact_path: Optional[str] = None):
788
- upload_path = artifact_path or ""
789
- upload_path = upload_path.lstrip(posixpath.sep)
790
- upload_path = posixpath.join(upload_path, os.path.basename(local_file))
703
+ def _upload(
704
+ self,
705
+ files_for_normal_upload: Sequence[Tuple[str, str, _FileMultiPartInfo]],
706
+ files_for_multipart_upload: Sequence[Tuple[str, str, _FileMultiPartInfo]],
707
+ progress: Optional[bool] = None,
708
+ ):
709
+ abort_event = Event()
710
+ show_progress = _can_display_progress(progress)
791
711
  with Progress(
792
712
  "[progress.description]{task.description}",
793
713
  BarColumn(bar_width=None),
@@ -797,49 +717,349 @@ class MlFoundryArtifactsRepository:
797
717
  TimeRemainingColumn(),
798
718
  TimeElapsedColumn(),
799
719
  refresh_per_second=1,
800
- disable=True,
720
+ disable=not show_progress,
801
721
  expand=True,
802
- ) as progress_bar:
803
- self._log_artifact(
804
- local_file=local_file,
805
- artifact_path=upload_path,
806
- multipart_info=_decide_file_parts(local_file),
807
- progress_bar=progress_bar,
808
- )
809
-
810
- def _is_directory(self, artifact_path):
811
- for _ in self.list_artifacts(artifact_path, page_size=3):
812
- return True
813
- return False
814
-
815
- def download_artifacts( # noqa: C901
816
- self,
817
- artifact_path: str,
818
- dst_path: Optional[str] = None,
819
- overwrite: bool = False,
820
- progress: Optional[bool] = None,
821
- ) -> str:
822
- """
823
- Download an artifact file or directory to a local directory if applicable, and return a
824
- local path for it. The caller is responsible for managing the lifecycle of the downloaded artifacts.
825
-
826
- Args:
827
- artifact_path: Relative source path to the desired artifacts.
828
- dst_path: Absolute path of the local filesystem destination directory to which to
829
- download the specified artifacts. This directory must already exist.
830
- If unspecified, the artifacts will either be downloaded to a new
831
- uniquely-named directory.
832
- overwrite: if to overwrite the files at/inside `dst_path` if they exist
833
- progress: Show or hide progress bar
834
-
835
- Returns:
836
- str: Absolute path of the local filesystem location containing the desired artifacts.
837
- """
838
-
839
- show_progress = _can_display_progress()
840
-
841
- is_dir_temp = False
842
- if dst_path is None:
722
+ ) as progress_bar, ThreadPoolExecutor(
723
+ max_workers=ENV_VARS.TFY_ARTIFACTS_UPLOAD_MAX_WORKERS
724
+ ) as executor:
725
+ futures: List[Future] = []
726
+ # Note: While this batching is beneficial when there is a large number of files, there is also
727
+ # a rare case risk of the signed url expiring before a request is made to it
728
+ _batch_size = _GENERATE_SIGNED_URL_BATCH_SIZE
729
+ for start_idx in range(0, len(files_for_normal_upload), _batch_size):
730
+ end_idx = min(start_idx + _batch_size, len(files_for_normal_upload))
731
+ if _any_future_has_failed(futures):
732
+ break
733
+ logger.debug("Generating write signed urls for a batch ...")
734
+ remote_file_paths = [
735
+ files_for_normal_upload[idx][0] for idx in range(start_idx, end_idx)
736
+ ]
737
+ signed_urls = self.get_signed_urls_for_write(
738
+ artifact_identifier=self.artifact_identifier,
739
+ paths=remote_file_paths,
740
+ )
741
+ for idx, signed_url in zip(range(start_idx, end_idx), signed_urls):
742
+ (
743
+ upload_path,
744
+ local_file,
745
+ multipart_info,
746
+ ) = files_for_normal_upload[idx]
747
+ future = executor.submit(
748
+ self._upload_file,
749
+ local_file=local_file,
750
+ artifact_path=upload_path,
751
+ multipart_info=multipart_info,
752
+ signed_url=signed_url,
753
+ abort_event=abort_event,
754
+ executor_for_multipart_upload=None,
755
+ progress_bar=progress_bar,
756
+ )
757
+ futures.append(future)
758
+
759
+ done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
760
+ if len(not_done) > 0:
761
+ abort_event.set()
762
+ for future in not_done:
763
+ future.cancel()
764
+ for future in done:
765
+ if future.exception() is not None:
766
+ raise future.exception()
767
+
768
+ for (
769
+ upload_path,
770
+ local_file,
771
+ multipart_info,
772
+ ) in files_for_multipart_upload:
773
+ self._upload_file(
774
+ local_file=local_file,
775
+ artifact_path=upload_path,
776
+ signed_url=None,
777
+ multipart_info=multipart_info,
778
+ executor_for_multipart_upload=executor,
779
+ progress_bar=progress_bar,
780
+ )
781
+
782
+ def _add_file_for_upload(
783
+ self,
784
+ local_file: str,
785
+ artifact_path: str,
786
+ files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]],
787
+ files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]],
788
+ ):
789
+ local_file = os.path.realpath(local_file)
790
+ if os.path.isdir(local_file):
791
+ raise MlFoundryException(
792
+ "Cannot log a directory as an artifact. Use `log_artifacts` instead"
793
+ )
794
+ upload_path = artifact_path
795
+ upload_path = upload_path.lstrip(posixpath.sep)
796
+ multipart_info = _decide_file_parts(local_file)
797
+ if multipart_info.num_parts == 1:
798
+ files_for_normal_upload.append((upload_path, local_file, multipart_info))
799
+ else:
800
+ files_for_multipart_upload.append((upload_path, local_file, multipart_info))
801
+
802
+ def _add_dir_for_upload(
803
+ self,
804
+ local_dir: str,
805
+ artifact_path: str,
806
+ files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]],
807
+ files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]],
808
+ ):
809
+ local_dir = os.path.realpath(local_dir)
810
+ if not os.path.isdir(local_dir):
811
+ raise MlFoundryException(
812
+ "Cannot log a file as a directory. Use `log_artifact` instead"
813
+ )
814
+ dest_path = artifact_path
815
+ dest_path = dest_path.lstrip(posixpath.sep)
816
+
817
+ for root, _, file_names in os.walk(local_dir):
818
+ upload_path = dest_path
819
+ if root != local_dir:
820
+ rel_path = os.path.relpath(root, local_dir)
821
+ rel_path = relative_path_to_artifact_path(rel_path)
822
+ upload_path = posixpath.join(dest_path, rel_path)
823
+ for file_name in file_names:
824
+ local_file = os.path.join(root, file_name)
825
+ self._add_file_for_upload(
826
+ local_file=local_file,
827
+ artifact_path=upload_path,
828
+ files_for_normal_upload=files_for_normal_upload,
829
+ files_for_multipart_upload=files_for_multipart_upload,
830
+ )
831
+
832
+ def log_artifacts(
833
+ self,
834
+ src_dest_pairs: Sequence[Tuple[str, Optional[str]]],
835
+ progress: Optional[bool] = None,
836
+ ):
837
+ files_for_normal_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
838
+ files_for_multipart_upload: List[Tuple[str, str, _FileMultiPartInfo]] = []
839
+ for src, dest in src_dest_pairs:
840
+ if os.path.isdir(src):
841
+ self._add_dir_for_upload(
842
+ local_dir=src,
843
+ artifact_path=dest or "",
844
+ files_for_normal_upload=files_for_normal_upload,
845
+ files_for_multipart_upload=files_for_multipart_upload,
846
+ )
847
+ else:
848
+ self._add_file_for_upload(
849
+ local_file=src,
850
+ artifact_path=dest or "",
851
+ files_for_normal_upload=files_for_normal_upload,
852
+ files_for_multipart_upload=files_for_multipart_upload,
853
+ )
854
+ self._upload(
855
+ files_for_normal_upload=files_for_normal_upload,
856
+ files_for_multipart_upload=files_for_multipart_upload,
857
+ progress=progress,
858
+ )
859
+
860
+ def _list_files(
861
+ self, artifact_identifier: ArtifactIdentifier, path, page_size, page_token
862
+ ) -> Union[ListFilesForDatasetResponseDto, ListFilesForArtifactVersionsResponseDto]:
863
+ if artifact_identifier.dataset_fqn:
864
+ return self._mlfoundry_artifacts_api.list_files_for_dataset_post(
865
+ list_files_for_dataset_request_dto=ListFilesForDatasetRequestDto(
866
+ dataset_fqn=artifact_identifier.dataset_fqn,
867
+ path=path,
868
+ max_results=page_size,
869
+ page_token=page_token,
870
+ )
871
+ )
872
+ else:
873
+ return self._mlfoundry_artifacts_api.list_files_for_artifact_version_post(
874
+ list_files_for_artifact_version_request_dto=ListFilesForArtifactVersionRequestDto(
875
+ id=str(artifact_identifier.artifact_version_id),
876
+ path=path,
877
+ max_results=page_size,
878
+ page_token=page_token,
879
+ )
880
+ )
881
+
882
+ def list_artifacts(
883
+ self, path=None, page_size=_LIST_FILES_PAGE_SIZE, **kwargs
884
+ ) -> Iterator[FileInfoDto]:
885
+ page_token = None
886
+ started = False
887
+ while not started or page_token is not None:
888
+ started = True
889
+ page = self._list_files(
890
+ artifact_identifier=self.artifact_identifier,
891
+ path=path,
892
+ page_size=page_size,
893
+ page_token=page_token,
894
+ )
895
+ for file_info in page.files:
896
+ yield file_info
897
+ page_token = page.next_page_token
898
+
899
+ def _is_directory(self, artifact_path):
900
+ for _ in self.list_artifacts(artifact_path, page_size=3):
901
+ return True
902
+ return False
903
+
904
+ def _create_download_destination(
905
+ self, src_artifact_path: str, dst_local_dir_path: str
906
+ ) -> str:
907
+ """
908
+ Creates a local filesystem location to be used as a destination for downloading the artifact
909
+ specified by `src_artifact_path`. The destination location is a subdirectory of the
910
+ specified `dst_local_dir_path`, which is determined according to the structure of
911
+ `src_artifact_path`. For example, if `src_artifact_path` is `dir1/file1.txt`, then the
912
+ resulting destination path is `<dst_local_dir_path>/dir1/file1.txt`. Local directories are
913
+ created for the resulting destination location if they do not exist.
914
+
915
+ :param src_artifact_path: A relative, POSIX-style path referring to an artifact stored
916
+ within the repository's artifact root location.
917
+ `src_artifact_path` should be specified relative to the
918
+ repository's artifact root location.
919
+ :param dst_local_dir_path: The absolute path to a local filesystem directory in which the
920
+ local destination path will be contained. The local destination
921
+ path may be contained in a subdirectory of `dst_root_dir` if
922
+ `src_artifact_path` contains subdirectories.
923
+ :return: The absolute path to a local filesystem location to be used as a destination
924
+ for downloading the artifact specified by `src_artifact_path`.
925
+ """
926
+ src_artifact_path = src_artifact_path.rstrip(
927
+ "/"
928
+ ) # Ensure correct dirname for trailing '/'
929
+ dirpath = posixpath.dirname(src_artifact_path)
930
+ local_dir_path = os.path.join(dst_local_dir_path, dirpath)
931
+ local_file_path = os.path.join(dst_local_dir_path, src_artifact_path)
932
+ if not os.path.exists(local_dir_path):
933
+ os.makedirs(local_dir_path, exist_ok=True)
934
+ return local_file_path
935
+
936
+ # noinspection PyMethodOverriding
937
+ def _download_file(
938
+ self,
939
+ remote_file_path: str,
940
+ local_path: str,
941
+ progress_bar: Optional[Progress],
942
+ signed_url: Optional[SignedURLDto],
943
+ abort_event: Optional[Event] = None,
944
+ ):
945
+ if not remote_file_path:
946
+ raise MlFoundryException(
947
+ f"remote_file_path cannot be None or empty str {remote_file_path}"
948
+ )
949
+ if not signed_url:
950
+ signed_url = self.get_signed_urls_for_read(
951
+ artifact_identifier=self.artifact_identifier, paths=[remote_file_path]
952
+ )[0]
953
+
954
+ if progress_bar is None or not progress_bar.disable:
955
+ logger.info("Downloading %s to %s", remote_file_path, local_path)
956
+
957
+ if progress_bar is not None:
958
+ download_progress_bar = progress_bar.add_task(
959
+ f"[green]Downloading to {remote_file_path}:", start=True
960
+ )
961
+
962
+ def callback(chunk, total_file_size):
963
+ if progress_bar is not None:
964
+ progress_bar.update(
965
+ download_progress_bar,
966
+ advance=chunk,
967
+ total=total_file_size,
968
+ )
969
+ if abort_event and abort_event.is_set():
970
+ raise Exception("aborting download")
971
+
972
+ _download_file_using_http_uri(
973
+ http_uri=signed_url.signed_url,
974
+ download_path=local_path,
975
+ callback=callback,
976
+ )
977
+ logger.debug("Downloaded %s to %s", remote_file_path, local_path)
978
+
979
+ def _download_artifact(
980
+ self,
981
+ src_artifact_path: str,
982
+ dst_local_dir_path: str,
983
+ signed_url: Optional[SignedURLDto],
984
+ progress_bar: Optional[Progress] = None,
985
+ abort_event=None,
986
+ ) -> str:
987
+ """
988
+ Download the file artifact specified by `src_artifact_path` to the local filesystem
989
+ directory specified by `dst_local_dir_path`.
990
+ :param src_artifact_path: A relative, POSIX-style path referring to a file artifact
991
+ stored within the repository's artifact root location.
992
+ `src_artifact_path` should be specified relative to the
993
+ repository's artifact root location.
994
+ :param dst_local_dir_path: Absolute path of the local filesystem destination directory
995
+ to which to download the specified artifact. The downloaded
996
+ artifact may be written to a subdirectory of
997
+ `dst_local_dir_path` if `src_artifact_path` contains
998
+ subdirectories.
999
+ :param progress_bar: An instance of a Rich progress bar used to visually display the
1000
+ progress of the file download.
1001
+ :return: A local filesystem path referring to the downloaded file.
1002
+ """
1003
+ local_destination_file_path = self._create_download_destination(
1004
+ src_artifact_path=src_artifact_path, dst_local_dir_path=dst_local_dir_path
1005
+ )
1006
+ self._download_file(
1007
+ remote_file_path=src_artifact_path,
1008
+ local_path=local_destination_file_path,
1009
+ signed_url=signed_url,
1010
+ abort_event=abort_event,
1011
+ progress_bar=progress_bar,
1012
+ )
1013
+ return local_destination_file_path
1014
+
1015
+ def _get_file_paths_recur(self, src_artifact_dir_path, dst_local_dir_path):
1016
+ local_dir = os.path.join(dst_local_dir_path, src_artifact_dir_path)
1017
+ dir_content = [ # prevent infinite loop, sometimes the dir is recursively included
1018
+ file_info
1019
+ for file_info in self.list_artifacts(src_artifact_dir_path)
1020
+ if file_info.path != "." and file_info.path != src_artifact_dir_path
1021
+ ]
1022
+ if not dir_content: # empty dir
1023
+ if not os.path.exists(local_dir):
1024
+ os.makedirs(local_dir, exist_ok=True)
1025
+ else:
1026
+ for file_info in dir_content:
1027
+ if file_info.is_dir:
1028
+ yield from self._get_file_paths_recur(
1029
+ src_artifact_dir_path=file_info.path,
1030
+ dst_local_dir_path=dst_local_dir_path,
1031
+ )
1032
+ else:
1033
+ yield file_info.path, dst_local_dir_path
1034
+
1035
+ def download_artifacts( # noqa: C901
1036
+ self,
1037
+ artifact_path: str,
1038
+ dst_path: Optional[str] = None,
1039
+ overwrite: bool = False,
1040
+ progress: Optional[bool] = None,
1041
+ ) -> str:
1042
+ """
1043
+ Download an artifact file or directory to a local directory if applicable, and return a
1044
+ local path for it. The caller is responsible for managing the lifecycle of the downloaded artifacts.
1045
+
1046
+ Args:
1047
+ artifact_path: Relative source path to the desired artifacts.
1048
+ dst_path: Absolute path of the local filesystem destination directory to which to
1049
+ download the specified artifacts. This directory must already exist.
1050
+ If unspecified, the artifacts will either be downloaded to a new
1051
+ uniquely-named directory.
1052
+ overwrite: if to overwrite the files at/inside `dst_path` if they exist
1053
+ progress: Show or hide progress bar
1054
+
1055
+ Returns:
1056
+ str: Absolute path of the local filesystem location containing the desired artifacts.
1057
+ """
1058
+
1059
+ show_progress = _can_display_progress(user_choice=progress)
1060
+
1061
+ is_dir_temp = False
1062
+ if dst_path is None:
843
1063
  dst_path = tempfile.mkdtemp()
844
1064
  is_dir_temp = True
845
1065
 
@@ -963,205 +1183,3 @@ class MlFoundryArtifactsRepository:
963
1183
 
964
1184
  finally:
965
1185
  progress_bar.stop()
966
-
967
- # noinspection PyMethodOverriding
968
- def _download_file(
969
- self,
970
- remote_file_path: str,
971
- local_path: str,
972
- progress_bar: Optional[Progress],
973
- signed_url: Optional[SignedURLDto],
974
- abort_event: Optional[Event] = None,
975
- ):
976
- if not remote_file_path:
977
- raise MlFoundryException(
978
- f"remote_file_path cannot be None or empty str {remote_file_path}"
979
- )
980
- if not signed_url:
981
- signed_url = self.get_signed_urls_for_read(
982
- artifact_identifier=self.artifact_identifier, paths=[remote_file_path]
983
- )[0]
984
-
985
- if progress_bar is None or not progress_bar.disable:
986
- logger.info("Downloading %s to %s", remote_file_path, local_path)
987
-
988
- if progress_bar is not None:
989
- download_progress_bar = progress_bar.add_task(
990
- f"[green]Downloading to {remote_file_path}:", start=True
991
- )
992
-
993
- def callback(chunk, total_file_size):
994
- if progress_bar is not None:
995
- progress_bar.update(
996
- download_progress_bar,
997
- advance=chunk,
998
- total=total_file_size,
999
- )
1000
- if abort_event and abort_event.is_set():
1001
- raise Exception("aborting download")
1002
-
1003
- _download_file_using_http_uri(
1004
- http_uri=signed_url.signed_url,
1005
- download_path=local_path,
1006
- callback=callback,
1007
- )
1008
- logger.debug("Downloaded %s to %s", remote_file_path, local_path)
1009
-
1010
- def _download_artifact(
1011
- self,
1012
- src_artifact_path,
1013
- dst_local_dir_path,
1014
- signed_url: Optional[SignedURLDto],
1015
- progress_bar: Optional[Progress] = None,
1016
- abort_event=None,
1017
- ) -> str:
1018
- """
1019
- Download the file artifact specified by `src_artifact_path` to the local filesystem
1020
- directory specified by `dst_local_dir_path`.
1021
- :param src_artifact_path: A relative, POSIX-style path referring to a file artifact
1022
- stored within the repository's artifact root location.
1023
- `src_artifact_path` should be specified relative to the
1024
- repository's artifact root location.
1025
- :param dst_local_dir_path: Absolute path of the local filesystem destination directory
1026
- to which to download the specified artifact. The downloaded
1027
- artifact may be written to a subdirectory of
1028
- `dst_local_dir_path` if `src_artifact_path` contains
1029
- subdirectories.
1030
- :param progress_bar: An instance of a Rich progress bar used to visually display the
1031
- progress of the file download.
1032
- :return: A local filesystem path referring to the downloaded file.
1033
- """
1034
- local_destination_file_path = self._create_download_destination(
1035
- src_artifact_path=src_artifact_path, dst_local_dir_path=dst_local_dir_path
1036
- )
1037
- self._download_file(
1038
- remote_file_path=src_artifact_path,
1039
- local_path=local_destination_file_path,
1040
- signed_url=signed_url,
1041
- abort_event=abort_event,
1042
- progress_bar=progress_bar,
1043
- )
1044
- return local_destination_file_path
1045
-
1046
- def _get_file_paths_recur(self, src_artifact_dir_path, dst_local_dir_path):
1047
- local_dir = os.path.join(dst_local_dir_path, src_artifact_dir_path)
1048
- dir_content = [ # prevent infinite loop, sometimes the dir is recursively included
1049
- file_info
1050
- for file_info in self.list_artifacts(src_artifact_dir_path)
1051
- if file_info.path != "." and file_info.path != src_artifact_dir_path
1052
- ]
1053
- if not dir_content: # empty dir
1054
- if not os.path.exists(local_dir):
1055
- os.makedirs(local_dir, exist_ok=True)
1056
- else:
1057
- for file_info in dir_content:
1058
- if file_info.is_dir:
1059
- yield from self._get_file_paths_recur(
1060
- src_artifact_dir_path=file_info.path,
1061
- dst_local_dir_path=dst_local_dir_path,
1062
- )
1063
- else:
1064
- yield file_info.path, dst_local_dir_path
1065
-
1066
- # TODO (chiragjn): Refactor these methods - if else is very inconvenient
1067
- def get_signed_urls_for_read(
1068
- self,
1069
- artifact_identifier: ArtifactIdentifier,
1070
- paths,
1071
- ) -> List[SignedURLDto]:
1072
- if artifact_identifier.artifact_version_id:
1073
- signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_read_post(
1074
- get_signed_urls_for_artifact_version_read_request_dto=GetSignedURLsForArtifactVersionReadRequestDto(
1075
- id=str(artifact_identifier.artifact_version_id), paths=paths
1076
- )
1077
- )
1078
- signed_urls = signed_urls_response.signed_urls
1079
- elif artifact_identifier.dataset_fqn:
1080
- signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_dataset_read_post(
1081
- get_signed_urls_for_dataset_read_request_dto=GetSignedURLsForDatasetReadRequestDto(
1082
- dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
1083
- )
1084
- )
1085
- signed_urls = signed_urls_dataset_response.signed_urls
1086
- else:
1087
- raise ValueError(
1088
- "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1089
- )
1090
- return signed_urls
1091
-
1092
- def get_signed_urls_for_write(
1093
- self,
1094
- artifact_identifier: ArtifactIdentifier,
1095
- paths: List[str],
1096
- ) -> List[SignedURLDto]:
1097
- if artifact_identifier.artifact_version_id:
1098
- signed_urls_response = self._mlfoundry_artifacts_api.get_signed_urls_for_write_post(
1099
- get_signed_urls_for_artifact_version_write_request_dto=GetSignedURLsForArtifactVersionWriteRequestDto(
1100
- id=str(artifact_identifier.artifact_version_id), paths=paths
1101
- )
1102
- )
1103
- signed_urls = signed_urls_response.signed_urls
1104
- elif artifact_identifier.dataset_fqn:
1105
- signed_urls_dataset_response = self._mlfoundry_artifacts_api.get_signed_urls_for_dataset_write_post(
1106
- get_signed_url_for_dataset_write_request_dto=GetSignedURLForDatasetWriteRequestDto(
1107
- dataset_fqn=artifact_identifier.dataset_fqn, paths=paths
1108
- )
1109
- )
1110
- signed_urls = signed_urls_dataset_response.signed_urls
1111
- else:
1112
- raise ValueError(
1113
- "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1114
- )
1115
- return signed_urls
1116
-
1117
- def create_multipart_upload_for_identifier(
1118
- self,
1119
- artifact_identifier: ArtifactIdentifier,
1120
- path,
1121
- num_parts,
1122
- ) -> MultiPartUploadDto:
1123
- if artifact_identifier.artifact_version_id:
1124
- create_multipart_response: MultiPartUploadResponseDto = self._mlfoundry_artifacts_api.create_multi_part_upload_post(
1125
- create_multi_part_upload_request_dto=CreateMultiPartUploadRequestDto(
1126
- artifact_version_id=str(artifact_identifier.artifact_version_id),
1127
- path=path,
1128
- num_parts=num_parts,
1129
- )
1130
- )
1131
- multipart_upload = create_multipart_response.multipart_upload
1132
- elif artifact_identifier.dataset_fqn:
1133
- create_multipart_for_dataset_response = self._mlfoundry_artifacts_api.create_multipart_upload_for_dataset_post(
1134
- create_multi_part_upload_for_dataset_request_dto=CreateMultiPartUploadForDatasetRequestDto(
1135
- dataset_fqn=artifact_identifier.dataset_fqn,
1136
- path=path,
1137
- num_parts=num_parts,
1138
- )
1139
- )
1140
- multipart_upload = create_multipart_for_dataset_response.multipart_upload
1141
- else:
1142
- raise ValueError(
1143
- "Invalid artifact type - both `artifact_version_id` and `dataset_fqn` both are None"
1144
- )
1145
- return multipart_upload
1146
-
1147
- def list_files(
1148
- self, artifact_identifier: ArtifactIdentifier, path, page_size, page_token
1149
- ) -> Union[ListFilesForDatasetResponseDto, ListFilesForArtifactVersionsResponseDto]:
1150
- if artifact_identifier.dataset_fqn:
1151
- return self._mlfoundry_artifacts_api.list_files_for_dataset_post(
1152
- list_files_for_dataset_request_dto=ListFilesForDatasetRequestDto(
1153
- dataset_fqn=artifact_identifier.dataset_fqn,
1154
- path=path,
1155
- max_results=page_size,
1156
- page_token=page_token,
1157
- )
1158
- )
1159
- else:
1160
- return self._mlfoundry_artifacts_api.list_files_for_artifact_version_post(
1161
- list_files_for_artifact_version_request_dto=ListFilesForArtifactVersionRequestDto(
1162
- id=str(artifact_identifier.artifact_version_id),
1163
- path=path,
1164
- max_results=page_size,
1165
- page_token=page_token,
1166
- )
1167
- )