databricks-sdk 0.67.0__py3-none-any.whl → 0.69.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of databricks-sdk might be problematic. Click here for more details.

Files changed (49) hide show
  1. databricks/sdk/__init__.py +14 -10
  2. databricks/sdk/_base_client.py +4 -1
  3. databricks/sdk/common/lro.py +17 -0
  4. databricks/sdk/common/types/__init__.py +0 -0
  5. databricks/sdk/common/types/fieldmask.py +39 -0
  6. databricks/sdk/config.py +62 -14
  7. databricks/sdk/credentials_provider.py +61 -12
  8. databricks/sdk/dbutils.py +5 -1
  9. databricks/sdk/errors/parser.py +8 -3
  10. databricks/sdk/mixins/files.py +1156 -111
  11. databricks/sdk/mixins/files_utils.py +293 -0
  12. databricks/sdk/oidc_token_supplier.py +80 -0
  13. databricks/sdk/retries.py +102 -2
  14. databricks/sdk/service/_internal.py +93 -1
  15. databricks/sdk/service/agentbricks.py +1 -1
  16. databricks/sdk/service/apps.py +264 -1
  17. databricks/sdk/service/billing.py +2 -3
  18. databricks/sdk/service/catalog.py +1026 -540
  19. databricks/sdk/service/cleanrooms.py +3 -3
  20. databricks/sdk/service/compute.py +21 -33
  21. databricks/sdk/service/dashboards.py +7 -3
  22. databricks/sdk/service/database.py +3 -2
  23. databricks/sdk/service/dataquality.py +1145 -0
  24. databricks/sdk/service/files.py +2 -1
  25. databricks/sdk/service/iam.py +2 -1
  26. databricks/sdk/service/iamv2.py +1 -1
  27. databricks/sdk/service/jobs.py +6 -9
  28. databricks/sdk/service/marketplace.py +3 -1
  29. databricks/sdk/service/ml.py +3 -1
  30. databricks/sdk/service/oauth2.py +1 -1
  31. databricks/sdk/service/pipelines.py +5 -6
  32. databricks/sdk/service/provisioning.py +544 -655
  33. databricks/sdk/service/qualitymonitorv2.py +1 -1
  34. databricks/sdk/service/serving.py +3 -1
  35. databricks/sdk/service/settings.py +5 -2
  36. databricks/sdk/service/settingsv2.py +1 -1
  37. databricks/sdk/service/sharing.py +12 -3
  38. databricks/sdk/service/sql.py +305 -70
  39. databricks/sdk/service/tags.py +1 -1
  40. databricks/sdk/service/vectorsearch.py +3 -1
  41. databricks/sdk/service/workspace.py +70 -17
  42. databricks/sdk/version.py +1 -1
  43. {databricks_sdk-0.67.0.dist-info → databricks_sdk-0.69.0.dist-info}/METADATA +4 -2
  44. databricks_sdk-0.69.0.dist-info/RECORD +84 -0
  45. databricks_sdk-0.67.0.dist-info/RECORD +0 -79
  46. {databricks_sdk-0.67.0.dist-info → databricks_sdk-0.69.0.dist-info}/WHEEL +0 -0
  47. {databricks_sdk-0.67.0.dist-info → databricks_sdk-0.69.0.dist-info}/licenses/LICENSE +0 -0
  48. {databricks_sdk-0.67.0.dist-info → databricks_sdk-0.69.0.dist-info}/licenses/NOTICE +0 -0
  49. {databricks_sdk-0.67.0.dist-info → databricks_sdk-0.69.0.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import base64
4
4
  import datetime
5
5
  import logging
6
+ import math
6
7
  import os
7
8
  import pathlib
8
9
  import platform
@@ -13,8 +14,13 @@ import xml.etree.ElementTree as ET
13
14
  from abc import ABC, abstractmethod
14
15
  from collections import deque
15
16
  from collections.abc import Iterator
17
+ from concurrent.futures import ThreadPoolExecutor, as_completed
18
+ from dataclasses import dataclass
16
19
  from datetime import timedelta
17
20
  from io import BytesIO
21
+ from queue import Empty, Full, Queue
22
+ from tempfile import mkstemp
23
+ from threading import Event, Thread
18
24
  from types import TracebackType
19
25
  from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Callable, Generator,
20
26
  Iterable, Optional, Type, Union)
@@ -27,12 +33,14 @@ from requests import RequestException
27
33
  from .._base_client import _BaseClient, _RawResponse, _StreamingResponse
28
34
  from .._property import _cached_property
29
35
  from ..config import Config
30
- from ..errors import AlreadyExists, NotFound
36
+ from ..errors import AlreadyExists, NotFound, PermissionDenied
31
37
  from ..errors.mapper import _error_mapper
32
38
  from ..retries import retried
33
39
  from ..service import files
34
40
  from ..service._internal import _escape_multi_segment_path_parameter
35
41
  from ..service.files import DownloadResponse
42
+ from .files_utils import (CreateDownloadUrlResponse, _ConcatenatedInputStream,
43
+ _PresignedUrlDistributor)
36
44
 
37
45
  if TYPE_CHECKING:
38
46
  from _typeshed import Self
@@ -710,18 +718,70 @@ class DbfsExt(files.DbfsAPI):
710
718
  p.delete(recursive=recursive)
711
719
 
712
720
 
721
+ class FallbackToUploadUsingFilesApi(Exception):
722
+ """Custom exception that signals to fallback to FilesAPI for upload"""
723
+
724
+ def __init__(self, buffer, message):
725
+ super().__init__(message)
726
+ self.buffer = buffer
727
+
728
+
729
+ class FallbackToDownloadUsingFilesApi(Exception):
730
+ """Custom exception that signals to fallback to FilesAPI for download"""
731
+
732
+ def __init__(self, message):
733
+ super().__init__(message)
734
+
735
+
736
+ @dataclass
737
+ class UploadStreamResult:
738
+ """Result of an upload from stream operation. Currently empty, but can be extended in the future."""
739
+
740
+
741
+ @dataclass
742
+ class UploadFileResult:
743
+ """Result of an upload from file operation. Currently empty, but can be extended in the future."""
744
+
745
+
746
+ @dataclass
747
+ class DownloadFileResult:
748
+ """Result of a download to file operation. Currently empty, but can be extended in the future."""
749
+
750
+
713
751
  class FilesExt(files.FilesAPI):
714
752
  __doc__ = files.FilesAPI.__doc__
715
753
 
716
754
  # note that these error codes are retryable only for idempotent operations
717
- _RETRYABLE_STATUS_CODES = [408, 429, 500, 502, 503, 504]
755
+ _RETRYABLE_STATUS_CODES: list[int] = [408, 429, 500, 502, 503, 504]
756
+
757
+ @dataclass(frozen=True)
758
+ class _UploadContext:
759
+ target_path: str
760
+ """The absolute remote path of the target file, e.g. /Volumes/path/to/your/file."""
761
+ overwrite: Optional[bool]
762
+ """If true, an existing file will be overwritten. When unspecified, default behavior of the cloud storage provider is performed."""
763
+ part_size: int
764
+ """The size of each part in bytes for multipart upload."""
765
+ batch_size: int
766
+ """The number of urls to request in a single batch."""
767
+ content_length: Optional[int] = None
768
+ """The total size of the content being uploaded, if known."""
769
+ source_file_path: Optional[str] = None
770
+ """The local path of the file being uploaded, if applicable."""
771
+ use_parallel: Optional[bool] = None
772
+ """If true, the upload will be performed using multiple threads."""
773
+ parallelism: Optional[int] = None
774
+ """The number of threads to use for parallel upload, if applicable."""
718
775
 
719
776
  def __init__(self, api_client, config: Config):
720
777
  super().__init__(api_client)
721
778
  self._config = config.copy()
722
779
  self._multipart_upload_read_ahead_bytes = 1
723
780
 
724
- def download(self, file_path: str) -> DownloadResponse:
781
+ def download(
782
+ self,
783
+ file_path: str,
784
+ ) -> DownloadResponse:
725
785
  """Download a file.
726
786
 
727
787
  Downloads a file of any size. The file contents are the response body.
@@ -736,48 +796,462 @@ class FilesExt(files.FilesAPI):
736
796
 
737
797
  :returns: :class:`DownloadResponse`
738
798
  """
799
+ if self._config.disable_experimental_files_api_client:
800
+ _LOG.info("Disable experimental files API client, will use the original download method.")
801
+ return super().download(file_path)
739
802
 
740
803
  initial_response: DownloadResponse = self._open_download_stream(
741
- file_path=file_path,
742
- start_byte_offset=0,
743
- if_unmodified_since_timestamp=None,
804
+ file_path=file_path, start_byte_offset=0, if_unmodified_since_timestamp=None
744
805
  )
745
806
 
746
807
  wrapped_response = self._wrap_stream(file_path, initial_response)
747
808
  initial_response.contents._response = wrapped_response
748
809
  return initial_response
749
810
 
750
- def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool] = None):
751
- """Upload a file.
811
+ def download_to(
812
+ self,
813
+ file_path: str,
814
+ destination: str,
815
+ *,
816
+ overwrite: bool = True,
817
+ use_parallel: bool = False,
818
+ parallelism: Optional[int] = None,
819
+ ) -> DownloadFileResult:
820
+ """Download a file to a local path. There would be no responses returned if the download is successful.
821
+
822
+ :param file_path: str
823
+ The remote path of the file, e.g. /Volumes/path/to/your/file
824
+ :param destination: str
825
+ The local path where the file will be saved.
826
+ :param overwrite: bool
827
+ If true, an existing file will be overwritten. When not specified, assumed True.
828
+ :param use_parallel: bool
829
+ If true, the download will be performed using multiple threads.
830
+ :param parallelism: int
831
+ The number of parallel threads to use for downloading. If not specified, defaults to the number of CPU cores.
832
+
833
+ :returns: :class:`DownloadFileResult`
834
+ """
835
+ if self._config.disable_experimental_files_api_client:
836
+ raise NotImplementedError(
837
+ "Experimental files API features are disabled, download_to is not supported. Please use download instead."
838
+ )
839
+
840
+ # The existence of the target file is checked before starting the download. This is a best-effort check
841
+ # to avoid overwriting an existing file. However, there is nothing preventing a file from being created
842
+ # at the destination path after this check and before the file is written, and no way to prevent other
843
+ # actor from writing to the destination path concurrently.
844
+ if not overwrite and os.path.exists(destination):
845
+ raise FileExistsError(destination)
846
+ if use_parallel:
847
+ # Parallel download is not supported for Windows due to the limit of only one open file handle
848
+ # for writing. If parallel download is requested on Windows, fall back to sequential download with
849
+ # a warning.
850
+ if platform.system() == "Windows":
851
+ _LOG.warning("Parallel download is not supported on Windows. Falling back to sequential download.")
852
+ self._sequential_download_to_file(destination, remote_path=file_path)
853
+ return DownloadFileResult()
854
+ if parallelism is None:
855
+ parallelism = self._config.files_ext_parallel_download_default_parallelism
856
+ if parallelism < 1 or parallelism > 64:
857
+ raise ValueError("parallelism must be between 1 and 64")
858
+ self._parallel_download_with_fallback(file_path, destination, parallelism=parallelism)
859
+ else:
860
+ self._sequential_download_to_file(destination, remote_path=file_path)
861
+ return DownloadFileResult()
862
+
863
+ def _parallel_download_with_fallback(self, remote_path: str, destination: str, parallelism: int) -> None:
864
+ """Download a file in parallel to a local path. There would be no responses returned if the download is successful.
865
+ This method first tries to use the Presigned URL for parallel download. If it fails due to permission issues,
866
+ it falls back to using Files API.
867
+
868
+ :param remote_path: str
869
+ The remote path of the file, e.g. /Volumes/path/to/your/file
870
+ :param destination: str
871
+ The local path where the file will be saved.
872
+ :param parallelism: int
873
+ The number of parallel threads to use for downloading.
874
+
875
+ :returns: None
876
+ """
877
+ try:
878
+ self._parallel_download_presigned_url(remote_path, destination, parallelism)
879
+ except FallbackToDownloadUsingFilesApi as e:
880
+ _LOG.info("Falling back to Files API download due to permission issues with Presigned URL: %s", e)
881
+ self._parallel_download_files_api(remote_path, destination, parallelism)
882
+
883
+ def _sequential_download_to_file(
884
+ self, destination: str, remote_path: str, last_modified: Optional[str] = None
885
+ ) -> None:
886
+ with open(destination, "wb") as f:
887
+ response = self._open_download_stream(
888
+ file_path=remote_path,
889
+ start_byte_offset=0,
890
+ if_unmodified_since_timestamp=last_modified,
891
+ )
892
+ wrapped_response = self._wrap_stream(remote_path, response, 0)
893
+ response.contents._response = wrapped_response
894
+ shutil.copyfileobj(response.contents, f)
895
+
896
+ def _do_parallel_download(
897
+ self, remote_path: str, destination: str, parallelism: int, download_chunk: Callable
898
+ ) -> None:
899
+
900
+ file_info = self.get_metadata(remote_path)
901
+ file_size = file_info.content_length
902
+ last_modified = file_info.last_modified
903
+ # If the file is smaller than the threshold, do not use parallel download.
904
+ if file_size <= self._config.files_ext_parallel_download_min_file_size:
905
+ self._sequential_download_to_file(destination, remote_path, last_modified)
906
+ return
907
+ part_size = self._config.files_ext_parallel_download_default_part_size
908
+ part_count = int(math.ceil(file_size / part_size))
909
+
910
+ fd, temp_file = mkstemp()
911
+ # We are preallocate the file size to the same as the remote file to avoid seeking beyond the file size.
912
+ os.truncate(temp_file, file_size)
913
+ os.close(fd)
914
+ try:
915
+ aborted = Event()
916
+
917
+ def wrapped_download_chunk(start: int, end: int, last_modified: Optional[str], temp_file: str) -> None:
918
+ if aborted.is_set():
919
+ return
920
+ additional_headers = {
921
+ "Range": f"bytes={start}-{end}",
922
+ "If-Unmodified-Since": last_modified,
923
+ }
924
+ try:
925
+ contents = download_chunk(additional_headers)
926
+ with open(temp_file, "r+b") as f:
927
+ f.seek(start)
928
+ shutil.copyfileobj(contents, f)
929
+ except Exception as e:
930
+ aborted.set()
931
+ raise e
932
+
933
+ with ThreadPoolExecutor(max_workers=parallelism) as executor:
934
+ futures = []
935
+ # Start the threads to download parts of the file.
936
+ for i in range(part_count):
937
+ start = i * part_size
938
+ end = min(start + part_size - 1, file_size - 1)
939
+ futures.append(executor.submit(wrapped_download_chunk, start, end, last_modified, temp_file))
940
+
941
+ # Wait for all threads to complete and check for exceptions.
942
+ for future in as_completed(futures):
943
+ exception = future.exception()
944
+ if exception:
945
+ raise exception
946
+ # Finally, move the temp file to the destination.
947
+ shutil.move(temp_file, destination)
948
+ finally:
949
+ if os.path.exists(temp_file):
950
+ os.remove(temp_file)
951
+
952
+ def _parallel_download_presigned_url(self, remote_path: str, destination: str, parallelism: int) -> None:
953
+ """Download a file in parallel to a local path. There would be no responses returned if the download is successful.
954
+
955
+ :param remote_path: str
956
+ The remote path of the file, e.g. /Volumes/path/to/your/file
957
+ :param destination: str
958
+ The local path where the file will be saved.
959
+ :param parallelism: int
960
+ The number of parallel threads to use for downloading.
961
+
962
+ :returns: None
963
+ """
964
+
965
+ cloud_session = self._create_cloud_provider_session()
966
+ url_distributor = _PresignedUrlDistributor(lambda: self._create_download_url(remote_path))
752
967
 
753
- Uploads a file. The file contents should be sent as the request body as raw bytes (an
754
- octet stream); do not encode or otherwise modify the bytes before sending. The contents of the
755
- resulting file will be exactly the bytes sent in the request body. If the request is successful, there
756
- is no response body.
968
+ def download_chunk(additional_headers: dict[str, str]) -> BinaryIO:
969
+ retry_count = 0
970
+ while retry_count < self._config.files_ext_parallel_download_max_retries:
971
+ url_and_header, version = url_distributor.get_url()
972
+
973
+ headers = {**url_and_header.headers, **additional_headers}
974
+
975
+ def get_content() -> requests.Response:
976
+ return cloud_session.get(url_and_header.url, headers=headers)
977
+
978
+ raw_resp = self._retry_cloud_idempotent_operation(get_content)
979
+
980
+ if FilesExt._is_url_expired_response(raw_resp):
981
+ _LOG.info("Presigned URL expired, fetching a new one.")
982
+ url_distributor.invalidate_url(version)
983
+ retry_count += 1
984
+ continue
985
+ elif raw_resp.status_code == 403:
986
+ raise FallbackToDownloadUsingFilesApi("Received 403 Forbidden from presigned URL")
987
+
988
+ raw_resp.raise_for_status()
989
+ return BytesIO(raw_resp.content)
990
+ raise ValueError("Exceeded maximum retries for downloading with presigned URL: URL expired too many times")
991
+
992
+ self._do_parallel_download(remote_path, destination, parallelism, download_chunk)
993
+
994
+ def _parallel_download_files_api(self, remote_path: str, destination: str, parallelism: int) -> None:
995
+ """Download a file in parallel to a local path using FilesAPI. There would be no responses returned if the download is successful.
996
+
997
+ :param remote_path: str
998
+ The remote path of the file, e.g. /Volumes/path/to/your/file
999
+ :param destination: str
1000
+ The local path where the file will be saved.
1001
+ :param parallelism: int
1002
+ The number of parallel threads to use for downloading.
1003
+
1004
+ :returns: None
1005
+ """
1006
+
1007
+ def download_chunk(additional_headers: dict[str, str]) -> BinaryIO:
1008
+ raw_response: dict = self._api.do(
1009
+ method="GET",
1010
+ path=f"/api/2.0/fs/files{remote_path}",
1011
+ headers=additional_headers,
1012
+ raw=True,
1013
+ )
1014
+ return raw_response["contents"]
1015
+
1016
+ self._do_parallel_download(remote_path, destination, parallelism, download_chunk)
1017
+
1018
+ def _get_optimized_performance_parameters_for_upload(
1019
+ self, content_length: Optional[int], part_size_overwrite: Optional[int]
1020
+ ) -> (int, int):
1021
+ """Get optimized part size and batch size for upload based on content length and provided part size.
1022
+
1023
+ Returns tuple of (part_size, batch_size).
1024
+ """
1025
+ chosen_part_size = None
1026
+
1027
+ # 1. decide on the part size
1028
+ if part_size_overwrite is not None: # If a part size is provided, we use it directly after validation.
1029
+ if part_size_overwrite > self._config.files_ext_multipart_upload_max_part_size:
1030
+ raise ValueError(
1031
+ f"Part size {part_size_overwrite} exceeds maximum allowed size {self._config.files_ext_multipart_upload_max_part_size} bytes."
1032
+ )
1033
+ chosen_part_size = part_size_overwrite
1034
+ _LOG.debug(f"Using provided part size: {chosen_part_size} bytes")
1035
+ else: # If no part size is provided, we will optimize based on the content length.
1036
+ if content_length is not None:
1037
+ # Choosing the smallest part size that allows for a maximum of 100 parts.
1038
+ for part_size in self._config.files_ext_multipart_upload_part_size_options:
1039
+ part_num = (content_length + part_size - 1) // part_size
1040
+ if part_num <= 100:
1041
+ chosen_part_size = part_size
1042
+ _LOG.debug(
1043
+ f"Optimized part size for upload: {chosen_part_size} bytes for content length {content_length} bytes"
1044
+ )
1045
+ break
1046
+ if chosen_part_size is None: # If no part size was chosen, we default to the maximum allowed part size.
1047
+ chosen_part_size = self._config.files_ext_multipart_upload_max_part_size
1048
+
1049
+ # Use defaults if not determined yet
1050
+ if chosen_part_size is None:
1051
+ chosen_part_size = self._config.files_ext_multipart_upload_default_part_size
1052
+
1053
+ # 2. decide on the batch size
1054
+ if content_length is not None and chosen_part_size is not None:
1055
+ part_num = (content_length + chosen_part_size - 1) // chosen_part_size
1056
+ chosen_batch_size = int(
1057
+ math.ceil(math.sqrt(part_num))
1058
+ ) # Using the square root of the number of parts as a heuristic for batch size.
1059
+ else:
1060
+ chosen_batch_size = self._config.files_ext_multipart_upload_batch_url_count
1061
+
1062
+ return chosen_part_size, chosen_batch_size
1063
+
1064
+ def upload(
1065
+ self,
1066
+ file_path: str,
1067
+ content: BinaryIO,
1068
+ *,
1069
+ overwrite: Optional[bool] = None,
1070
+ part_size: Optional[int] = None,
1071
+ use_parallel: bool = True,
1072
+ parallelism: Optional[int] = None,
1073
+ ) -> UploadStreamResult:
1074
+ """
1075
+ Upload a file with stream interface.
1076
+
1077
+ :param file_path: str
1078
+ The absolute remote path of the target file, e.g. /Volumes/path/to/your/file
1079
+ :param content: BinaryIO
1080
+ The contents of the file to upload. This must be a BinaryIO stream.
1081
+ :param overwrite: bool (optional)
1082
+ If true, an existing file will be overwritten. When not specified, assumed True.
1083
+ :param part_size: int (optional)
1084
+ If set, multipart upload will use the value as its size per uploading part.
1085
+ :param use_parallel: bool (optional)
1086
+ If true, the upload will be performed using multiple threads. Be aware that this will consume more memory
1087
+ because multiple parts will be buffered in memory before being uploaded. The amount of memory used is proportional
1088
+ to `parallelism * part_size`.
1089
+ If false, the upload will be performed in a single thread.
1090
+ Default is True.
1091
+ :param parallelism: int (optional)
1092
+ The number of threads to use for parallel uploads. This is only used if `use_parallel` is True.
1093
+
1094
+ :returns: :class:`UploadStreamResult`
1095
+ """
1096
+
1097
+ if self._config.disable_experimental_files_api_client:
1098
+ _LOG.info("Disable experimental files API client, will use the original upload method.")
1099
+ super().upload(file_path=file_path, contents=content, overwrite=overwrite)
1100
+ return UploadStreamResult()
1101
+
1102
+ _LOG.debug(f"Uploading file from BinaryIO stream")
1103
+ if parallelism is not None and not use_parallel:
1104
+ raise ValueError("parallelism can only be set if use_parallel is True")
1105
+ if parallelism is None and use_parallel:
1106
+ parallelism = self._config.files_ext_multipart_upload_default_parallelism
1107
+
1108
+ # Determine content length if the stream is seekable
1109
+ content_length = None
1110
+ if content.seekable():
1111
+ _LOG.debug(f"Uploading using seekable mode")
1112
+ # If the stream is seekable, we can read its size.
1113
+ content.seek(0, os.SEEK_END)
1114
+ content_length = content.tell()
1115
+ content.seek(0)
1116
+
1117
+ # Get optimized part size and batch size based on content length and provided part size
1118
+ optimized_part_size, optimized_batch_size = self._get_optimized_performance_parameters_for_upload(
1119
+ content_length, part_size
1120
+ )
1121
+
1122
+ # Create context with all final parameters
1123
+ ctx = self._UploadContext(
1124
+ target_path=file_path,
1125
+ overwrite=overwrite,
1126
+ part_size=optimized_part_size,
1127
+ batch_size=optimized_batch_size,
1128
+ content_length=content_length,
1129
+ use_parallel=use_parallel,
1130
+ parallelism=parallelism,
1131
+ )
1132
+
1133
+ _LOG.debug(
1134
+ f"Upload context: part_size={ctx.part_size}, batch_size={ctx.batch_size}, content_length={ctx.content_length}"
1135
+ )
1136
+
1137
+ if ctx.use_parallel:
1138
+ self._parallel_upload_from_stream(ctx, content)
1139
+ return UploadStreamResult()
1140
+ elif ctx.content_length is not None:
1141
+ self._upload_single_thread_with_known_size(ctx, content)
1142
+ return UploadStreamResult()
1143
+ else:
1144
+ _LOG.debug(f"Uploading using non-seekable mode")
1145
+ # If the stream is not seekable, we cannot determine its size.
1146
+ # We will use a multipart upload.
1147
+ _LOG.debug(f"Using multipart upload for non-seekable input stream of unknown size for file {file_path}")
1148
+ self._single_thread_multipart_upload(ctx, content)
1149
+ return UploadStreamResult()
1150
+
1151
+ def upload_from(
1152
+ self,
1153
+ file_path: str,
1154
+ source_path: str,
1155
+ *,
1156
+ overwrite: Optional[bool] = None,
1157
+ part_size: Optional[int] = None,
1158
+ use_parallel: bool = True,
1159
+ parallelism: Optional[int] = None,
1160
+ ) -> UploadFileResult:
1161
+ """Upload a file directly from a local path.
757
1162
 
758
1163
  :param file_path: str
759
1164
  The absolute remote path of the target file.
760
- :param contents: BinaryIO
1165
+ :param source_path: str
1166
+ The local path of the file to upload. This must be a path to a local file.
1167
+ :param part_size: int
1168
+ The size of each part in bytes for multipart upload. This is a required parameter for multipart uploads.
761
1169
  :param overwrite: bool (optional)
762
1170
  If true, an existing file will be overwritten. When not specified, assumed True.
763
- """
1171
+ :param use_parallel: bool (optional)
1172
+ If true, the upload will be performed using multiple threads. Default is True.
1173
+ :param parallelism: int (optional)
1174
+ The number of threads to use for parallel uploads. This is only used if `use_parallel` is True.
1175
+ If not specified, the default parallelism will be set to config.multipart_upload_default_parallelism
764
1176
 
765
- # Upload empty and small files with one-shot upload.
766
- pre_read_buffer = contents.read(self._config.multipart_upload_min_stream_size)
767
- if len(pre_read_buffer) < self._config.multipart_upload_min_stream_size:
768
- _LOG.debug(
769
- f"Using one-shot upload for input stream of size {len(pre_read_buffer)} below {self._config.multipart_upload_min_stream_size} bytes"
1177
+ :returns: :class:`UploadFileResult`
1178
+ """
1179
+ if self._config.disable_experimental_files_api_client:
1180
+ raise NotImplementedError(
1181
+ "Experimental files API features are disabled, upload_from is not supported. Please use upload instead."
770
1182
  )
771
- return super().upload(file_path=file_path, contents=BytesIO(pre_read_buffer), overwrite=overwrite)
772
1183
 
1184
+ _LOG.debug(f"Uploading file from local path: {source_path}")
1185
+
1186
+ if parallelism is not None and not use_parallel:
1187
+ raise ValueError("parallelism can only be set if use_parallel is True")
1188
+ if parallelism is None and use_parallel:
1189
+ parallelism = self._config.files_ext_multipart_upload_default_parallelism
1190
+ # Get the file size
1191
+ file_size = os.path.getsize(source_path)
1192
+
1193
+ # Get optimized part size and batch size based on content length and provided part size
1194
+ optimized_part_size, optimized_batch_size = self._get_optimized_performance_parameters_for_upload(
1195
+ file_size, part_size
1196
+ )
1197
+
1198
+ # Create context with all final parameters
1199
+ ctx = self._UploadContext(
1200
+ target_path=file_path,
1201
+ overwrite=overwrite,
1202
+ part_size=optimized_part_size,
1203
+ batch_size=optimized_batch_size,
1204
+ content_length=file_size,
1205
+ source_file_path=source_path,
1206
+ use_parallel=use_parallel,
1207
+ parallelism=parallelism,
1208
+ )
1209
+ if ctx.use_parallel:
1210
+ self._parallel_upload_from_file(ctx)
1211
+ return UploadFileResult()
1212
+ else:
1213
+ with open(source_path, "rb") as f:
1214
+ self._upload_single_thread_with_known_size(ctx, f)
1215
+ return UploadFileResult()
1216
+
1217
+ def _upload_single_thread_with_known_size(self, ctx: _UploadContext, contents: BinaryIO) -> None:
1218
+ """Upload a file with a known size."""
1219
+ if ctx.content_length < self._config.files_ext_multipart_upload_min_stream_size:
1220
+ _LOG.debug(f"Using single-shot upload for input stream of size {ctx.content_length} bytes")
1221
+ return self._single_thread_single_shot_upload(ctx, contents)
1222
+ else:
1223
+ _LOG.debug(f"Using multipart upload for input stream of size {ctx.content_length} bytes")
1224
+ return self._single_thread_multipart_upload(ctx, contents)
1225
+
1226
+ def _single_thread_single_shot_upload(self, ctx: _UploadContext, contents: BinaryIO) -> None:
1227
+ """Upload a file with a known size."""
1228
+ _LOG.debug(f"Using single-shot upload for input stream")
1229
+ return super().upload(file_path=ctx.target_path, contents=contents, overwrite=ctx.overwrite)
1230
+
1231
+ def _initiate_multipart_upload(self, ctx: _UploadContext) -> dict:
1232
+ """Initiate a multipart upload and return the response."""
773
1233
  query = {"action": "initiate-upload"}
774
- if overwrite is not None:
775
- query["overwrite"] = overwrite
1234
+ if ctx.overwrite is not None:
1235
+ query["overwrite"] = ctx.overwrite
776
1236
 
777
1237
  # Method _api.do() takes care of retrying and will raise an exception in case of failure.
778
1238
  initiate_upload_response = self._api.do(
779
- "POST", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}", query=query
1239
+ "POST", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(ctx.target_path)}", query=query
780
1240
  )
1241
+ return initiate_upload_response
1242
+
1243
+ def _single_thread_multipart_upload(self, ctx: _UploadContext, contents: BinaryIO) -> None:
1244
+
1245
+ # Upload empty and small files with one-shot upload.
1246
+ pre_read_buffer = contents.read(self._config.files_ext_multipart_upload_min_stream_size)
1247
+ if len(pre_read_buffer) < self._config.files_ext_multipart_upload_min_stream_size:
1248
+ _LOG.debug(
1249
+ f"Using one-shot upload for input stream of size {len(pre_read_buffer)} below {self._config.files_ext_multipart_upload_min_stream_size} bytes"
1250
+ )
1251
+ return self._single_thread_single_shot_upload(ctx, BytesIO(pre_read_buffer))
1252
+
1253
+ # Initiate the multipart upload.
1254
+ initiate_upload_response = self._initiate_multipart_upload(ctx)
781
1255
 
782
1256
  if initiate_upload_response.get("multipart_upload"):
783
1257
  cloud_provider_session = self._create_cloud_provider_session()
@@ -786,37 +1260,451 @@ class FilesExt(files.FilesAPI):
786
1260
  raise ValueError(f"Unexpected server response: {initiate_upload_response}")
787
1261
 
788
1262
  try:
789
- self._perform_multipart_upload(
790
- file_path, contents, session_token, pre_read_buffer, cloud_provider_session
791
- )
1263
+ self._perform_multipart_upload(ctx, contents, session_token, pre_read_buffer, cloud_provider_session)
1264
+ except FallbackToUploadUsingFilesApi as e:
1265
+ try:
1266
+ self._abort_multipart_upload(ctx, session_token, cloud_provider_session)
1267
+ except BaseException as ex:
1268
+ # Ignore abort exceptions as it is a best-effort.
1269
+ _LOG.warning(f"Failed to abort upload: {ex}")
1270
+
1271
+ _LOG.info(f"Falling back to single-shot upload with Files API: {e}")
1272
+ # Concatenate the buffered part and the rest of the stream.
1273
+ full_stream = _ConcatenatedInputStream(BytesIO(e.buffer), contents)
1274
+ return self._single_thread_single_shot_upload(ctx, full_stream)
1275
+
792
1276
  except Exception as e:
793
1277
  _LOG.info(f"Aborting multipart upload on error: {e}")
794
1278
  try:
795
- self._abort_multipart_upload(file_path, session_token, cloud_provider_session)
1279
+ self._abort_multipart_upload(ctx, session_token, cloud_provider_session)
796
1280
  except BaseException as ex:
1281
+ # Ignore abort exceptions as it is a best-effort.
797
1282
  _LOG.warning(f"Failed to abort upload: {ex}")
798
- # ignore, abort is a best-effort
799
1283
  finally:
800
- # rethrow original exception
1284
+ # Rethrow the original exception
801
1285
  raise e from None
802
1286
 
803
1287
  elif initiate_upload_response.get("resumable_upload"):
804
1288
  cloud_provider_session = self._create_cloud_provider_session()
805
1289
  session_token = initiate_upload_response["resumable_upload"]["session_token"]
806
- self._perform_resumable_upload(
807
- file_path, contents, session_token, overwrite, pre_read_buffer, cloud_provider_session
808
- )
1290
+
1291
+ try:
1292
+ self._perform_resumable_upload(ctx, contents, session_token, pre_read_buffer, cloud_provider_session)
1293
+ except FallbackToUploadUsingFilesApi as e:
1294
+ _LOG.info(f"Falling back to single-shot upload with Files API: {e}")
1295
+ # Concatenate the buffered part and the rest of the stream.
1296
+ full_stream = _ConcatenatedInputStream(BytesIO(e.buffer), contents)
1297
+ return self._single_thread_single_shot_upload(ctx, full_stream)
809
1298
  else:
810
1299
  raise ValueError(f"Unexpected server response: {initiate_upload_response}")
811
1300
 
1301
+ def _parallel_upload_from_stream(self, ctx: _UploadContext, contents: BinaryIO) -> None:
1302
+ """
1303
+ Upload a stream using multipart upload with multiple threads.
1304
+ This method is not implemented in this example, but it would typically
1305
+ involve creating multiple threads to upload different parts of the stream concurrently.
1306
+ """
1307
+ initiate_upload_response = self._initiate_multipart_upload(ctx)
1308
+
1309
+ if initiate_upload_response.get("resumable_upload"):
1310
+ _LOG.warning("GCP does not support parallel resumable uploads, falling back to single-threaded upload")
1311
+ return self._single_thread_multipart_upload(ctx, contents)
1312
+ elif initiate_upload_response.get("multipart_upload"):
1313
+ session_token = initiate_upload_response["multipart_upload"].get("session_token")
1314
+ cloud_provider_session = self._create_cloud_provider_session()
1315
+ if not session_token:
1316
+ raise ValueError(f"Unexpected server response: {initiate_upload_response}")
1317
+ try:
1318
+ self._parallel_multipart_upload_from_stream(ctx, session_token, contents, cloud_provider_session)
1319
+ except FallbackToUploadUsingFilesApi as e:
1320
+ try:
1321
+ self._abort_multipart_upload(ctx, session_token, cloud_provider_session)
1322
+ except Exception as abort_ex:
1323
+ _LOG.warning(f"Failed to abort upload: {abort_ex}")
1324
+ _LOG.info(f"Falling back to single-shot upload with Files API: {e}")
1325
+ # Concatenate the buffered part and the rest of the stream.
1326
+ full_stream = _ConcatenatedInputStream(BytesIO(e.buffer), contents)
1327
+ return self._single_thread_single_shot_upload(ctx, full_stream)
1328
+ except Exception as e:
1329
+ _LOG.info(f"Aborting multipart upload on error: {e}")
1330
+ try:
1331
+ self._abort_multipart_upload(ctx, session_token, cloud_provider_session)
1332
+ except Exception as abort_ex:
1333
+ _LOG.warning(f"Failed to abort upload: {abort_ex}")
1334
+ finally:
1335
+ # Rethrow the original exception.
1336
+ raise e from None
1337
+ else:
1338
+ raise ValueError(f"Unexpected server response: {initiate_upload_response}")
1339
+
1340
+ def _parallel_upload_from_file(
1341
+ self,
1342
+ ctx: _UploadContext,
1343
+ ) -> None:
1344
+ """
1345
+ Upload a file using multipart upload with multiple threads.
1346
+ This method is not implemented in this example, but it would typically
1347
+ involve creating multiple threads to upload different parts of the file concurrently.
1348
+ """
1349
+
1350
+ initiate_upload_response = self._initiate_multipart_upload(ctx)
1351
+
1352
+ if initiate_upload_response.get("multipart_upload"):
1353
+ cloud_provider_session = self._create_cloud_provider_session()
1354
+ session_token = initiate_upload_response["multipart_upload"].get("session_token")
1355
+ if not session_token:
1356
+ raise ValueError(f"Unexpected server response: {initiate_upload_response}")
1357
+ try:
1358
+ self._parallel_multipart_upload_from_file(ctx, session_token)
1359
+ except FallbackToUploadUsingFilesApi as e:
1360
+ try:
1361
+ self._abort_multipart_upload(ctx, session_token, cloud_provider_session)
1362
+ except Exception as abort_ex:
1363
+ _LOG.warning(f"Failed to abort upload: {abort_ex}")
1364
+
1365
+ _LOG.info(f"Falling back to single-shot upload with Files API: {e}")
1366
+ # Concatenate the buffered part and the rest of the stream.
1367
+ with open(ctx.source_file_path, "rb") as f:
1368
+ return self._single_thread_single_shot_upload(ctx, f)
1369
+
1370
+ except Exception as e:
1371
+ _LOG.info(f"Aborting multipart upload on error: {e}")
1372
+ try:
1373
+ self._abort_multipart_upload(ctx, session_token, cloud_provider_session)
1374
+ except Exception as abort_ex:
1375
+ _LOG.warning(f"Failed to abort upload: {abort_ex}")
1376
+ finally:
1377
+ # Rethrow the original exception.
1378
+ raise e from None
1379
+
1380
+ elif initiate_upload_response.get("resumable_upload"):
1381
+ _LOG.warning("GCP does not support parallel resumable uploads, falling back to single-threaded upload")
1382
+ with open(ctx.source_file_path, "rb") as f:
1383
+ return self._upload_single_thread_with_known_size(ctx, f)
1384
+ else:
1385
+ raise ValueError(f"Unexpected server response: {initiate_upload_response}")
1386
+
1387
+ @dataclass
1388
+ class _MultipartUploadPart:
1389
+ ctx: FilesExt._UploadContext
1390
+ part_index: int
1391
+ part_offset: int
1392
+ part_size: int
1393
+ session_token: str
1394
+
1395
+ def _parallel_multipart_upload_from_file(
1396
+ self,
1397
+ ctx: _UploadContext,
1398
+ session_token: str,
1399
+ ) -> None:
1400
+ # Calculate the number of parts.
1401
+ file_size = os.path.getsize(ctx.source_file_path)
1402
+ part_size = ctx.part_size
1403
+ num_parts = (file_size + part_size - 1) // part_size
1404
+ _LOG.debug(f"Uploading file of size {file_size} bytes in {num_parts} parts using {ctx.parallelism} threads")
1405
+
1406
+ # Create queues and worker threads.
1407
+ task_queue = Queue()
1408
+ etags_result_queue = Queue()
1409
+ exception_queue = Queue()
1410
+ aborted = Event()
1411
+ workers = [
1412
+ Thread(target=self._upload_file_consumer, args=(task_queue, etags_result_queue, exception_queue, aborted))
1413
+ for _ in range(ctx.parallelism)
1414
+ ]
1415
+ _LOG.debug(f"Starting {len(workers)} worker threads for parallel upload")
1416
+
1417
+ # Enqueue all parts. Since the task queue is populated before starting the workers, we don't need to signal completion.
1418
+ for part_index in range(1, num_parts + 1):
1419
+ part_offset = (part_index - 1) * part_size
1420
+ part_size = min(part_size, file_size - part_offset)
1421
+ part = self._MultipartUploadPart(ctx, part_index, part_offset, part_size, session_token)
1422
+ task_queue.put(part)
1423
+
1424
+ # Start the worker threads for parallel upload.
1425
+ for worker in workers:
1426
+ worker.start()
1427
+
1428
+ # Wait for all tasks to be processed.
1429
+ for worker in workers:
1430
+ worker.join()
1431
+
1432
+ # Check for exceptions: if any worker encountered an exception, raise the first one.
1433
+ if not exception_queue.empty():
1434
+ first_exception = exception_queue.get()
1435
+ raise first_exception
1436
+
1437
+ # Collect results from the etags queue.
1438
+ etags: dict = {}
1439
+ while not etags_result_queue.empty():
1440
+ part_number, etag = etags_result_queue.get()
1441
+ etags[part_number] = etag
1442
+
1443
+ self._complete_multipart_upload(ctx, etags, session_token)
1444
+
1445
+ def _parallel_multipart_upload_from_stream(
1446
+ self,
1447
+ ctx: _UploadContext,
1448
+ session_token: str,
1449
+ content: BinaryIO,
1450
+ cloud_provider_session: requests.Session,
1451
+ ) -> None:
1452
+
1453
+ task_queue = Queue(maxsize=ctx.parallelism) # Limit queue size to control memory usage
1454
+ etags_result_queue = Queue()
1455
+ exception_queue = Queue()
1456
+ all_produced = Event()
1457
+ aborted = Event()
1458
+
1459
+ # Do the first part read ahead
1460
+ pre_read_buffer = content.read(ctx.part_size)
1461
+ if not pre_read_buffer:
1462
+ self._complete_multipart_upload(ctx, {}, session_token)
1463
+ return
1464
+ try:
1465
+ etag = self._do_upload_one_part(
1466
+ ctx, cloud_provider_session, 1, 0, len(pre_read_buffer), session_token, BytesIO(pre_read_buffer)
1467
+ )
1468
+ etags_result_queue.put((1, etag))
1469
+ except FallbackToUploadUsingFilesApi as e:
1470
+ raise FallbackToUploadUsingFilesApi(
1471
+ pre_read_buffer, "Falling back to single-shot upload with Files API"
1472
+ ) from e
1473
+
1474
+ if len(pre_read_buffer) < ctx.part_size:
1475
+ self._complete_multipart_upload(ctx, {1: etag}, session_token)
1476
+ return
1477
+
1478
+ def producer() -> None:
1479
+ part_index = 2
1480
+ part_size = ctx.part_size
1481
+ while not aborted.is_set():
1482
+ part_content = content.read(part_size)
1483
+ if not part_content:
1484
+ break
1485
+ part_offset = (part_index - 1) * part_size
1486
+ part = self._MultipartUploadPart(ctx, part_index, part_offset, len(part_content), session_token)
1487
+ while not aborted.is_set():
1488
+ try:
1489
+ task_queue.put((part, part_content), timeout=0.1)
1490
+ break
1491
+ except Full:
1492
+ continue
1493
+ part_index += 1
1494
+ all_produced.set()
1495
+
1496
+ producer_thread = Thread(target=producer)
1497
+ consumers = [
1498
+ Thread(
1499
+ target=self._upload_stream_consumer,
1500
+ args=(task_queue, etags_result_queue, exception_queue, all_produced, aborted),
1501
+ )
1502
+ for _ in range(ctx.parallelism)
1503
+ ]
1504
+ _LOG.debug(f"Starting {len(consumers)} worker threads for parallel upload")
1505
+ # Start producer and consumer threads
1506
+ producer_thread.start()
1507
+ for consumer in consumers:
1508
+ consumer.start()
1509
+
1510
+ # Wait for producer to finish
1511
+ _LOG.debug(f"threads started, waiting for producer to finish")
1512
+ producer_thread.join()
1513
+ # Wait for all tasks to be processed
1514
+ _LOG.debug(f"producer finished, waiting for consumers to finish")
1515
+ # task_queue.join()
1516
+ for consumer in consumers:
1517
+ consumer.join()
1518
+
1519
+ # Check for exceptions: if any worker encountered an exception, raise the first one.
1520
+ if not exception_queue.empty():
1521
+ first_exception = exception_queue.get()
1522
+ raise first_exception
1523
+
1524
+ # Collect results from the etags queue
1525
+ etags: dict = {}
1526
+ while not etags_result_queue.empty():
1527
+ part_number, etag = etags_result_queue.get()
1528
+ etags[part_number] = etag
1529
+
1530
+ self._complete_multipart_upload(ctx, etags, session_token)
1531
+
1532
+ def _complete_multipart_upload(self, ctx, etags, session_token):
1533
+ query = {"action": "complete-upload", "upload_type": "multipart", "session_token": session_token}
1534
+ headers = {"Content-Type": "application/json"}
1535
+ body: dict = {}
1536
+ parts = []
1537
+ for part_number, etag in sorted(etags.items()):
1538
+ part = {"part_number": part_number, "etag": etag}
1539
+ parts.append(part)
1540
+ body["parts"] = parts
1541
+ self._api.do(
1542
+ "POST",
1543
+ f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(ctx.target_path)}",
1544
+ query=query,
1545
+ headers=headers,
1546
+ body=body,
1547
+ )
1548
+
1549
+ def _upload_file_consumer(
1550
+ self,
1551
+ task_queue: Queue[FilesExt._MultipartUploadPart],
1552
+ etags_queue: Queue[tuple[int, str]],
1553
+ exception_queue: Queue[Exception],
1554
+ aborted: Event,
1555
+ ) -> None:
1556
+ cloud_provider_session = self._create_cloud_provider_session()
1557
+ while not aborted.is_set():
1558
+ try:
1559
+ part = task_queue.get(block=False)
1560
+ except Empty:
1561
+ # The task_queue was populated before the workers were started, so we can exit if it's empty.
1562
+ break
1563
+
1564
+ try:
1565
+ with open(part.ctx.source_file_path, "rb") as f:
1566
+ f.seek(part.part_offset, os.SEEK_SET)
1567
+ part_content = BytesIO(f.read(part.part_size))
1568
+ etag = self._do_upload_one_part(
1569
+ part.ctx,
1570
+ cloud_provider_session,
1571
+ part.part_index,
1572
+ part.part_offset,
1573
+ part.part_size,
1574
+ part.session_token,
1575
+ part_content,
1576
+ )
1577
+ etags_queue.put((part.part_index, etag))
1578
+ except Exception as e:
1579
+ aborted.set()
1580
+ exception_queue.put(e)
1581
+ finally:
1582
+ task_queue.task_done()
1583
+
1584
+ def _upload_stream_consumer(
1585
+ self,
1586
+ task_queue: Queue[tuple[FilesExt._MultipartUploadPart, bytes]],
1587
+ etags_queue: Queue[tuple[int, str]],
1588
+ exception_queue: Queue[Exception],
1589
+ all_produced: Event,
1590
+ aborted: Event,
1591
+ ) -> None:
1592
+ cloud_provider_session = self._create_cloud_provider_session()
1593
+ while not aborted.is_set():
1594
+ try:
1595
+ (part, content) = task_queue.get(block=False, timeout=0.1)
1596
+ except Empty:
1597
+ if all_produced.is_set():
1598
+ break # No more parts will be produced and the queue is empty
1599
+ else:
1600
+ continue
1601
+ try:
1602
+ etag = self._do_upload_one_part(
1603
+ part.ctx,
1604
+ cloud_provider_session,
1605
+ part.part_index,
1606
+ part.part_offset,
1607
+ part.part_size,
1608
+ part.session_token,
1609
+ BytesIO(content),
1610
+ )
1611
+ etags_queue.put((part.part_index, etag))
1612
+ except Exception as e:
1613
+ aborted.set()
1614
+ exception_queue.put(e)
1615
+ finally:
1616
+ task_queue.task_done()
1617
+
1618
+ def _do_upload_one_part(
1619
+ self,
1620
+ ctx: _UploadContext,
1621
+ cloud_provider_session: requests.Session,
1622
+ part_index: int,
1623
+ part_offset: int,
1624
+ part_size: int,
1625
+ session_token: str,
1626
+ part_content: BinaryIO,
1627
+ ) -> str:
1628
+ retry_count = 0
1629
+
1630
+ # Try to upload the part, retrying if the upload URL expires.
1631
+ while True:
1632
+ body: dict = {
1633
+ "path": ctx.target_path,
1634
+ "session_token": session_token,
1635
+ "start_part_number": part_index,
1636
+ "count": 1,
1637
+ "expire_time": self._get_upload_url_expire_time(),
1638
+ }
1639
+
1640
+ headers = {"Content-Type": "application/json"}
1641
+
1642
+ # Requesting URLs for the same set of parts is an idempotent operation and is safe to retry.
1643
+ try:
1644
+ # The _api.do() method handles retries and will raise an exception in case of failure.
1645
+ upload_part_urls_response = self._api.do(
1646
+ "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body
1647
+ )
1648
+ except PermissionDenied as e:
1649
+ if self._is_presigned_urls_disabled_error(e):
1650
+ raise FallbackToUploadUsingFilesApi(None, "Presigned URLs are disabled")
1651
+ else:
1652
+ raise e from None
1653
+
1654
+ upload_part_urls = upload_part_urls_response.get("upload_part_urls", [])
1655
+ if len(upload_part_urls) == 0:
1656
+ raise ValueError(f"Unexpected server response: {upload_part_urls_response}")
1657
+ upload_part_url = upload_part_urls[0]
1658
+ url = upload_part_url["url"]
1659
+ required_headers = upload_part_url.get("headers", [])
1660
+ assert part_index == upload_part_url["part_number"]
1661
+
1662
+ headers: dict = {"Content-Type": "application/octet-stream"}
1663
+ for h in required_headers:
1664
+ headers[h["name"]] = h["value"]
1665
+
1666
+ _LOG.debug(f"Uploading part {part_index}: [{part_offset}, {part_offset + part_size - 1}]")
1667
+
1668
+ def rewind() -> None:
1669
+ part_content.seek(0, os.SEEK_SET)
1670
+
1671
+ def perform_upload() -> requests.Response:
1672
+ return cloud_provider_session.request(
1673
+ "PUT",
1674
+ url,
1675
+ headers=headers,
1676
+ data=part_content,
1677
+ timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds,
1678
+ )
1679
+
1680
+ upload_response = self._retry_cloud_idempotent_operation(perform_upload, rewind)
1681
+
1682
+ if upload_response.status_code in (200, 201):
1683
+ etag = upload_response.headers.get("ETag", "")
1684
+ return etag
1685
+ elif FilesExt._is_url_expired_response(upload_response):
1686
+ if retry_count < self._config.files_ext_multipart_upload_max_retries:
1687
+ retry_count += 1
1688
+ _LOG.debug("Upload URL expired, retrying...")
1689
+ continue
1690
+ else:
1691
+ raise ValueError(f"Unsuccessful chunk upload: upload URL expired after {retry_count} retries")
1692
+ elif upload_response.status_code == 403:
1693
+ raise FallbackToUploadUsingFilesApi(None, f"Direct upload forbidden: {upload_response.content}")
1694
+ else:
1695
+ message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}"
1696
+ _LOG.warning(message)
1697
+ mapped_error = _error_mapper(upload_response, {})
1698
+ raise mapped_error or ValueError(message)
1699
+
812
1700
  def _perform_multipart_upload(
813
1701
  self,
814
- target_path: str,
1702
+ ctx: _UploadContext,
815
1703
  input_stream: BinaryIO,
816
1704
  session_token: str,
817
1705
  pre_read_buffer: bytes,
818
1706
  cloud_provider_session: requests.Session,
819
- ):
1707
+ ) -> None:
820
1708
  """
821
1709
  Performs multipart upload using presigned URLs on AWS and Azure:
822
1710
  https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html
@@ -832,7 +1720,7 @@ class FilesExt(files.FilesAPI):
832
1720
  # AWS signed chunked upload: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html
833
1721
  # https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blobs-tune-upload-download-python#buffering-during-uploads
834
1722
 
835
- chunk_offset = 0 # used only for logging
1723
+ chunk_offset = 0
836
1724
 
837
1725
  # This buffer is expected to contain at least multipart_upload_chunk_size bytes.
838
1726
  # Note that initially buffer can be bigger (from pre_read_buffer).
@@ -842,37 +1730,43 @@ class FilesExt(files.FilesAPI):
842
1730
  eof = False
843
1731
  while not eof:
844
1732
  # If needed, buffer the next chunk.
845
- buffer = FilesExt._fill_buffer(buffer, self._config.multipart_upload_chunk_size, input_stream)
1733
+ buffer = FilesExt._fill_buffer(buffer, ctx.part_size, input_stream)
846
1734
  if len(buffer) == 0:
847
1735
  # End of stream, no need to request the next block of upload URLs.
848
1736
  break
849
1737
 
850
1738
  _LOG.debug(
851
- f"Multipart upload: requesting next {self._config.multipart_upload_batch_url_count} upload URLs starting from part {current_part_number}"
1739
+ f"Multipart upload: requesting next {ctx.batch_size} upload URLs starting from part {current_part_number}"
852
1740
  )
853
1741
 
854
1742
  body: dict = {
855
- "path": target_path,
1743
+ "path": ctx.target_path,
856
1744
  "session_token": session_token,
857
1745
  "start_part_number": current_part_number,
858
- "count": self._config.multipart_upload_batch_url_count,
859
- "expire_time": self._get_url_expire_time(),
1746
+ "count": ctx.batch_size,
1747
+ "expire_time": self._get_upload_url_expire_time(),
860
1748
  }
861
1749
 
862
1750
  headers = {"Content-Type": "application/json"}
863
1751
 
864
1752
  # Requesting URLs for the same set of parts is an idempotent operation, safe to retry.
865
- # Method _api.do() takes care of retrying and will raise an exception in case of failure.
866
- upload_part_urls_response = self._api.do(
867
- "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body
868
- )
1753
+ try:
1754
+ # Method _api.do() takes care of retrying and will raise an exception in case of failure.
1755
+ upload_part_urls_response = self._api.do(
1756
+ "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body
1757
+ )
1758
+ except PermissionDenied as e:
1759
+ if chunk_offset == 0 and self._is_presigned_urls_disabled_error(e):
1760
+ raise FallbackToUploadUsingFilesApi(buffer, "Presigned URLs are disabled")
1761
+ else:
1762
+ raise e from None
869
1763
 
870
1764
  upload_part_urls = upload_part_urls_response.get("upload_part_urls", [])
871
1765
  if len(upload_part_urls) == 0:
872
1766
  raise ValueError(f"Unexpected server response: {upload_part_urls_response}")
873
1767
 
874
1768
  for upload_part_url in upload_part_urls:
875
- buffer = FilesExt._fill_buffer(buffer, self._config.multipart_upload_chunk_size, input_stream)
1769
+ buffer = FilesExt._fill_buffer(buffer, ctx.part_size, input_stream)
876
1770
  actual_buffer_length = len(buffer)
877
1771
  if actual_buffer_length == 0:
878
1772
  eof = True
@@ -886,7 +1780,7 @@ class FilesExt(files.FilesAPI):
886
1780
  for h in required_headers:
887
1781
  headers[h["name"]] = h["value"]
888
1782
 
889
- actual_chunk_length = min(actual_buffer_length, self._config.multipart_upload_chunk_size)
1783
+ actual_chunk_length = min(actual_buffer_length, ctx.part_size)
890
1784
  _LOG.debug(
891
1785
  f"Uploading part {current_part_number}: [{chunk_offset}, {chunk_offset + actual_chunk_length - 1}]"
892
1786
  )
@@ -902,7 +1796,7 @@ class FilesExt(files.FilesAPI):
902
1796
  url,
903
1797
  headers=headers,
904
1798
  data=chunk,
905
- timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds,
1799
+ timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds,
906
1800
  )
907
1801
 
908
1802
  upload_response = self._retry_cloud_idempotent_operation(perform, rewind)
@@ -922,7 +1816,7 @@ class FilesExt(files.FilesAPI):
922
1816
  retry_count = 0
923
1817
 
924
1818
  elif FilesExt._is_url_expired_response(upload_response):
925
- if retry_count < self._config.multipart_upload_max_retries:
1819
+ if retry_count < self._config.files_ext_multipart_upload_max_retries:
926
1820
  retry_count += 1
927
1821
  _LOG.debug("Upload URL expired")
928
1822
  # Preserve the buffer so we'll upload the current part again using next upload URL
@@ -930,6 +1824,13 @@ class FilesExt(files.FilesAPI):
930
1824
  # don't confuse user with unrelated "Permission denied" error.
931
1825
  raise ValueError(f"Unsuccessful chunk upload: upload URL expired")
932
1826
 
1827
+ elif upload_response.status_code == 403 and chunk_offset == 0:
1828
+ # We got 403 failure when uploading the very first chunk (we can't tell if it is Azure for sure yet).
1829
+ # This might happen due to Azure firewall enabled for the customer bucket.
1830
+ # Let's fallback to using Files API which might be allowlisted to upload, passing
1831
+ # currently buffered (but not yet uploaded) part of the stream.
1832
+ raise FallbackToUploadUsingFilesApi(buffer, f"Direct upload forbidden: {upload_response.content}")
1833
+
933
1834
  else:
934
1835
  message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}"
935
1836
  _LOG.warning(message)
@@ -938,9 +1839,7 @@ class FilesExt(files.FilesAPI):
938
1839
 
939
1840
  current_part_number += 1
940
1841
 
941
- _LOG.debug(
942
- f"Completing multipart upload after uploading {len(etags)} parts of up to {self._config.multipart_upload_chunk_size} bytes"
943
- )
1842
+ _LOG.debug(f"Completing multipart upload after uploading {len(etags)} parts of up to {ctx.part_size} bytes")
944
1843
 
945
1844
  query = {"action": "complete-upload", "upload_type": "multipart", "session_token": session_token}
946
1845
  headers = {"Content-Type": "application/json"}
@@ -957,14 +1856,14 @@ class FilesExt(files.FilesAPI):
957
1856
  # Method _api.do() takes care of retrying and will raise an exception in case of failure.
958
1857
  self._api.do(
959
1858
  "POST",
960
- f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(target_path)}",
1859
+ f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(ctx.target_path)}",
961
1860
  query=query,
962
1861
  headers=headers,
963
1862
  body=body,
964
1863
  )
965
1864
 
966
1865
  @staticmethod
967
- def _fill_buffer(buffer: bytes, desired_min_size: int, input_stream: BinaryIO):
1866
+ def _fill_buffer(buffer: bytes, desired_min_size: int, input_stream: BinaryIO) -> bytes:
968
1867
  """
969
1868
  Tries to fill given buffer to contain at least `desired_min_size` bytes by reading from input stream.
970
1869
  """
@@ -978,7 +1877,7 @@ class FilesExt(files.FilesAPI):
978
1877
  return buffer
979
1878
 
980
1879
  @staticmethod
981
- def _is_url_expired_response(response: requests.Response):
1880
+ def _is_url_expired_response(response: requests.Response) -> bool:
982
1881
  """
983
1882
  Checks if response matches one of the known "URL expired" responses from the cloud storage providers.
984
1883
  """
@@ -1011,15 +1910,21 @@ class FilesExt(files.FilesAPI):
1011
1910
 
1012
1911
  return False
1013
1912
 
1913
+ def _is_presigned_urls_disabled_error(self, e: PermissionDenied) -> bool:
1914
+ error_infos = e.get_error_info()
1915
+ for error_info in error_infos:
1916
+ if error_info.reason == "FILES_API_API_IS_NOT_ENABLED":
1917
+ return True
1918
+ return False
1919
+
1014
1920
  def _perform_resumable_upload(
1015
1921
  self,
1016
- target_path: str,
1922
+ ctx: _UploadContext,
1017
1923
  input_stream: BinaryIO,
1018
1924
  session_token: str,
1019
- overwrite: bool,
1020
1925
  pre_read_buffer: bytes,
1021
1926
  cloud_provider_session: requests.Session,
1022
- ):
1927
+ ) -> None:
1023
1928
  """
1024
1929
  Performs resumable upload on GCP: https://cloud.google.com/storage/docs/performing-resumable-uploads
1025
1930
  """
@@ -1047,14 +1952,20 @@ class FilesExt(files.FilesAPI):
1047
1952
  # On the contrary, in multipart upload we can decide to complete upload *after*
1048
1953
  # last chunk has been sent.
1049
1954
 
1050
- body: dict = {"path": target_path, "session_token": session_token}
1955
+ body: dict = {"path": ctx.target_path, "session_token": session_token}
1051
1956
 
1052
1957
  headers = {"Content-Type": "application/json"}
1053
1958
 
1054
- # Method _api.do() takes care of retrying and will raise an exception in case of failure.
1055
- resumable_upload_url_response = self._api.do(
1056
- "POST", "/api/2.0/fs/create-resumable-upload-url", headers=headers, body=body
1057
- )
1959
+ try:
1960
+ # Method _api.do() takes care of retrying and will raise an exception in case of failure.
1961
+ resumable_upload_url_response = self._api.do(
1962
+ "POST", "/api/2.0/fs/create-resumable-upload-url", headers=headers, body=body
1963
+ )
1964
+ except PermissionDenied as e:
1965
+ if self._is_presigned_urls_disabled_error(e):
1966
+ raise FallbackToUploadUsingFilesApi(pre_read_buffer, "Presigned URLs are disabled")
1967
+ else:
1968
+ raise e from None
1058
1969
 
1059
1970
  resumable_upload_url_node = resumable_upload_url_response.get("resumable_upload_url")
1060
1971
  if not resumable_upload_url_node:
@@ -1069,7 +1980,7 @@ class FilesExt(files.FilesAPI):
1069
1980
  try:
1070
1981
  # We will buffer this many bytes: one chunk + read-ahead block.
1071
1982
  # Note buffer may contain more data initially (from pre_read_buffer).
1072
- min_buffer_size = self._config.multipart_upload_chunk_size + self._multipart_upload_read_ahead_bytes
1983
+ min_buffer_size = ctx.part_size + self._multipart_upload_read_ahead_bytes
1073
1984
 
1074
1985
  buffer = pre_read_buffer
1075
1986
 
@@ -1094,7 +2005,7 @@ class FilesExt(files.FilesAPI):
1094
2005
  file_size = chunk_offset + actual_chunk_length
1095
2006
  else:
1096
2007
  # More chunks expected, let's upload current chunk (excluding read-ahead block).
1097
- actual_chunk_length = self._config.multipart_upload_chunk_size
2008
+ actual_chunk_length = ctx.part_size
1098
2009
  file_size = "*"
1099
2010
 
1100
2011
  headers: dict = {"Content-Type": "application/octet-stream"}
@@ -1113,7 +2024,7 @@ class FilesExt(files.FilesAPI):
1113
2024
  resumable_upload_url,
1114
2025
  headers={"Content-Range": "bytes */*"},
1115
2026
  data=b"",
1116
- timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds,
2027
+ timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds,
1117
2028
  )
1118
2029
 
1119
2030
  try:
@@ -1128,7 +2039,7 @@ class FilesExt(files.FilesAPI):
1128
2039
  resumable_upload_url,
1129
2040
  headers=headers,
1130
2041
  data=BytesIO(buffer[:actual_chunk_length]),
1131
- timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds,
2042
+ timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds,
1132
2043
  )
1133
2044
 
1134
2045
  # https://cloud.google.com/storage/docs/performing-resumable-uploads#resume-upload
@@ -1136,9 +2047,8 @@ class FilesExt(files.FilesAPI):
1136
2047
  # a 503 or 500 response, then you need to resume the interrupted upload from where it left off.
1137
2048
 
1138
2049
  # Let's follow that for all potentially retryable status codes.
1139
- # Together with the catch block below we replicate the logic in _retry_databricks_idempotent_operation().
1140
2050
  if upload_response.status_code in self._RETRYABLE_STATUS_CODES:
1141
- if retry_count < self._config.multipart_upload_max_retries:
2051
+ if retry_count < self._config.files_ext_multipart_upload_max_retries:
1142
2052
  retry_count += 1
1143
2053
  # let original upload_response be handled as an error
1144
2054
  upload_response = retrieve_upload_status() or upload_response
@@ -1148,7 +2058,10 @@ class FilesExt(files.FilesAPI):
1148
2058
 
1149
2059
  except RequestException as e:
1150
2060
  # Let's do the same for retryable network errors.
1151
- if _BaseClient._is_retryable(e) and retry_count < self._config.multipart_upload_max_retries:
2061
+ if (
2062
+ _BaseClient._is_retryable(e)
2063
+ and retry_count < self._config.files_ext_multipart_upload_max_retries
2064
+ ):
1152
2065
  retry_count += 1
1153
2066
  upload_response = retrieve_upload_status()
1154
2067
  if not upload_response:
@@ -1194,7 +2107,7 @@ class FilesExt(files.FilesAPI):
1194
2107
  uploaded_bytes_count = next_chunk_offset - chunk_offset
1195
2108
  chunk_offset = next_chunk_offset
1196
2109
 
1197
- elif upload_response.status_code == 412 and not overwrite:
2110
+ elif upload_response.status_code == 412 and not ctx.overwrite:
1198
2111
  # Assuming this is only possible reason
1199
2112
  # Full message in this case: "At least one of the pre-conditions you specified did not hold."
1200
2113
  raise AlreadyExists("The file being created already exists.")
@@ -1227,19 +2140,38 @@ class FilesExt(files.FilesAPI):
1227
2140
  else:
1228
2141
  raise ValueError(f"Cannot parse response header: Range: {range_string}")
1229
2142
 
1230
- def _get_url_expire_time(self):
1231
- """Generates expiration time and save it in the required format."""
1232
- current_time = datetime.datetime.now(datetime.timezone.utc)
1233
- expire_time = current_time + self._config.multipart_upload_url_expiration_duration
2143
+ def _get_rfc339_timestamp_with_future_offset(self, base_time: datetime.datetime, offset: timedelta) -> str:
2144
+ """Generates an offset timestamp in an RFC3339 format suitable for URL generation"""
2145
+ offset_timestamp = base_time + offset
1234
2146
  # From Google Protobuf doc:
1235
2147
  # In JSON format, the Timestamp type is encoded as a string in the
1236
2148
  # * [RFC 3339](https://www.ietf.org/rfc/rfc3339.txt) format. That is, the
1237
2149
  # * format is "{year}-{month}-{day}T{hour}:{min}:{sec}[.{frac_sec}]Z"
1238
- return expire_time.strftime("%Y-%m-%dT%H:%M:%SZ")
2150
+ return offset_timestamp.strftime("%Y-%m-%dT%H:%M:%SZ")
2151
+
2152
+ def _get_upload_url_expire_time(self) -> str:
2153
+ """Generates expiration time in the required format."""
2154
+ current_time = datetime.datetime.now(datetime.timezone.utc)
2155
+ return self._get_rfc339_timestamp_with_future_offset(
2156
+ current_time, self._config.files_ext_multipart_upload_url_expiration_duration
2157
+ )
2158
+
2159
+ def _get_download_url_expire_time(self) -> str:
2160
+ """Generates expiration time in the required format."""
2161
+ current_time = datetime.datetime.now(datetime.timezone.utc)
2162
+ return self._get_rfc339_timestamp_with_future_offset(
2163
+ current_time, self._config.files_ext_presigned_download_url_expiration_duration
2164
+ )
1239
2165
 
1240
- def _abort_multipart_upload(self, target_path: str, session_token: str, cloud_provider_session: requests.Session):
2166
+ def _abort_multipart_upload(
2167
+ self, ctx: _UploadContext, session_token: str, cloud_provider_session: requests.Session
2168
+ ) -> None:
1241
2169
  """Aborts ongoing multipart upload session to clean up incomplete file."""
1242
- body: dict = {"path": target_path, "session_token": session_token, "expire_time": self._get_url_expire_time()}
2170
+ body: dict = {
2171
+ "path": ctx.target_path,
2172
+ "session_token": session_token,
2173
+ "expire_time": self._get_upload_url_expire_time(),
2174
+ }
1243
2175
 
1244
2176
  headers = {"Content-Type": "application/json"}
1245
2177
 
@@ -1254,13 +2186,13 @@ class FilesExt(files.FilesAPI):
1254
2186
  for h in required_headers:
1255
2187
  headers[h["name"]] = h["value"]
1256
2188
 
1257
- def perform():
2189
+ def perform() -> requests.Response:
1258
2190
  return cloud_provider_session.request(
1259
2191
  "DELETE",
1260
2192
  abort_url,
1261
2193
  headers=headers,
1262
2194
  data=b"",
1263
- timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds,
2195
+ timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds,
1264
2196
  )
1265
2197
 
1266
2198
  abort_response = self._retry_cloud_idempotent_operation(perform)
@@ -1270,19 +2202,19 @@ class FilesExt(files.FilesAPI):
1270
2202
 
1271
2203
  def _abort_resumable_upload(
1272
2204
  self, resumable_upload_url: str, required_headers: list, cloud_provider_session: requests.Session
1273
- ):
2205
+ ) -> None:
1274
2206
  """Aborts ongoing resumable upload session to clean up incomplete file."""
1275
2207
  headers: dict = {}
1276
2208
  for h in required_headers:
1277
2209
  headers[h["name"]] = h["value"]
1278
2210
 
1279
- def perform():
2211
+ def perform() -> requests.Response:
1280
2212
  return cloud_provider_session.request(
1281
2213
  "DELETE",
1282
2214
  resumable_upload_url,
1283
2215
  headers=headers,
1284
2216
  data=b"",
1285
- timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds,
2217
+ timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds,
1286
2218
  )
1287
2219
 
1288
2220
  abort_response = self._retry_cloud_idempotent_operation(perform)
@@ -1290,7 +2222,7 @@ class FilesExt(files.FilesAPI):
1290
2222
  if abort_response.status_code not in (200, 201):
1291
2223
  raise ValueError(abort_response)
1292
2224
 
1293
- def _create_cloud_provider_session(self):
2225
+ def _create_cloud_provider_session(self) -> requests.Session:
1294
2226
  """Creates a separate session which does not inherit auth headers from BaseClient session."""
1295
2227
  session = requests.Session()
1296
2228
 
@@ -1304,7 +2236,7 @@ class FilesExt(files.FilesAPI):
1304
2236
  return session
1305
2237
 
1306
2238
  def _retry_cloud_idempotent_operation(
1307
- self, operation: Callable[[], requests.Response], before_retry: Callable = None
2239
+ self, operation: Callable[[], requests.Response], before_retry: Optional[Callable] = None
1308
2240
  ) -> requests.Response:
1309
2241
  """Perform given idempotent operation with necessary retries for requests to non Databricks APIs.
1310
2242
  For cloud APIs, we will retry on network errors and on server response codes.
@@ -1337,10 +2269,14 @@ class FilesExt(files.FilesAPI):
1337
2269
  # where we believe request didn't reach the server
1338
2270
  is_retryable=extended_is_retryable,
1339
2271
  before_retry=before_retry,
2272
+ clock=self._config.clock,
1340
2273
  )(delegate)()
1341
2274
 
1342
2275
  def _open_download_stream(
1343
- self, file_path: str, start_byte_offset: int, if_unmodified_since_timestamp: Optional[str] = None
2276
+ self,
2277
+ file_path: str,
2278
+ start_byte_offset: int,
2279
+ if_unmodified_since_timestamp: Optional[str] = None,
1344
2280
  ) -> DownloadResponse:
1345
2281
  """Opens a download stream from given offset, performing necessary retries."""
1346
2282
  headers = {
@@ -1350,7 +2286,7 @@ class FilesExt(files.FilesAPI):
1350
2286
  if start_byte_offset and not if_unmodified_since_timestamp:
1351
2287
  raise Exception("if_unmodified_since_timestamp is required if start_byte_offset is specified")
1352
2288
 
1353
- if start_byte_offset:
2289
+ if start_byte_offset > 0:
1354
2290
  headers["Range"] = f"bytes={start_byte_offset}-"
1355
2291
 
1356
2292
  if if_unmodified_since_timestamp:
@@ -1361,6 +2297,23 @@ class FilesExt(files.FilesAPI):
1361
2297
  "content-type",
1362
2298
  "last-modified",
1363
2299
  ]
2300
+
2301
+ result = self._init_download_response_mode_csp_with_fallback(file_path, headers, response_headers)
2302
+
2303
+ if not isinstance(result.contents, _StreamingResponse):
2304
+ raise Exception(
2305
+ "Internal error: response contents is of unexpected type: " + type(result.contents).__name__
2306
+ )
2307
+
2308
+ return result
2309
+
2310
+ def _init_download_response_files_api(
2311
+ self, file_path: str, headers: dict[str, str], response_headers: list[str]
2312
+ ) -> DownloadResponse:
2313
+ """
2314
+ Initiates a download response using the Files API.
2315
+ """
2316
+
1364
2317
  # Method _api.do() takes care of retrying and will raise an exception in case of failure.
1365
2318
  res = self._api.do(
1366
2319
  "GET",
@@ -1369,22 +2322,119 @@ class FilesExt(files.FilesAPI):
1369
2322
  response_headers=response_headers,
1370
2323
  raw=True,
1371
2324
  )
2325
+ return DownloadResponse.from_dict(res)
1372
2326
 
1373
- result = DownloadResponse.from_dict(res)
1374
- if not isinstance(result.contents, _StreamingResponse):
1375
- raise Exception(
1376
- "Internal error: response contents is of unexpected type: " + type(result.contents).__name__
2327
+ def _create_download_url(self, file_path: str) -> CreateDownloadUrlResponse:
2328
+ """
2329
+ Creates a presigned download URL using the CSP presigned URL API.
2330
+
2331
+ Wrapped in similar retry logic to the internal API.do call:
2332
+ 1. Call _.api.do to obtain the presigned URL
2333
+ 2. Return the presigned URL
2334
+ """
2335
+
2336
+ # Method _api.do() takes care of retrying and will raise an exception in case of failure.
2337
+ try:
2338
+ raw_response = self._api.do(
2339
+ "POST",
2340
+ f"/api/2.0/fs/create-download-url",
2341
+ query={
2342
+ "path": file_path,
2343
+ "expire_time": self._get_download_url_expire_time(),
2344
+ },
1377
2345
  )
1378
2346
 
1379
- return result
2347
+ return CreateDownloadUrlResponse.from_dict(raw_response)
2348
+ except PermissionDenied as e:
2349
+ if self._is_presigned_urls_disabled_error(e):
2350
+ raise FallbackToDownloadUsingFilesApi(f"Presigned URLs are disabled")
2351
+ else:
2352
+ raise e from None
2353
+
2354
+ def _init_download_response_presigned_api(self, file_path: str, added_headers: dict[str, str]) -> DownloadResponse:
2355
+ """
2356
+ Initiates a download response using the CSP presigned URL API.
2357
+
2358
+ Wrapped in similar retry logic to the internal API.do call:
2359
+ 1. Call _.api.do to obtain the presigned URL
2360
+ 2. Attempt to establish a streaming connection via the presigned URL
2361
+ 3. Construct a StreamingResponse from the presigned URL
2362
+ """
2363
+
2364
+ url_and_headers = self._create_download_url(file_path)
2365
+ cloud_provider_session = self._create_cloud_provider_session()
2366
+
2367
+ header_overlap = added_headers.keys() & url_and_headers.headers.keys()
2368
+ if header_overlap:
2369
+ raise ValueError(
2370
+ f"Provided headers overlap with required headers from the CSP API bundle: {header_overlap}"
2371
+ )
2372
+
2373
+ merged_headers = {**added_headers, **url_and_headers.headers}
2374
+
2375
+ def perform() -> requests.Response:
2376
+ return cloud_provider_session.request(
2377
+ "GET",
2378
+ url_and_headers.url,
2379
+ headers=merged_headers,
2380
+ timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds,
2381
+ stream=True,
2382
+ )
2383
+
2384
+ csp_response: _RawResponse = self._retry_cloud_idempotent_operation(perform)
1380
2385
 
1381
- def _wrap_stream(self, file_path: str, download_response: DownloadResponse):
2386
+ # Mapping the error if the response is not successful.
2387
+ if csp_response.status_code in (200, 201, 206):
2388
+ resp = DownloadResponse(
2389
+ content_length=int(csp_response.headers.get("content-length")),
2390
+ content_type=csp_response.headers.get("content-type"),
2391
+ last_modified=csp_response.headers.get("last-modified"),
2392
+ contents=_StreamingResponse(csp_response, self._config.files_ext_client_download_streaming_chunk_size),
2393
+ )
2394
+ return resp
2395
+ elif csp_response.status_code == 403:
2396
+ # We got 403 failure when downloading the file. This might happen due to Azure firewall enabled for the customer bucket.
2397
+ # Let's fallback to using Files API which might be allowlisted to download.
2398
+ raise FallbackToDownloadUsingFilesApi(f"Direct download forbidden: {csp_response.content}")
2399
+ else:
2400
+ message = (
2401
+ f"Unsuccessful download. Response status: {csp_response.status_code}, body: {csp_response.content}"
2402
+ )
2403
+ _LOG.warning(message)
2404
+ mapped_error = _error_mapper(csp_response, {})
2405
+ raise mapped_error or ValueError(message)
2406
+
2407
+ def _init_download_response_mode_csp_with_fallback(
2408
+ self, file_path: str, headers: dict[str, str], response_headers: list[str]
2409
+ ) -> DownloadResponse:
2410
+ """
2411
+ Initiates a download response using the CSP presigned URL API or the Files API, depending on the configuration.
2412
+ If the CSP presigned download API is enabled, it will attempt to use that first.
2413
+ If the CSP API call fails, it will fall back to the Files API.
2414
+ If the CSP presigned download API is disabled, it will use the Files API directly.
2415
+ """
2416
+
2417
+ try:
2418
+ _LOG.debug(f"Attempting download of {file_path} via CSP APIs")
2419
+ return self._init_download_response_presigned_api(file_path, headers)
2420
+ except FallbackToDownloadUsingFilesApi as e:
2421
+ _LOG.info(f"Falling back to download via Files API: {e}")
2422
+ _LOG.debug(f"Attempt via CSP APIs for {file_path} failed. Falling back to download via Files API")
2423
+ ret = self._init_download_response_files_api(file_path, headers, response_headers)
2424
+ return ret
2425
+
2426
+ def _wrap_stream(
2427
+ self,
2428
+ file_path: str,
2429
+ download_response: DownloadResponse,
2430
+ start_byte_offset: int = 0,
2431
+ ) -> "_ResilientResponse":
1382
2432
  underlying_response = _ResilientIterator._extract_raw_response(download_response)
1383
2433
  return _ResilientResponse(
1384
2434
  self,
1385
2435
  file_path,
1386
2436
  download_response.last_modified,
1387
- offset=0,
2437
+ offset=start_byte_offset,
1388
2438
  underlying_response=underlying_response,
1389
2439
  )
1390
2440
 
@@ -1398,29 +2448,24 @@ class _ResilientResponse(_RawResponse):
1398
2448
  file_last_modified: str,
1399
2449
  offset: int,
1400
2450
  underlying_response: _RawResponse,
1401
- ):
2451
+ ) -> None:
1402
2452
  self.api = api
1403
2453
  self.file_path = file_path
1404
2454
  self.underlying_response = underlying_response
1405
2455
  self.offset = offset
1406
2456
  self.file_last_modified = file_last_modified
1407
2457
 
1408
- def iter_content(self, chunk_size=1, decode_unicode=False):
2458
+ def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False) -> Iterator[bytes]:
1409
2459
  if decode_unicode:
1410
2460
  raise ValueError("Decode unicode is not supported")
1411
2461
 
1412
2462
  iterator = self.underlying_response.iter_content(chunk_size=chunk_size, decode_unicode=False)
1413
2463
  self.iterator = _ResilientIterator(
1414
- iterator,
1415
- self.file_path,
1416
- self.file_last_modified,
1417
- self.offset,
1418
- self.api,
1419
- chunk_size,
2464
+ iterator, self.file_path, self.file_last_modified, self.offset, self.api, chunk_size
1420
2465
  )
1421
2466
  return self.iterator
1422
2467
 
1423
- def close(self):
2468
+ def close(self) -> None:
1424
2469
  self.iterator.close()
1425
2470
 
1426
2471
 
@@ -1432,18 +2477,18 @@ class _ResilientIterator(Iterator):
1432
2477
  def _extract_raw_response(
1433
2478
  download_response: DownloadResponse,
1434
2479
  ) -> _RawResponse:
1435
- streaming_response: _StreamingResponse = download_response.contents # this is an instance of _StreamingResponse
2480
+ streaming_response: _StreamingResponse = download_response.contents
1436
2481
  return streaming_response._response
1437
2482
 
1438
2483
  def __init__(
1439
2484
  self,
1440
- underlying_iterator,
2485
+ underlying_iterator: Iterator[bytes],
1441
2486
  file_path: str,
1442
2487
  file_last_modified: str,
1443
2488
  offset: int,
1444
2489
  api: FilesExt,
1445
2490
  chunk_size: int,
1446
- ):
2491
+ ) -> None:
1447
2492
  self._underlying_iterator = underlying_iterator
1448
2493
  self._api = api
1449
2494
  self._file_path = file_path
@@ -1459,13 +2504,13 @@ class _ResilientIterator(Iterator):
1459
2504
  self._closed: bool = False
1460
2505
 
1461
2506
  def _should_recover(self) -> bool:
1462
- if self._total_recovers_count == self._api._config.files_api_client_download_max_total_recovers:
2507
+ if self._total_recovers_count == self._api._config.files_ext_client_download_max_total_recovers:
1463
2508
  _LOG.debug("Total recovers limit exceeded")
1464
2509
  return False
1465
2510
  if (
1466
- self._api._config.files_api_client_download_max_total_recovers_without_progressing is not None
2511
+ self._api._config.files_ext_client_download_max_total_recovers_without_progressing is not None
1467
2512
  and self._recovers_without_progressing_count
1468
- >= self._api._config.files_api_client_download_max_total_recovers_without_progressing
2513
+ >= self._api._config.files_ext_client_download_max_total_recovers_without_progressing
1469
2514
  ):
1470
2515
  _LOG.debug("No progression recovers limit exceeded")
1471
2516
  return False
@@ -1481,7 +2526,7 @@ class _ResilientIterator(Iterator):
1481
2526
  try:
1482
2527
  self._underlying_iterator.close()
1483
2528
 
1484
- _LOG.debug("Trying to recover from offset " + str(self._offset))
2529
+ _LOG.debug(f"Trying to recover from offset {self._offset}")
1485
2530
 
1486
2531
  # following call includes all the required network retries
1487
2532
  downloadResponse = self._api._open_download_stream(self._file_path, self._offset, self._file_last_modified)
@@ -1494,7 +2539,7 @@ class _ResilientIterator(Iterator):
1494
2539
  except:
1495
2540
  return False # recover failed, rethrow original exception
1496
2541
 
1497
- def __next__(self):
2542
+ def __next__(self) -> bytes:
1498
2543
  if self._closed:
1499
2544
  # following _BaseClient
1500
2545
  raise ValueError("I/O operation on closed file")
@@ -1514,6 +2559,6 @@ class _ResilientIterator(Iterator):
1514
2559
  if not self._recover():
1515
2560
  raise
1516
2561
 
1517
- def close(self):
2562
+ def close(self) -> None:
1518
2563
  self._underlying_iterator.close()
1519
2564
  self._closed = True