corehttp 1.0.0b6__py3-none-any.whl → 1.0.0b7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. corehttp/_version.py +1 -1
  2. corehttp/credentials.py +14 -5
  3. corehttp/instrumentation/__init__.py +9 -0
  4. corehttp/instrumentation/tracing/__init__.py +14 -0
  5. corehttp/instrumentation/tracing/_decorator.py +189 -0
  6. corehttp/instrumentation/tracing/_models.py +72 -0
  7. corehttp/instrumentation/tracing/_tracer.py +69 -0
  8. corehttp/instrumentation/tracing/opentelemetry.py +277 -0
  9. corehttp/instrumentation/tracing/utils.py +31 -0
  10. corehttp/paging.py +13 -0
  11. corehttp/rest/_aiohttp.py +17 -5
  12. corehttp/rest/_http_response_impl.py +7 -7
  13. corehttp/rest/_http_response_impl_async.py +2 -0
  14. corehttp/rest/_httpx.py +8 -8
  15. corehttp/rest/_requests_basic.py +13 -5
  16. corehttp/rest/_rest_py3.py +2 -2
  17. corehttp/runtime/pipeline/__init__.py +2 -2
  18. corehttp/runtime/pipeline/_base.py +2 -1
  19. corehttp/runtime/pipeline/_base_async.py +2 -0
  20. corehttp/runtime/pipeline/_tools.py +18 -2
  21. corehttp/runtime/policies/__init__.py +2 -0
  22. corehttp/runtime/policies/_authentication.py +28 -5
  23. corehttp/runtime/policies/_authentication_async.py +22 -3
  24. corehttp/runtime/policies/_distributed_tracing.py +169 -0
  25. corehttp/runtime/policies/_retry.py +7 -11
  26. corehttp/runtime/policies/_retry_async.py +4 -8
  27. corehttp/runtime/policies/_universal.py +11 -0
  28. corehttp/serialization.py +236 -2
  29. corehttp/settings.py +59 -0
  30. corehttp/transport/_base.py +1 -3
  31. corehttp/transport/_base_async.py +1 -3
  32. corehttp/transport/aiohttp/_aiohttp.py +39 -14
  33. corehttp/transport/requests/_requests_basic.py +31 -16
  34. corehttp/utils/_utils.py +2 -1
  35. {corehttp-1.0.0b6.dist-info → corehttp-1.0.0b7.dist-info}/METADATA +52 -6
  36. corehttp-1.0.0b7.dist-info/RECORD +61 -0
  37. {corehttp-1.0.0b6.dist-info → corehttp-1.0.0b7.dist-info}/WHEEL +1 -1
  38. corehttp-1.0.0b6.dist-info/RECORD +0 -52
  39. {corehttp-1.0.0b6.dist-info → corehttp-1.0.0b7.dist-info/licenses}/LICENSE +0 -0
  40. {corehttp-1.0.0b6.dist-info → corehttp-1.0.0b7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,31 @@
1
+ # ------------------------------------
2
+ # Copyright (c) Microsoft Corporation.
3
+ # Licensed under the MIT License.
4
+ # ------------------------------------
5
+ """Common tracing functionality for SDK libraries."""
6
+ from typing import Any, Callable
7
+
8
+ from ._tracer import get_tracer
9
+ from ...settings import settings
10
+
11
+
12
+ __all__ = [
13
+ "with_current_context",
14
+ ]
15
+
16
+
17
+ def with_current_context(func: Callable) -> Any:
18
+ """Passes the current spans to the new context the function will be run in.
19
+
20
+ :param func: The function that will be run in the new context
21
+ :type func: callable
22
+ :return: The func wrapped with correct context
23
+ :rtype: callable
24
+ """
25
+ if not settings.tracing_enabled:
26
+ return func
27
+
28
+ tracer = get_tracer()
29
+ if not tracer:
30
+ return func
31
+ return tracer.with_current_context(func)
corehttp/paging.py CHANGED
@@ -80,6 +80,13 @@ class PageIterator(Iterator[Iterator[ReturnType]]):
80
80
  return self
81
81
 
82
82
  def __next__(self) -> Iterator[ReturnType]:
83
+ """Get the next page in the iterator.
84
+
85
+ :returns: An iterator of objects in the next page.
86
+ :rtype: iterator[ReturnType]
87
+ :raises StopIteration: If there are no more pages to return.
88
+ :raises ~corehttp.exceptions.BaseError: If the request to get the next page fails.
89
+ """
83
90
  if self.continuation_token is None and self._did_a_call_already:
84
91
  raise StopIteration("End of paging")
85
92
  try:
@@ -129,6 +136,12 @@ class ItemPaged(Iterator[ReturnType]):
129
136
  return self
130
137
 
131
138
  def __next__(self) -> ReturnType:
139
+ """Get the next item in the iterator.
140
+
141
+ :returns: The next item in the iterator.
142
+ :rtype: ReturnType
143
+ :raises StopIteration: If there are no more items to return.
144
+ """
132
145
  if self._page_iterator is None:
133
146
  self._page_iterator = itertools.chain.from_iterable(self.by_page())
134
147
  return next(self._page_iterator)
corehttp/rest/_aiohttp.py CHANGED
@@ -38,6 +38,7 @@ from ..exceptions import (
38
38
  ServiceRequestError,
39
39
  ServiceResponseError,
40
40
  IncompleteReadError,
41
+ ServiceResponseTimeoutError,
41
42
  )
42
43
  from ..runtime.pipeline import AsyncPipeline
43
44
  from ..transport._base_async import _ResponseStopIteration
@@ -224,7 +225,18 @@ class RestAioHttpTransportResponse(AsyncHttpResponseImpl):
224
225
  """
225
226
  if not self._content:
226
227
  self._stream_download_check()
227
- self._content = await self._internal_response.read()
228
+ try:
229
+ self._content = await self._internal_response.read()
230
+ except aiohttp.client_exceptions.ClientPayloadError as err:
231
+ # This is the case that server closes connection before we finish the reading. aiohttp library
232
+ # raises ClientPayloadError.
233
+ raise IncompleteReadError(err, error=err) from err
234
+ except aiohttp.client_exceptions.ClientResponseError as err:
235
+ raise ServiceResponseError(err, error=err) from err
236
+ except asyncio.TimeoutError as err:
237
+ raise ServiceResponseTimeoutError(err, error=err) from err
238
+ except aiohttp.client_exceptions.ClientError as err:
239
+ raise ServiceRequestError(err, error=err) from err
228
240
  await self._set_read_checks()
229
241
  return _aiohttp_content_helper(self)
230
242
 
@@ -300,16 +312,16 @@ class AioHttpStreamDownloadGenerator(collections.abc.AsyncIterator):
300
312
  except aiohttp.client_exceptions.ClientPayloadError as err:
301
313
  # This is the case that server closes connection before we finish the reading. aiohttp library
302
314
  # raises ClientPayloadError.
303
- _LOGGER.warning("Incomplete download: %s", err)
315
+ _LOGGER.warning("Incomplete download.")
304
316
  internal_response.close()
305
317
  raise IncompleteReadError(err, error=err) from err
306
318
  except aiohttp.client_exceptions.ClientResponseError as err:
307
319
  raise ServiceResponseError(err, error=err) from err
308
320
  except asyncio.TimeoutError as err:
309
- raise ServiceResponseError(err, error=err) from err
321
+ raise ServiceResponseTimeoutError(err, error=err) from err
310
322
  except aiohttp.client_exceptions.ClientError as err:
311
323
  raise ServiceRequestError(err, error=err) from err
312
- except Exception as err:
313
- _LOGGER.warning("Unable to stream download: %s", err)
324
+ except Exception:
325
+ _LOGGER.warning("Unable to stream download.")
314
326
  internal_response.close()
315
327
  raise
@@ -54,7 +54,7 @@ class _HttpResponseBaseImpl(_HttpResponseBase): # pylint: disable=too-many-inst
54
54
  :type request: ~corehttp.rest.HttpRequest
55
55
  :keyword any internal_response: The response we get directly from the transport. For example, for our requests
56
56
  transport, this will be a requests.Response.
57
- :keyword optional[int] block_size: The block size we are using in our transport
57
+ :keyword Optional[int] block_size: The block size we are using in our transport
58
58
  :keyword int status_code: The status code of the response
59
59
  :keyword str reason: The HTTP reason
60
60
  :keyword str content_type: The content type of the response
@@ -136,7 +136,7 @@ class _HttpResponseBaseImpl(_HttpResponseBase): # pylint: disable=too-many-inst
136
136
  def content_type(self) -> Optional[str]:
137
137
  """The content type of the response.
138
138
 
139
- :rtype: optional[str]
139
+ :rtype: Optional[str]
140
140
  :return: The content type of the response.
141
141
  """
142
142
  return self._content_type
@@ -157,7 +157,7 @@ class _HttpResponseBaseImpl(_HttpResponseBase): # pylint: disable=too-many-inst
157
157
  :return: The response encoding. We either return the encoding set by the user,
158
158
  or try extracting the encoding from the response's content type. If all fails,
159
159
  we return `None`.
160
- :rtype: optional[str]
160
+ :rtype: Optional[str]
161
161
  """
162
162
  try:
163
163
  return self._encoding
@@ -166,10 +166,10 @@ class _HttpResponseBaseImpl(_HttpResponseBase): # pylint: disable=too-many-inst
166
166
  return self._encoding
167
167
 
168
168
  @encoding.setter
169
- def encoding(self, value: str) -> None:
169
+ def encoding(self, value: Optional[str]) -> None:
170
170
  """Sets the response encoding.
171
171
 
172
- :param str value: Sets the response encoding.
172
+ :param Optional[str] value: Sets the response encoding.
173
173
  """
174
174
  self._encoding = value
175
175
  self._text = None # clear text cache
@@ -178,7 +178,7 @@ class _HttpResponseBaseImpl(_HttpResponseBase): # pylint: disable=too-many-inst
178
178
  def text(self, encoding: Optional[str] = None) -> str:
179
179
  """Returns the response body as a string
180
180
 
181
- :param optional[str] encoding: The encoding you want to decode the text with. Can
181
+ :param Optional[str] encoding: The encoding you want to decode the text with. Can
182
182
  also be set independently through our encoding property
183
183
  :return: The response's content decoded as a string.
184
184
  :rtype: str
@@ -246,7 +246,7 @@ class HttpResponseImpl(_HttpResponseBaseImpl, _HttpResponse):
246
246
  :type request: ~corehttp.rest.HttpRequest
247
247
  :keyword any internal_response: The response we get directly from the transport. For example, for our requests
248
248
  transport, this will be a requests.Response.
249
- :keyword optional[int] block_size: The block size we are using in our transport
249
+ :keyword Optional[int] block_size: The block size we are using in our transport
250
250
  :keyword int status_code: The status code of the response
251
251
  :keyword str reason: The HTTP reason
252
252
  :keyword str content_type: The content type of the response
@@ -69,6 +69,7 @@ class AsyncHttpResponseImpl(_HttpResponseBaseImpl, _AsyncHttpResponse):
69
69
 
70
70
  async def iter_raw(self, **kwargs: Any) -> AsyncIterator[bytes]:
71
71
  """Asynchronously iterates over the response's bytes. Will not decompress in the process
72
+
72
73
  :return: An async iterator of bytes from the response
73
74
  :rtype: AsyncIterator[bytes]
74
75
  """
@@ -79,6 +80,7 @@ class AsyncHttpResponseImpl(_HttpResponseBaseImpl, _AsyncHttpResponse):
79
80
 
80
81
  async def iter_bytes(self, **kwargs: Any) -> AsyncIterator[bytes]:
81
82
  """Asynchronously iterates over the response's bytes. Will decompress in the process
83
+
82
84
  :return: An async iterator of bytes from the response
83
85
  :rtype: AsyncIterator[bytes]
84
86
  """
corehttp/rest/_httpx.py CHANGED
@@ -96,10 +96,10 @@ class HttpXStreamDownloadGenerator:
96
96
  except httpx.RemoteProtocolError as ex:
97
97
  msg = ex.__str__()
98
98
  if "complete message" in msg:
99
- _LOGGER.warning("Incomplete download: %s", ex)
99
+ _LOGGER.warning("Incomplete download.")
100
100
  internal_response.close()
101
101
  raise IncompleteReadError(ex, error=ex) from ex
102
- _LOGGER.warning("Unable to stream download: %s", ex)
102
+ _LOGGER.warning("Unable to stream download.")
103
103
  internal_response.close()
104
104
  raise HttpResponseError(ex, error=ex) from ex
105
105
  except httpx.DecodingError as ex:
@@ -108,8 +108,8 @@ class HttpXStreamDownloadGenerator:
108
108
  raise DecodeError("Failed to decode.", error=ex) from ex
109
109
  except httpx.RequestError as err:
110
110
  raise ServiceRequestError(err, error=err) from err
111
- except Exception as err:
112
- _LOGGER.warning("Unable to stream download: %s", err)
111
+ except Exception:
112
+ _LOGGER.warning("Unable to stream download.")
113
113
  internal_response.close()
114
114
  raise
115
115
 
@@ -186,10 +186,10 @@ class AsyncHttpXStreamDownloadGenerator(AsyncIterator):
186
186
  except httpx.RemoteProtocolError as ex:
187
187
  msg = ex.__str__()
188
188
  if "complete message" in msg:
189
- _LOGGER.warning("Incomplete download: %s", ex)
189
+ _LOGGER.warning("Incomplete download.")
190
190
  await internal_response.aclose()
191
191
  raise IncompleteReadError(ex, error=ex) from ex
192
- _LOGGER.warning("Unable to stream download: %s", ex)
192
+ _LOGGER.warning("Unable to stream download.")
193
193
  await internal_response.aclose()
194
194
  raise HttpResponseError(ex, error=ex) from ex
195
195
  except httpx.DecodingError as ex:
@@ -198,7 +198,7 @@ class AsyncHttpXStreamDownloadGenerator(AsyncIterator):
198
198
  raise DecodeError("Failed to decode.", error=ex) from ex
199
199
  except httpx.RequestError as err:
200
200
  raise ServiceRequestError(err, error=err) from err
201
- except Exception as err:
202
- _LOGGER.warning("Unable to stream download: %s", err)
201
+ except Exception:
202
+ _LOGGER.warning("Unable to stream download.")
203
203
  await internal_response.aclose()
204
204
  raise
@@ -38,8 +38,8 @@ from urllib3.exceptions import (
38
38
  from ..runtime.pipeline import Pipeline
39
39
  from ._http_response_impl import _HttpResponseBaseImpl, HttpResponseImpl
40
40
  from ..exceptions import (
41
- ServiceRequestError,
42
41
  ServiceResponseError,
42
+ ServiceResponseTimeoutError,
43
43
  IncompleteReadError,
44
44
  HttpResponseError,
45
45
  DecodeError,
@@ -156,14 +156,22 @@ class StreamDownloadGenerator:
156
156
  except requests.exceptions.ChunkedEncodingError as err:
157
157
  msg = err.__str__()
158
158
  if "IncompleteRead" in msg:
159
- _LOGGER.warning("Incomplete download: %s", err)
159
+ _LOGGER.warning("Incomplete download.")
160
160
  internal_response.close()
161
161
  raise IncompleteReadError(err, error=err) from err
162
- _LOGGER.warning("Unable to stream download: %s", err)
162
+ _LOGGER.warning("Unable to stream download.")
163
163
  internal_response.close()
164
164
  raise HttpResponseError(err, error=err) from err
165
+ except requests.ConnectionError as err:
166
+ internal_response.close()
167
+ if err.args and isinstance(err.args[0], ReadTimeoutError):
168
+ raise ServiceResponseTimeoutError(err, error=err) from err
169
+ raise ServiceResponseError(err, error=err) from err
170
+ except requests.RequestException as err:
171
+ internal_response.close()
172
+ raise ServiceResponseError(err, error=err) from err
165
173
  except Exception as err:
166
- _LOGGER.warning("Unable to stream download: %s", err)
174
+ _LOGGER.warning("Unable to stream download.")
167
175
  internal_response.close()
168
176
  raise
169
177
 
@@ -178,7 +186,7 @@ def _read_raw_stream(response, chunk_size=1):
178
186
  except CoreDecodeError as e:
179
187
  raise DecodeError(e, error=e) from e
180
188
  except ReadTimeoutError as e:
181
- raise ServiceRequestError(e, error=e) from e
189
+ raise ServiceResponseTimeoutError(e, error=e) from e
182
190
  else:
183
191
  # Standard file-like object.
184
192
  while True:
@@ -300,7 +300,7 @@ class _HttpResponseBase(abc.ABC):
300
300
 
301
301
  :return: The JSON deserialized response body
302
302
  :rtype: any
303
- :raises json.decoder.JSONDecodeError or ValueError (in python 2.7) if object is not JSON decodable:
303
+ :raises json.decoder.JSONDecodeError: if the body is not valid JSON.
304
304
  """
305
305
 
306
306
  @abc.abstractmethod
@@ -309,7 +309,7 @@ class _HttpResponseBase(abc.ABC):
309
309
 
310
310
  If response is good, does nothing.
311
311
 
312
- :raises ~corehttp.HttpResponseError if the object has an error status code.:
312
+ :raises ~corehttp.HttpResponseError: if the object has an error status code.
313
313
  """
314
314
 
315
315
 
@@ -97,7 +97,7 @@ class PipelineContext(Dict[str, Any]):
97
97
  def clear(self) -> None: # pylint: disable=docstring-missing-return, docstring-missing-rtype
98
98
  """Context objects cannot be cleared.
99
99
 
100
- :raises: TypeError
100
+ :raises TypeError: If context objects cannot be cleared.
101
101
  """
102
102
  raise TypeError("Context objects cannot be cleared.")
103
103
 
@@ -106,7 +106,7 @@ class PipelineContext(Dict[str, Any]):
106
106
  ) -> None:
107
107
  """Context objects cannot be updated.
108
108
 
109
- :raises: TypeError
109
+ :raises TypeError: If context objects cannot be updated.
110
110
  """
111
111
  raise TypeError("Context objects cannot be updated.")
112
112
 
@@ -34,7 +34,7 @@ from . import (
34
34
  PipelineContext,
35
35
  )
36
36
  from ..policies import HTTPPolicy, SansIOHTTPPolicy
37
- from ._tools import await_result as _await_result
37
+ from ._tools import sanitize_transport_options, await_result as _await_result
38
38
  from ...transport import HttpTransport
39
39
 
40
40
  HTTPResponseType = TypeVar("HTTPResponseType")
@@ -103,6 +103,7 @@ class _TransportRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]):
103
103
  :return: The PipelineResponse object.
104
104
  :rtype: ~corehttp.runtime.pipeline.PipelineResponse
105
105
  """
106
+ sanitize_transport_options(request.context.options)
106
107
  return PipelineResponse(
107
108
  request.http_request,
108
109
  self._sender.send(request.http_request, **request.context.options),
@@ -33,6 +33,7 @@ from . import PipelineRequest, PipelineResponse, PipelineContext
33
33
  from ..policies import AsyncHTTPPolicy, SansIOHTTPPolicy
34
34
  from ..pipeline._base import is_sansio_http_policy
35
35
  from ._tools_async import await_result as _await_result
36
+ from ._tools import sanitize_transport_options
36
37
  from ...transport import AsyncHttpTransport
37
38
 
38
39
  AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
@@ -97,6 +98,7 @@ class _AsyncTransportRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseTy
97
98
  :return: The PipelineResponse object.
98
99
  :rtype: ~corehttp.runtime.pipeline.PipelineResponse
99
100
  """
101
+ sanitize_transport_options(request.context.options)
100
102
  return PipelineResponse(
101
103
  request.http_request,
102
104
  await self._sender.send(request.http_request, **request.context.options),
@@ -23,7 +23,7 @@
23
23
  # IN THE SOFTWARE.
24
24
  #
25
25
  # --------------------------------------------------------------------------
26
- from typing import Callable, TypeVar
26
+ from typing import Callable, TypeVar, Dict
27
27
  from typing_extensions import ParamSpec
28
28
 
29
29
  P = ParamSpec("P")
@@ -39,9 +39,25 @@ def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
39
39
  :type args: list
40
40
  :rtype: any
41
41
  :return: The result of the function
42
- :raises: TypeError
42
+ :raises TypeError: If the function returns an awaitable object.
43
43
  """
44
44
  result = func(*args, **kwargs)
45
45
  if hasattr(result, "__await__"):
46
46
  raise TypeError("Policy {} returned awaitable object in non-async pipeline.".format(func))
47
47
  return result
48
+
49
+
50
+ def sanitize_transport_options(options: Dict[str, str]) -> None:
51
+ """Remove options that could potentially make it to the transport layer.
52
+
53
+ - "tracing_options" is used in the DistributedHttpTracingPolicy and tracing decorators
54
+
55
+ :param options: The options.
56
+ :type options: dict
57
+ """
58
+ if not options:
59
+ return
60
+
61
+ options_to_remove = ["tracing_options"]
62
+ for key in options_to_remove:
63
+ options.pop(key, None)
@@ -29,6 +29,7 @@ from ._authentication import (
29
29
  BearerTokenCredentialPolicy,
30
30
  ServiceKeyCredentialPolicy,
31
31
  )
32
+ from ._distributed_tracing import DistributedHttpTracingPolicy
32
33
  from ._retry import RetryPolicy, RetryMode
33
34
  from ._universal import (
34
35
  HeadersPolicy,
@@ -57,4 +58,5 @@ __all__ = [
57
58
  "AsyncHTTPPolicy",
58
59
  "AsyncBearerTokenCredentialPolicy",
59
60
  "AsyncRetryPolicy",
61
+ "DistributedHttpTracingPolicy",
60
62
  ]
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union
10
10
  from ...credentials import TokenRequestOptions
11
11
  from ...rest import HttpResponse, HttpRequest
12
12
  from . import HTTPPolicy, SansIOHTTPPolicy
13
- from ...exceptions import ServiceRequestError
13
+ from ...exceptions import ServiceRequestError, HttpResponseError
14
14
 
15
15
  if TYPE_CHECKING:
16
16
 
@@ -93,7 +93,7 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[H
93
93
  :param str scopes: Lets you specify the type of access needed.
94
94
  :keyword auth_flows: A list of authentication flows to use for the credential.
95
95
  :paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
96
- :raises: :class:`~corehttp.exceptions.ServiceRequestError`
96
+ :raises ~corehttp.exceptions.ServiceRequestError: If the request fails.
97
97
  """
98
98
 
99
99
  def on_request(
@@ -110,6 +110,9 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[H
110
110
  :keyword auth_flows: A list of authentication flows to use for the credential.
111
111
  :paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
112
112
  """
113
+ # If auth_flows is an empty list, we should not attempt to authorize the request.
114
+ if auth_flows is not None and len(auth_flows) == 0:
115
+ return
113
116
  self._enforce_https(request)
114
117
 
115
118
  if self._token is None or self._need_new_token:
@@ -142,7 +145,9 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[H
142
145
  :return: The pipeline response object
143
146
  :rtype: ~corehttp.runtime.pipeline.PipelineResponse
144
147
  """
145
- self.on_request(request, auth_flows=self._auth_flows)
148
+ op_auth_flows = request.context.options.pop("auth_flows", None)
149
+ auth_flows = op_auth_flows if op_auth_flows is not None else self._auth_flows
150
+ self.on_request(request, auth_flows=auth_flows)
146
151
  try:
147
152
  response = self.next.send(request)
148
153
  except Exception:
@@ -153,7 +158,19 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[H
153
158
  if response.http_response.status_code == 401:
154
159
  self._token = None # any cached token is invalid
155
160
  if "WWW-Authenticate" in response.http_response.headers:
156
- request_authorized = self.on_challenge(request, response)
161
+ try:
162
+ request_authorized = self.on_challenge(request, response)
163
+ except Exception as ex:
164
+ # If the response is streamed, read it so the error message is immediately available to the user.
165
+ # Otherwise, a generic error message will be given and the user will have to read the response
166
+ # body to see the actual error.
167
+ if response.context.options.get("stream"):
168
+ try:
169
+ response.http_response.read() # type: ignore
170
+ except Exception: # pylint:disable=broad-except
171
+ pass
172
+ # Raise the exception from the token request with the original 401 response.
173
+ raise ex from HttpResponseError(response=response.http_response)
157
174
  if request_authorized:
158
175
  try:
159
176
  response = self.next.send(request)
@@ -210,7 +227,8 @@ class ServiceKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseT
210
227
  :type credential: ~corehttp.credentials.ServiceKeyCredential
211
228
  :param str name: The name of the key header used for the credential.
212
229
  :keyword str prefix: The name of the prefix for the header value if any.
213
- :raises: ValueError or TypeError
230
+ :raises ValueError: if name is None or empty.
231
+ :raises TypeError: if name is not a string or if credential is not an instance of ServiceKeyCredential.
214
232
  """
215
233
 
216
234
  def __init__( # pylint: disable=unused-argument
@@ -233,4 +251,9 @@ class ServiceKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseT
233
251
  self._prefix = prefix + " " if prefix else ""
234
252
 
235
253
  def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
254
+ """Called before the policy sends a request.
255
+
256
+ :param request: The request to be modified before sending.
257
+ :type request: ~corehttp.runtime.pipeline.PipelineRequest
258
+ """
236
259
  request.http_request.headers[self._name] = f"{self._prefix}{self._credential.key}"
@@ -13,6 +13,7 @@ from ..pipeline._tools_async import await_result
13
13
  from ._base_async import AsyncHTTPPolicy
14
14
  from ._authentication import _BearerTokenCredentialPolicyBase
15
15
  from ...rest import AsyncHttpResponse, HttpRequest
16
+ from ...exceptions import HttpResponseError
16
17
  from ...utils._utils import get_running_async_lock
17
18
 
18
19
  if TYPE_CHECKING:
@@ -66,8 +67,11 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT
66
67
  :type request: ~corehttp.runtime.pipeline.PipelineRequest
67
68
  :keyword auth_flows: A list of authentication flows to use for the credential.
68
69
  :paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
69
- :raises: :class:`~corehttp.exceptions.ServiceRequestError`
70
+ :raises ~corehttp.exceptions.ServiceRequestError: If the request fails.
70
71
  """
72
+ # If auth_flows is an empty list, we should not attempt to authorize the request.
73
+ if auth_flows is not None and len(auth_flows) == 0:
74
+ return
71
75
  _BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access
72
76
 
73
77
  if self._token is None or self._need_new_token:
@@ -107,7 +111,9 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT
107
111
  :return: The pipeline response object
108
112
  :rtype: ~corehttp.runtime.pipeline.PipelineResponse
109
113
  """
110
- await await_result(self.on_request, request, auth_flows=self._auth_flows)
114
+ op_auth_flows = request.context.options.pop("auth_flows", None)
115
+ auth_flows = op_auth_flows if op_auth_flows is not None else self._auth_flows
116
+ await await_result(self.on_request, request, auth_flows=auth_flows)
111
117
  try:
112
118
  response = await self.next.send(request)
113
119
  except Exception:
@@ -118,7 +124,20 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT
118
124
  if response.http_response.status_code == 401:
119
125
  self._token = None # any cached token is invalid
120
126
  if "WWW-Authenticate" in response.http_response.headers:
121
- request_authorized = await self.on_challenge(request, response)
127
+ try:
128
+ request_authorized = await self.on_challenge(request, response)
129
+ except Exception as ex:
130
+ # If the response is streamed, read it so the error message is immediately available to the user.
131
+ # Otherwise, a generic error message will be given and the user will have to read the response
132
+ # body to see the actual error.
133
+ if response.context.options.get("stream"):
134
+ try:
135
+ await response.http_response.read() # type: ignore
136
+ except Exception: # pylint:disable=broad-except
137
+ pass
138
+
139
+ # Raise the exception from the token request with the original 401 response
140
+ raise ex from HttpResponseError(response=response.http_response)
122
141
  if request_authorized:
123
142
  try:
124
143
  response = await self.next.send(request)