azure-storage-blob 12.21.0b1__py3-none-any.whl → 12.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. azure/storage/blob/__init__.py +19 -18
  2. azure/storage/blob/_blob_client.py +470 -1555
  3. azure/storage/blob/_blob_client_helpers.py +1242 -0
  4. azure/storage/blob/_blob_service_client.py +93 -112
  5. azure/storage/blob/_blob_service_client_helpers.py +27 -0
  6. azure/storage/blob/_container_client.py +176 -377
  7. azure/storage/blob/_container_client_helpers.py +266 -0
  8. azure/storage/blob/_deserialize.py +68 -44
  9. azure/storage/blob/_download.py +375 -241
  10. azure/storage/blob/_encryption.py +14 -7
  11. azure/storage/blob/_generated/_azure_blob_storage.py +2 -1
  12. azure/storage/blob/_generated/_serialization.py +2 -0
  13. azure/storage/blob/_generated/aio/_azure_blob_storage.py +2 -1
  14. azure/storage/blob/_generated/aio/operations/_append_blob_operations.py +1 -7
  15. azure/storage/blob/_generated/aio/operations/_blob_operations.py +21 -47
  16. azure/storage/blob/_generated/aio/operations/_block_blob_operations.py +2 -10
  17. azure/storage/blob/_generated/aio/operations/_container_operations.py +13 -26
  18. azure/storage/blob/_generated/aio/operations/_page_blob_operations.py +3 -14
  19. azure/storage/blob/_generated/aio/operations/_service_operations.py +14 -17
  20. azure/storage/blob/_generated/operations/_append_blob_operations.py +1 -7
  21. azure/storage/blob/_generated/operations/_blob_operations.py +21 -47
  22. azure/storage/blob/_generated/operations/_block_blob_operations.py +2 -10
  23. azure/storage/blob/_generated/operations/_container_operations.py +13 -26
  24. azure/storage/blob/_generated/operations/_page_blob_operations.py +3 -14
  25. azure/storage/blob/_generated/operations/_service_operations.py +14 -17
  26. azure/storage/blob/_generated/py.typed +1 -0
  27. azure/storage/blob/_lease.py +52 -63
  28. azure/storage/blob/_list_blobs_helper.py +129 -135
  29. azure/storage/blob/_models.py +480 -277
  30. azure/storage/blob/_quick_query_helper.py +30 -31
  31. azure/storage/blob/_serialize.py +39 -56
  32. azure/storage/blob/_shared/avro/datafile.py +1 -1
  33. azure/storage/blob/_shared/avro/datafile_async.py +1 -1
  34. azure/storage/blob/_shared/base_client.py +3 -1
  35. azure/storage/blob/_shared/base_client_async.py +1 -1
  36. azure/storage/blob/_shared/policies.py +16 -15
  37. azure/storage/blob/_shared/policies_async.py +21 -6
  38. azure/storage/blob/_shared/response_handlers.py +6 -2
  39. azure/storage/blob/_shared/shared_access_signature.py +21 -3
  40. azure/storage/blob/_shared/uploads.py +1 -1
  41. azure/storage/blob/_shared/uploads_async.py +1 -1
  42. azure/storage/blob/_shared_access_signature.py +110 -52
  43. azure/storage/blob/_upload_helpers.py +75 -68
  44. azure/storage/blob/_version.py +1 -1
  45. azure/storage/blob/aio/__init__.py +19 -11
  46. azure/storage/blob/aio/_blob_client_async.py +554 -301
  47. azure/storage/blob/aio/_blob_service_client_async.py +148 -97
  48. azure/storage/blob/aio/_container_client_async.py +289 -140
  49. azure/storage/blob/aio/_download_async.py +485 -337
  50. azure/storage/blob/aio/_lease_async.py +61 -60
  51. azure/storage/blob/aio/_list_blobs_helper.py +94 -96
  52. azure/storage/blob/aio/_models.py +60 -38
  53. azure/storage/blob/aio/_upload_helpers.py +75 -66
  54. {azure_storage_blob-12.21.0b1.dist-info → azure_storage_blob-12.23.0.dist-info}/METADATA +7 -7
  55. azure_storage_blob-12.23.0.dist-info/RECORD +84 -0
  56. {azure_storage_blob-12.21.0b1.dist-info → azure_storage_blob-12.23.0.dist-info}/WHEEL +1 -1
  57. azure/storage/blob/_generated/_vendor.py +0 -16
  58. azure_storage_blob-12.21.0b1.dist-info/RECORD +0 -81
  59. {azure_storage_blob-12.21.0b1.dist-info → azure_storage_blob-12.23.0.dist-info}/LICENSE +0 -0
  60. {azure_storage_blob-12.21.0b1.dist-info → azure_storage_blob-12.23.0.dist-info}/top_level.txt +0 -0
@@ -5,37 +5,39 @@
5
5
  # --------------------------------------------------------------------------
6
6
 
7
7
  from io import BytesIO
8
- from typing import Union, Iterable, IO # pylint: disable=unused-import
8
+ from typing import Any, Dict, Generator, IO, Iterable, Optional, Type, Union, TYPE_CHECKING
9
9
 
10
- from ._shared.avro.datafile import DataFileReader
11
10
  from ._shared.avro.avro_io import DatumReader
11
+ from ._shared.avro.datafile import DataFileReader
12
+
13
+ if TYPE_CHECKING:
14
+ from ._models import BlobQueryError
12
15
 
13
16
 
14
17
  class BlobQueryReader(object): # pylint: disable=too-many-instance-attributes
15
- """A streaming object to read query results.
16
-
17
- :ivar str name:
18
- The name of the blob being quered.
19
- :ivar str container:
20
- The name of the container where the blob is.
21
- :ivar dict response_headers:
22
- The response_headers of the quick query request.
23
- :ivar bytes record_delimiter:
24
- The delimiter used to separate lines, or records with the data. The `records`
25
- method will return these lines via a generator.
26
- """
18
+ """A streaming object to read query results."""
19
+
20
+ name: str
21
+ """The name of the blob being quered."""
22
+ container: str
23
+ """The name of the container where the blob is."""
24
+ response_headers: Dict[str, Any]
25
+ """The response_headers of the quick query request."""
26
+ record_delimiter: str
27
+ """The delimiter used to separate lines, or records with the data. The `records`
28
+ method will return these lines via a generator."""
27
29
 
28
30
  def __init__(
29
31
  self,
30
- name=None,
31
- container=None,
32
- errors=None,
33
- record_delimiter='\n',
34
- encoding=None,
35
- headers=None,
36
- response=None,
37
- error_cls=None,
38
- ):
32
+ name: str = None, # type: ignore [assignment]
33
+ container: str = None, # type: ignore [assignment]
34
+ errors: Any = None,
35
+ record_delimiter: str = '\n',
36
+ encoding: Optional[str] = None,
37
+ headers: Dict[str, Any] = None, # type: ignore [assignment]
38
+ response: Any = None,
39
+ error_cls: Type["BlobQueryError"] = None, # type: ignore [assignment]
40
+ ) -> None:
39
41
  self.name = name
40
42
  self.container = container
41
43
  self.response_headers = headers
@@ -51,7 +53,7 @@ class BlobQueryReader(object): # pylint: disable=too-many-instance-attributes
51
53
  def __len__(self):
52
54
  return self._size
53
55
 
54
- def _process_record(self, result):
56
+ def _process_record(self, result: Dict[str, Any]) -> Optional[bytes]:
55
57
  self._size = result.get('totalBytes', self._size)
56
58
  self._bytes_processed = result.get('bytesScanned', self._bytes_processed)
57
59
  if 'data' in result:
@@ -67,7 +69,7 @@ class BlobQueryReader(object): # pylint: disable=too-many-instance-attributes
67
69
  self._errors(error)
68
70
  return None
69
71
 
70
- def _iter_stream(self):
72
+ def _iter_stream(self) -> Generator[bytes, None, None]:
71
73
  if self._first_result is not None:
72
74
  yield self._first_result
73
75
  for next_result in self._parsed_results:
@@ -75,8 +77,7 @@ class BlobQueryReader(object): # pylint: disable=too-many-instance-attributes
75
77
  if processed_result is not None:
76
78
  yield processed_result
77
79
 
78
- def readall(self):
79
- # type: () -> Union[bytes, str]
80
+ def readall(self) -> Union[bytes, str]:
80
81
  """Return all query results.
81
82
 
82
83
  This operation is blocking until all data is downloaded.
@@ -93,8 +94,7 @@ class BlobQueryReader(object): # pylint: disable=too-many-instance-attributes
93
94
  return data.decode(self._encoding)
94
95
  return data
95
96
 
96
- def readinto(self, stream):
97
- # type: (IO) -> None
97
+ def readinto(self, stream: IO) -> None:
98
98
  """Download the query result to a stream.
99
99
 
100
100
  :param IO stream:
@@ -105,8 +105,7 @@ class BlobQueryReader(object): # pylint: disable=too-many-instance-attributes
105
105
  for record in self._iter_stream():
106
106
  stream.write(record)
107
107
 
108
- def records(self):
109
- # type: () -> Iterable[Union[bytes, str]]
108
+ def records(self) -> Iterable[Union[bytes, str]]:
110
109
  """Returns a record generator for the query result.
111
110
 
112
111
  Records will be returned line by line.
@@ -3,9 +3,7 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import ( # pylint: disable=unused-import
7
- Any, Dict, Optional, Tuple, Union,
8
- TYPE_CHECKING)
6
+ from typing import Any, cast, Dict, Optional, Tuple, Union, TYPE_CHECKING
9
7
 
10
8
  try:
11
9
  from urllib.parse import quote
@@ -14,23 +12,22 @@ except ImportError:
14
12
 
15
13
  from azure.core import MatchConditions
16
14
 
17
- from ._models import (
18
- ContainerEncryptionScope,
19
- DelimitedJsonDialect)
20
15
  from ._generated.models import (
21
- ModifiedAccessConditions,
22
- SourceModifiedAccessConditions,
23
- CpkScopeInfo,
16
+ ArrowConfiguration,
17
+ BlobTag,
18
+ BlobTags,
24
19
  ContainerCpkScopeInfo,
25
- QueryFormat,
26
- QuerySerialization,
20
+ CpkScopeInfo,
27
21
  DelimitedTextConfiguration,
28
22
  JsonTextConfiguration,
29
- ArrowConfiguration,
23
+ LeaseAccessConditions,
24
+ ModifiedAccessConditions,
25
+ QueryFormat,
30
26
  QueryFormatType,
31
- BlobTag,
32
- BlobTags, LeaseAccessConditions
27
+ QuerySerialization,
28
+ SourceModifiedAccessConditions
33
29
  )
30
+ from ._models import ContainerEncryptionScope, DelimitedJsonDialect
34
31
 
35
32
  if TYPE_CHECKING:
36
33
  from ._lease import BlobLeaseClient
@@ -59,11 +56,15 @@ _SUPPORTED_API_VERSIONS = [
59
56
  '2023-11-03',
60
57
  '2024-05-04',
61
58
  '2024-08-04',
59
+ '2024-11-04',
62
60
  ]
63
61
 
64
62
 
65
- def _get_match_headers(kwargs, match_param, etag_param):
66
- # type: (Dict[str, Any], str, str) -> Tuple(Dict[str, Any], Optional[str], Optional[str])
63
+ def _get_match_headers(
64
+ kwargs: Dict[str, Any],
65
+ match_param: str,
66
+ etag_param: str
67
+ ) -> Tuple[Optional[str], Optional[Any]]:
67
68
  if_match = None
68
69
  if_none_match = None
69
70
  match_condition = kwargs.pop(match_param, None)
@@ -87,8 +88,7 @@ def _get_match_headers(kwargs, match_param, etag_param):
87
88
  return if_match, if_none_match
88
89
 
89
90
 
90
- def get_access_conditions(lease):
91
- # type: (Optional[Union[BlobLeaseClient, str]]) -> Union[LeaseAccessConditions, None]
91
+ def get_access_conditions(lease: Optional[Union["BlobLeaseClient", str]]) -> Optional[LeaseAccessConditions]:
92
92
  try:
93
93
  lease_id = lease.id # type: ignore
94
94
  except AttributeError:
@@ -96,8 +96,7 @@ def get_access_conditions(lease):
96
96
  return LeaseAccessConditions(lease_id=lease_id) if lease_id else None
97
97
 
98
98
 
99
- def get_modify_conditions(kwargs):
100
- # type: (Dict[str, Any]) -> ModifiedAccessConditions
99
+ def get_modify_conditions(kwargs: Dict[str, Any]) -> ModifiedAccessConditions:
101
100
  if_match, if_none_match = _get_match_headers(kwargs, 'match_condition', 'etag')
102
101
  return ModifiedAccessConditions(
103
102
  if_modified_since=kwargs.pop('if_modified_since', None),
@@ -108,8 +107,7 @@ def get_modify_conditions(kwargs):
108
107
  )
109
108
 
110
109
 
111
- def get_source_conditions(kwargs):
112
- # type: (Dict[str, Any]) -> SourceModifiedAccessConditions
110
+ def get_source_conditions(kwargs: Dict[str, Any]) -> SourceModifiedAccessConditions:
113
111
  if_match, if_none_match = _get_match_headers(kwargs, 'source_match_condition', 'source_etag')
114
112
  return SourceModifiedAccessConditions(
115
113
  source_if_modified_since=kwargs.pop('source_if_modified_since', None),
@@ -120,15 +118,13 @@ def get_source_conditions(kwargs):
120
118
  )
121
119
 
122
120
 
123
- def get_cpk_scope_info(kwargs):
124
- # type: (Dict[str, Any]) -> CpkScopeInfo
121
+ def get_cpk_scope_info(kwargs: Dict[str, Any]) -> Optional[CpkScopeInfo]:
125
122
  if 'encryption_scope' in kwargs:
126
123
  return CpkScopeInfo(encryption_scope=kwargs.pop('encryption_scope'))
127
124
  return None
128
125
 
129
126
 
130
- def get_container_cpk_scope_info(kwargs):
131
- # type: (Dict[str, Any]) -> ContainerCpkScopeInfo
127
+ def get_container_cpk_scope_info(kwargs: Dict[str, Any]) -> Optional[ContainerCpkScopeInfo]:
132
128
  encryption_scope = kwargs.pop('container_encryption_scope', None)
133
129
  if encryption_scope:
134
130
  if isinstance(encryption_scope, ContainerEncryptionScope):
@@ -145,22 +141,19 @@ def get_container_cpk_scope_info(kwargs):
145
141
  return None
146
142
 
147
143
 
148
- def get_api_version(kwargs):
149
- # type: (Dict[str, Any]) -> str
144
+ def get_api_version(kwargs: Dict[str, Any]) -> str:
150
145
  api_version = kwargs.get('api_version', None)
151
146
  if api_version and api_version not in _SUPPORTED_API_VERSIONS:
152
147
  versions = '\n'.join(_SUPPORTED_API_VERSIONS)
153
148
  raise ValueError(f"Unsupported API version '{api_version}'. Please select from:\n{versions}")
154
149
  return api_version or _SUPPORTED_API_VERSIONS[-1]
155
150
 
156
- def get_version_id(self_vid, kwargs):
157
- # type: (Optional[str], Dict[str, Any]) -> Optional[str]
151
+ def get_version_id(self_vid: Optional[str], kwargs: Dict[str, Any]) -> Optional[str]:
158
152
  if 'version_id' in kwargs:
159
- return kwargs.pop('version_id')
153
+ return cast(str, kwargs.pop('version_id'))
160
154
  return self_vid
161
155
 
162
- def serialize_blob_tags_header(tags=None):
163
- # type: (Optional[Dict[str, str]]) -> str
156
+ def serialize_blob_tags_header(tags: Optional[Dict[str, str]] = None) -> Optional[str]:
164
157
  if tags is None:
165
158
  return None
166
159
 
@@ -178,33 +171,27 @@ def serialize_blob_tags_header(tags=None):
178
171
  return ''.join(components)
179
172
 
180
173
 
181
- def serialize_blob_tags(tags=None):
182
- # type: (Optional[Dict[str, str]]) -> Union[BlobTags, None]
174
+ def serialize_blob_tags(tags: Optional[Dict[str, str]] = None) -> BlobTags:
183
175
  tag_list = []
184
176
  if tags:
185
177
  tag_list = [BlobTag(key=k, value=v) for k, v in tags.items()]
186
178
  return BlobTags(blob_tag_set=tag_list)
187
179
 
188
180
 
189
- def serialize_query_format(formater):
181
+ def serialize_query_format(formater: Union[str, DelimitedJsonDialect]) -> Optional[QuerySerialization]:
190
182
  if formater == "ParquetDialect":
191
- qq_format = QueryFormat(
192
- type=QueryFormatType.PARQUET,
193
- parquet_text_configuration=' '
194
- )
183
+ qq_format = QueryFormat(type=QueryFormatType.PARQUET, parquet_text_configuration=' ') #type: ignore [arg-type]
195
184
  elif isinstance(formater, DelimitedJsonDialect):
196
- serialization_settings = JsonTextConfiguration(
197
- record_separator=formater.delimiter
198
- )
199
- qq_format = QueryFormat(
200
- type=QueryFormatType.json,
201
- json_text_configuration=serialization_settings)
185
+ json_serialization_settings = JsonTextConfiguration(record_separator=formater.delimiter)
186
+ qq_format = QueryFormat(type=QueryFormatType.JSON, json_text_configuration=json_serialization_settings)
202
187
  elif hasattr(formater, 'quotechar'): # This supports a csv.Dialect as well
203
188
  try:
204
- headers = formater.has_header
189
+ headers = formater.has_header # type: ignore
205
190
  except AttributeError:
206
191
  headers = False
207
- serialization_settings = DelimitedTextConfiguration(
192
+ if isinstance(formater, str):
193
+ raise ValueError("Unknown string value provided. Accepted values: ParquetDialect")
194
+ csv_serialization_settings = DelimitedTextConfiguration(
208
195
  column_separator=formater.delimiter,
209
196
  field_quote=formater.quotechar,
210
197
  record_separator=formater.lineterminator,
@@ -212,16 +199,12 @@ def serialize_query_format(formater):
212
199
  headers_present=headers
213
200
  )
214
201
  qq_format = QueryFormat(
215
- type=QueryFormatType.delimited,
216
- delimited_text_configuration=serialization_settings
202
+ type=QueryFormatType.DELIMITED,
203
+ delimited_text_configuration=csv_serialization_settings
217
204
  )
218
205
  elif isinstance(formater, list):
219
- serialization_settings = ArrowConfiguration(
220
- schema=formater
221
- )
222
- qq_format = QueryFormat(
223
- type=QueryFormatType.arrow,
224
- arrow_configuration=serialization_settings)
206
+ arrow_serialization_settings = ArrowConfiguration(schema=formater)
207
+ qq_format = QueryFormat(type=QueryFormatType.arrow, arrow_configuration=arrow_serialization_settings)
225
208
  elif not formater:
226
209
  return None
227
210
  else:
@@ -185,7 +185,7 @@ class DataFileReader(object): # pylint: disable=too-many-instance-attributes
185
185
 
186
186
  # check magic number
187
187
  if header.get('magic') != MAGIC:
188
- fail_msg = f"Not an Avro data file: {header.get('magic')} doesn't match {MAGIC}."
188
+ fail_msg = f"Not an Avro data file: {header.get('magic')} doesn't match {MAGIC!r}."
189
189
  raise schema.AvroException(fail_msg)
190
190
 
191
191
  # set metadata
@@ -146,7 +146,7 @@ class AsyncDataFileReader(object): # pylint: disable=too-many-instance-attribut
146
146
 
147
147
  # check magic number
148
148
  if header.get('magic') != MAGIC:
149
- fail_msg = f"Not an Avro data file: {header.get('magic')} doesn't match {MAGIC}."
149
+ fail_msg = f"Not an Avro data file: {header.get('magic')} doesn't match {MAGIC!r}."
150
150
  raise schema.AvroException(fail_msg)
151
151
 
152
152
  # set metadata
@@ -76,6 +76,7 @@ class StorageAccountHostsMixin(object): # pylint: disable=too-many-instance-att
76
76
  self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY)
77
77
  self._hosts = kwargs.get("_hosts")
78
78
  self.scheme = parsed_url.scheme
79
+ self._is_localhost = False
79
80
 
80
81
  if service not in ["blob", "queue", "file-share", "dfs"]:
81
82
  raise ValueError(f"Invalid service: {service}")
@@ -85,6 +86,7 @@ class StorageAccountHostsMixin(object): # pylint: disable=too-many-instance-att
85
86
  self.account_name = account[0] if len(account) > 1 else None
86
87
  if not self.account_name and parsed_url.netloc.startswith("localhost") \
87
88
  or parsed_url.netloc.startswith("127.0.0.1"):
89
+ self._is_localhost = True
88
90
  self.account_name = parsed_url.path.strip("/")
89
91
 
90
92
  self.credential = _format_shared_key_credential(self.account_name, credential)
@@ -331,7 +333,7 @@ class StorageAccountHostsMixin(object): # pylint: disable=too-many-instance-att
331
333
  )
332
334
  raise error
333
335
  return iter(parts)
334
- return parts
336
+ return parts # type: ignore [no-any-return]
335
337
  except HttpResponseError as error:
336
338
  process_storage_error(error)
337
339
 
@@ -201,7 +201,7 @@ class AsyncStorageAccountHostsMixin(object):
201
201
  )
202
202
  raise error
203
203
  return AsyncList(parts_list)
204
- return parts
204
+ return parts # type: ignore [no-any-return]
205
205
  except HttpResponseError as error:
206
206
  process_storage_error(error)
207
207
 
@@ -14,10 +14,10 @@ from io import SEEK_SET, UnsupportedOperation
14
14
  from time import time
15
15
  from typing import Any, Dict, Optional, TYPE_CHECKING
16
16
  from urllib.parse import (
17
- parse_qsl,
18
- urlencode,
19
- urlparse,
20
- urlunparse,
17
+ parse_qsl,
18
+ urlencode,
19
+ urlparse,
20
+ urlunparse,
21
21
  )
22
22
  from wsgiref.handlers import format_date_time
23
23
 
@@ -28,18 +28,13 @@ from azure.core.pipeline.policies import (
28
28
  HTTPPolicy,
29
29
  NetworkTraceLoggingPolicy,
30
30
  RequestHistory,
31
- SansIOHTTPPolicy,
31
+ SansIOHTTPPolicy
32
32
  )
33
33
 
34
- from .authentication import StorageHttpChallenge
34
+ from .authentication import AzureSigningError, StorageHttpChallenge
35
35
  from .constants import DEFAULT_OAUTH_SCOPE
36
36
  from .models import LocationMode
37
37
 
38
- try:
39
- _unicode_type = unicode # type: ignore
40
- except NameError:
41
- _unicode_type = str
42
-
43
38
  if TYPE_CHECKING:
44
39
  from azure.core.credentials import TokenCredential
45
40
  from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import
@@ -52,7 +47,7 @@ _LOGGER = logging.getLogger(__name__)
52
47
 
53
48
 
54
49
  def encode_base64(data):
55
- if isinstance(data, _unicode_type):
50
+ if isinstance(data, str):
56
51
  data = data.encode('utf-8')
57
52
  encoded = base64.b64encode(data)
58
53
  return encoded.decode('utf-8')
@@ -95,10 +90,14 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements
95
90
  if status in [501, 505]:
96
91
  return False
97
92
  return True
93
+ return False
94
+
95
+
96
+ def is_checksum_retry(response):
98
97
  # retry if invalid content md5
99
98
  if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
100
99
  computed_md5 = response.http_request.headers.get('content-md5', None) or \
101
- encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
100
+ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
102
101
  if response.http_response.headers['content-md5'] != computed_md5:
103
102
  return True
104
103
  return False
@@ -301,7 +300,7 @@ class StorageResponseHook(HTTPPolicy):
301
300
 
302
301
  response = self.next.send(request)
303
302
 
304
- will_retry = is_retry(response, request.context.options.get('mode'))
303
+ will_retry = is_retry(response, request.context.options.get('mode')) or is_checksum_retry(response)
305
304
  # Auth error could come from Bearer challenge, in which case this request will be made again
306
305
  is_auth_error = response.http_response.status_code == 401
307
306
  should_update_counts = not (will_retry or is_auth_error)
@@ -527,7 +526,7 @@ class StorageRetryPolicy(HTTPPolicy):
527
526
  while retries_remaining:
528
527
  try:
529
528
  response = self.next.send(request)
530
- if is_retry(response, retry_settings['mode']):
529
+ if is_retry(response, retry_settings['mode']) or is_checksum_retry(response):
531
530
  retries_remaining = self.increment(
532
531
  retry_settings,
533
532
  request=request.http_request,
@@ -542,6 +541,8 @@ class StorageRetryPolicy(HTTPPolicy):
542
541
  continue
543
542
  break
544
543
  except AzureError as err:
544
+ if isinstance(err, AzureSigningError):
545
+ raise
545
546
  retries_remaining = self.increment(
546
547
  retry_settings, request=request.http_request, error=err)
547
548
  if retries_remaining:
@@ -10,12 +10,12 @@ import logging
10
10
  import random
11
11
  from typing import Any, Dict, TYPE_CHECKING
12
12
 
13
- from azure.core.exceptions import AzureError
13
+ from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError
14
14
  from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy
15
15
 
16
- from .authentication import StorageHttpChallenge
16
+ from .authentication import AzureSigningError, StorageHttpChallenge
17
17
  from .constants import DEFAULT_OAUTH_SCOPE
18
- from .policies import is_retry, StorageRetryPolicy
18
+ from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy
19
19
 
20
20
  if TYPE_CHECKING:
21
21
  from azure.core.credentials_async import AsyncTokenCredential
@@ -42,6 +42,20 @@ async def retry_hook(settings, **kwargs):
42
42
  **kwargs)
43
43
 
44
44
 
45
+ async def is_checksum_retry(response):
46
+ # retry if invalid content md5
47
+ if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
48
+ try:
49
+ await response.http_response.read() # Load the body in memory and close the socket
50
+ except (StreamClosedError, StreamConsumedError):
51
+ pass
52
+ computed_md5 = response.http_request.headers.get('content-md5', None) or \
53
+ encode_base64(StorageContentValidation.get_content_md5(response.http_response.content))
54
+ if response.http_response.headers['content-md5'] != computed_md5:
55
+ return True
56
+ return False
57
+
58
+
45
59
  class AsyncStorageResponseHook(AsyncHTTPPolicy):
46
60
 
47
61
  def __init__(self, **kwargs): # pylint: disable=unused-argument
@@ -64,9 +78,8 @@ class AsyncStorageResponseHook(AsyncHTTPPolicy):
64
78
  request.context.options.pop('raw_response_hook', self._response_callback)
65
79
 
66
80
  response = await self.next.send(request)
67
- await response.http_response.load_body()
81
+ will_retry = is_retry(response, request.context.options.get('mode')) or await is_checksum_retry(response)
68
82
 
69
- will_retry = is_retry(response, request.context.options.get('mode'))
70
83
  # Auth error could come from Bearer challenge, in which case this request will be made again
71
84
  is_auth_error = response.http_response.status_code == 401
72
85
  should_update_counts = not (will_retry or is_auth_error)
@@ -112,7 +125,7 @@ class AsyncStorageRetryPolicy(StorageRetryPolicy):
112
125
  while retries_remaining:
113
126
  try:
114
127
  response = await self.next.send(request)
115
- if is_retry(response, retry_settings['mode']):
128
+ if is_retry(response, retry_settings['mode']) or await is_checksum_retry(response):
116
129
  retries_remaining = self.increment(
117
130
  retry_settings,
118
131
  request=request.http_request,
@@ -127,6 +140,8 @@ class AsyncStorageRetryPolicy(StorageRetryPolicy):
127
140
  continue
128
141
  break
129
142
  except AzureError as err:
143
+ if isinstance(err, AzureSigningError):
144
+ raise
130
145
  retries_remaining = self.increment(
131
146
  retry_settings, request=request.http_request, error=err)
132
147
  if retries_remaining:
@@ -17,7 +17,8 @@ from azure.core.exceptions import (
17
17
  )
18
18
  from azure.core.pipeline.policies import ContentDecodePolicy
19
19
 
20
- from .models import StorageErrorCode, UserDelegationKey, get_enum_value
20
+ from .authentication import AzureSigningError
21
+ from .models import get_enum_value, StorageErrorCode, UserDelegationKey
21
22
  from .parser import _to_utc_datetime
22
23
 
23
24
 
@@ -81,9 +82,12 @@ def return_raw_deserialized(response, *_):
81
82
  return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME]
82
83
 
83
84
 
84
- def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements
85
+ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches
85
86
  raise_error = HttpResponseError
86
87
  serialized = False
88
+ if isinstance(storage_error, AzureSigningError):
89
+ storage_error.message = storage_error.message + \
90
+ '. This is likely due to an invalid shared key. Please check your shared key and try again.'
87
91
  if not storage_error.response or storage_error.response.status_code in [200, 204]:
88
92
  raise storage_error
89
93
  # If it is one of those three then it has been serialized prior by the generated layer.
@@ -107,8 +107,17 @@ class SharedAccessSignature(object):
107
107
  self.account_key = account_key
108
108
  self.x_ms_version = x_ms_version
109
109
 
110
- def generate_account(self, services, resource_types, permission, expiry, start=None,
111
- ip=None, protocol=None, **kwargs):
110
+ def generate_account(
111
+ self, services,
112
+ resource_types,
113
+ permission,
114
+ expiry,
115
+ start=None,
116
+ ip=None,
117
+ protocol=None,
118
+ sts_hook=None,
119
+ **kwargs
120
+ ) -> str:
112
121
  '''
113
122
  Generates a shared access signature for the account.
114
123
  Use the returned signature with the sas_token parameter of the service
@@ -152,6 +161,10 @@ class SharedAccessSignature(object):
152
161
  :keyword str encryption_scope:
153
162
  Optional. If specified, this is the encryption scope to use when sending requests
154
163
  authorized with this SAS URI.
164
+ :param sts_hook:
165
+ For debugging purposes only. If provided, the hook is called with the string to sign
166
+ that was used to generate the SAS.
167
+ :type sts_hook: Optional[Callable[[str], None]]
155
168
  :returns: The generated SAS token for the account.
156
169
  :rtype: str
157
170
  '''
@@ -161,12 +174,16 @@ class SharedAccessSignature(object):
161
174
  sas.add_encryption_scope(**kwargs)
162
175
  sas.add_account_signature(self.account_name, self.account_key)
163
176
 
177
+ if sts_hook is not None:
178
+ sts_hook(sas.string_to_sign)
179
+
164
180
  return sas.get_token()
165
181
 
166
182
 
167
183
  class _SharedAccessHelper(object):
168
184
  def __init__(self):
169
185
  self.query_dict = {}
186
+ self.string_to_sign = ""
170
187
 
171
188
  def _add_query(self, name, val):
172
189
  if val:
@@ -229,6 +246,7 @@ class _SharedAccessHelper(object):
229
246
 
230
247
  self._add_query(QueryStringConstants.SIGNED_SIGNATURE,
231
248
  sign_string(account_key, string_to_sign))
249
+ self.string_to_sign = string_to_sign
232
250
 
233
- def get_token(self):
251
+ def get_token(self) -> str:
234
252
  return '&'.join([f'{n}={url_quote(v)}' for n, v in self.query_dict.items() if v is not None])
@@ -12,7 +12,7 @@ from threading import Lock
12
12
 
13
13
  from azure.core.tracing.common import with_current_context
14
14
 
15
- from . import encode_base64, url_quote
15
+ from .import encode_base64, url_quote
16
16
  from .request_handlers import get_length
17
17
  from .response_handlers import return_response_headers
18
18
 
@@ -13,7 +13,7 @@ from itertools import islice
13
13
  from math import ceil
14
14
  from typing import AsyncGenerator, Union
15
15
 
16
- from . import encode_base64, url_quote
16
+ from .import encode_base64, url_quote
17
17
  from .request_handlers import get_length
18
18
  from .response_handlers import return_response_headers
19
19
  from .uploads import SubStream, IterStreamer # pylint: disable=unused-import