corehttp 1.0.0b5__py3-none-any.whl → 1.0.0b7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. corehttp/_version.py +1 -1
  2. corehttp/credentials.py +66 -25
  3. corehttp/exceptions.py +7 -6
  4. corehttp/instrumentation/__init__.py +9 -0
  5. corehttp/instrumentation/tracing/__init__.py +14 -0
  6. corehttp/instrumentation/tracing/_decorator.py +189 -0
  7. corehttp/instrumentation/tracing/_models.py +72 -0
  8. corehttp/instrumentation/tracing/_tracer.py +69 -0
  9. corehttp/instrumentation/tracing/opentelemetry.py +277 -0
  10. corehttp/instrumentation/tracing/utils.py +31 -0
  11. corehttp/paging.py +13 -0
  12. corehttp/rest/_aiohttp.py +21 -9
  13. corehttp/rest/_http_response_impl.py +9 -15
  14. corehttp/rest/_http_response_impl_async.py +2 -0
  15. corehttp/rest/_httpx.py +9 -9
  16. corehttp/rest/_requests_basic.py +17 -10
  17. corehttp/rest/_rest_py3.py +6 -10
  18. corehttp/runtime/pipeline/__init__.py +5 -9
  19. corehttp/runtime/pipeline/_base.py +3 -2
  20. corehttp/runtime/pipeline/_base_async.py +6 -8
  21. corehttp/runtime/pipeline/_tools.py +18 -2
  22. corehttp/runtime/pipeline/_tools_async.py +2 -4
  23. corehttp/runtime/policies/__init__.py +2 -0
  24. corehttp/runtime/policies/_authentication.py +76 -24
  25. corehttp/runtime/policies/_authentication_async.py +66 -21
  26. corehttp/runtime/policies/_distributed_tracing.py +169 -0
  27. corehttp/runtime/policies/_retry.py +8 -12
  28. corehttp/runtime/policies/_retry_async.py +5 -9
  29. corehttp/runtime/policies/_universal.py +15 -11
  30. corehttp/serialization.py +237 -3
  31. corehttp/settings.py +59 -0
  32. corehttp/transport/_base.py +1 -3
  33. corehttp/transport/_base_async.py +1 -3
  34. corehttp/transport/aiohttp/_aiohttp.py +41 -16
  35. corehttp/transport/requests/_bigger_block_size_http_adapters.py +1 -1
  36. corehttp/transport/requests/_requests_basic.py +33 -18
  37. corehttp/utils/_enum_meta.py +1 -1
  38. corehttp/utils/_utils.py +2 -1
  39. corehttp-1.0.0b7.dist-info/METADATA +196 -0
  40. corehttp-1.0.0b7.dist-info/RECORD +61 -0
  41. {corehttp-1.0.0b5.dist-info → corehttp-1.0.0b7.dist-info}/WHEEL +1 -1
  42. corehttp-1.0.0b5.dist-info/METADATA +0 -132
  43. corehttp-1.0.0b5.dist-info/RECORD +0 -52
  44. {corehttp-1.0.0b5.dist-info → corehttp-1.0.0b7.dist-info/licenses}/LICENSE +0 -0
  45. {corehttp-1.0.0b5.dist-info → corehttp-1.0.0b7.dist-info}/top_level.txt +0 -0
@@ -300,7 +300,7 @@ class _HttpResponseBase(abc.ABC):
300
300
 
301
301
  :return: The JSON deserialized response body
302
302
  :rtype: any
303
- :raises json.decoder.JSONDecodeError or ValueError (in python 2.7) if object is not JSON decodable:
303
+ :raises json.decoder.JSONDecodeError: if the body is not valid JSON.
304
304
  """
305
305
 
306
306
  @abc.abstractmethod
@@ -309,7 +309,7 @@ class _HttpResponseBase(abc.ABC):
309
309
 
310
310
  If response is good, does nothing.
311
311
 
312
- :raises ~corehttp.HttpResponseError if the object has an error status code.:
312
+ :raises ~corehttp.HttpResponseError: if the object has an error status code.
313
313
  """
314
314
 
315
315
 
@@ -329,16 +329,13 @@ class HttpResponse(_HttpResponseBase):
329
329
  """
330
330
 
331
331
  @abc.abstractmethod
332
- def __enter__(self) -> "HttpResponse":
333
- ...
332
+ def __enter__(self) -> "HttpResponse": ...
334
333
 
335
334
  @abc.abstractmethod
336
- def __exit__(self, *args: Any) -> None:
337
- ...
335
+ def __exit__(self, *args: Any) -> None: ...
338
336
 
339
337
  @abc.abstractmethod
340
- def close(self) -> None:
341
- ...
338
+ def close(self) -> None: ...
342
339
 
343
340
  @abc.abstractmethod
344
341
  def read(self) -> bytes:
@@ -415,5 +412,4 @@ class AsyncHttpResponse(_HttpResponseBase, AsyncContextManager["AsyncHttpRespons
415
412
  yield # pylint: disable=unreachable
416
413
 
417
414
  @abc.abstractmethod
418
- async def close(self) -> None:
419
- ...
415
+ async def close(self) -> None: ...
@@ -51,9 +51,7 @@ class PipelineContext(Dict[str, Any]):
51
51
 
52
52
  _PICKLE_CONTEXT = {"deserialized_data"}
53
53
 
54
- def __init__(
55
- self, transport: Optional["TransportType"], **kwargs: Any
56
- ) -> None: # pylint: disable=super-init-not-called
54
+ def __init__(self, transport: Optional["TransportType"], **kwargs: Any) -> None:
57
55
  self.transport: Optional["TransportType"] = transport
58
56
  self.options = kwargs
59
57
  self._protected = ["transport", "options"]
@@ -99,7 +97,7 @@ class PipelineContext(Dict[str, Any]):
99
97
  def clear(self) -> None: # pylint: disable=docstring-missing-return, docstring-missing-rtype
100
98
  """Context objects cannot be cleared.
101
99
 
102
- :raises: TypeError
100
+ :raises TypeError: If context objects cannot be cleared.
103
101
  """
104
102
  raise TypeError("Context objects cannot be cleared.")
105
103
 
@@ -108,17 +106,15 @@ class PipelineContext(Dict[str, Any]):
108
106
  ) -> None:
109
107
  """Context objects cannot be updated.
110
108
 
111
- :raises: TypeError
109
+ :raises TypeError: If context objects cannot be updated.
112
110
  """
113
111
  raise TypeError("Context objects cannot be updated.")
114
112
 
115
113
  @overload
116
- def pop(self, __key: str) -> Any:
117
- ...
114
+ def pop(self, __key: str) -> Any: ...
118
115
 
119
116
  @overload
120
- def pop(self, __key: str, __default: Optional[Any]) -> Any:
121
- ...
117
+ def pop(self, __key: str, __default: Optional[Any]) -> Any: ...
122
118
 
123
119
  def pop(self, *args: Any) -> Any:
124
120
  """Removes specified key and returns the value.
@@ -34,7 +34,7 @@ from . import (
34
34
  PipelineContext,
35
35
  )
36
36
  from ..policies import HTTPPolicy, SansIOHTTPPolicy
37
- from ._tools import await_result as _await_result
37
+ from ._tools import sanitize_transport_options, await_result as _await_result
38
38
  from ...transport import HttpTransport
39
39
 
40
40
  HTTPResponseType = TypeVar("HTTPResponseType")
@@ -103,6 +103,7 @@ class _TransportRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]):
103
103
  :return: The PipelineResponse object.
104
104
  :rtype: ~corehttp.runtime.pipeline.PipelineResponse
105
105
  """
106
+ sanitize_transport_options(request.context.options)
106
107
  return PipelineResponse(
107
108
  request.http_request,
108
109
  self._sender.send(request.http_request, **request.context.options),
@@ -153,7 +154,7 @@ class Pipeline(ContextManager["Pipeline"], Generic[HTTPRequestType, HTTPResponse
153
154
  self._transport.__enter__()
154
155
  return self
155
156
 
156
- def __exit__(self, *exc_details: Any) -> None: # pylint: disable=arguments-differ
157
+ def __exit__(self, *exc_details: Any) -> None:
157
158
  self._transport.__exit__(*exc_details)
158
159
 
159
160
  def run(self, request: HTTPRequestType, **kwargs: Any) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
@@ -26,13 +26,14 @@
26
26
  from __future__ import annotations
27
27
  import inspect
28
28
  from types import TracebackType
29
- from typing import Any, Union, Generic, TypeVar, List, Optional, Iterable, Type
30
- from typing_extensions import AsyncContextManager, TypeGuard
29
+ from typing import Any, Union, Generic, TypeVar, List, Optional, Iterable, Type, AsyncContextManager
30
+ from typing_extensions import TypeGuard
31
31
 
32
32
  from . import PipelineRequest, PipelineResponse, PipelineContext
33
33
  from ..policies import AsyncHTTPPolicy, SansIOHTTPPolicy
34
34
  from ..pipeline._base import is_sansio_http_policy
35
35
  from ._tools_async import await_result as _await_result
36
+ from ._tools import sanitize_transport_options
36
37
  from ...transport import AsyncHttpTransport
37
38
 
38
39
  AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
@@ -45,9 +46,7 @@ def is_async_http_policy(policy: object) -> TypeGuard[AsyncHTTPPolicy]:
45
46
  return False
46
47
 
47
48
 
48
- class _SansIOAsyncHTTPPolicyRunner(
49
- AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]
50
- ): # pylint: disable=unsubscriptable-object
49
+ class _SansIOAsyncHTTPPolicyRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]):
51
50
  """Async implementation of the SansIO policy.
52
51
 
53
52
  Modifies the request and sends to the next policy in the chain.
@@ -76,9 +75,7 @@ class _SansIOAsyncHTTPPolicyRunner(
76
75
  return response
77
76
 
78
77
 
79
- class _AsyncTransportRunner(
80
- AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]
81
- ): # pylint: disable=unsubscriptable-object
78
+ class _AsyncTransportRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]):
82
79
  """Async Transport runner.
83
80
 
84
81
  Uses specified HTTP transport type to send request and returns response.
@@ -101,6 +98,7 @@ class _AsyncTransportRunner(
101
98
  :return: The PipelineResponse object.
102
99
  :rtype: ~corehttp.runtime.pipeline.PipelineResponse
103
100
  """
101
+ sanitize_transport_options(request.context.options)
104
102
  return PipelineResponse(
105
103
  request.http_request,
106
104
  await self._sender.send(request.http_request, **request.context.options),
@@ -23,7 +23,7 @@
23
23
  # IN THE SOFTWARE.
24
24
  #
25
25
  # --------------------------------------------------------------------------
26
- from typing import Callable, TypeVar
26
+ from typing import Callable, TypeVar, Dict
27
27
  from typing_extensions import ParamSpec
28
28
 
29
29
  P = ParamSpec("P")
@@ -39,9 +39,25 @@ def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
39
39
  :type args: list
40
40
  :rtype: any
41
41
  :return: The result of the function
42
- :raises: TypeError
42
+ :raises TypeError: If the function returns an awaitable object.
43
43
  """
44
44
  result = func(*args, **kwargs)
45
45
  if hasattr(result, "__await__"):
46
46
  raise TypeError("Policy {} returned awaitable object in non-async pipeline.".format(func))
47
47
  return result
48
+
49
+
50
+ def sanitize_transport_options(options: Dict[str, str]) -> None:
51
+ """Remove options that could potentially make it to the transport layer.
52
+
53
+ - "tracing_options" is used in the DistributedHttpTracingPolicy and tracing decorators
54
+
55
+ :param options: The options.
56
+ :type options: dict
57
+ """
58
+ if not options:
59
+ return
60
+
61
+ options_to_remove = ["tracing_options"]
62
+ for key in options_to_remove:
63
+ options.pop(key, None)
@@ -31,13 +31,11 @@ T = TypeVar("T")
31
31
 
32
32
 
33
33
  @overload
34
- async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T:
35
- ...
34
+ async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ...
36
35
 
37
36
 
38
37
  @overload
39
- async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
40
- ...
38
+ async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ...
41
39
 
42
40
 
43
41
  async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T:
@@ -29,6 +29,7 @@ from ._authentication import (
29
29
  BearerTokenCredentialPolicy,
30
30
  ServiceKeyCredentialPolicy,
31
31
  )
32
+ from ._distributed_tracing import DistributedHttpTracingPolicy
32
33
  from ._retry import RetryPolicy, RetryMode
33
34
  from ._universal import (
34
35
  HeadersPolicy,
@@ -57,4 +58,5 @@ __all__ = [
57
58
  "AsyncHTTPPolicy",
58
59
  "AsyncBearerTokenCredentialPolicy",
59
60
  "AsyncRetryPolicy",
61
+ "DistributedHttpTracingPolicy",
60
62
  ]
@@ -5,16 +5,17 @@
5
5
  # -------------------------------------------------------------------------
6
6
  from __future__ import annotations
7
7
  import time
8
- from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any
8
+ from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union
9
9
 
10
+ from ...credentials import TokenRequestOptions
10
11
  from ...rest import HttpResponse, HttpRequest
11
12
  from . import HTTPPolicy, SansIOHTTPPolicy
12
- from ...exceptions import ServiceRequestError
13
+ from ...exceptions import ServiceRequestError, HttpResponseError
13
14
 
14
15
  if TYPE_CHECKING:
15
- # pylint:disable=unused-import
16
+
16
17
  from ...credentials import (
17
- AccessToken,
18
+ AccessTokenInfo,
18
19
  TokenCredential,
19
20
  ServiceKeyCredential,
20
21
  )
@@ -31,15 +32,23 @@ class _BearerTokenCredentialPolicyBase:
31
32
  :param credential: The credential.
32
33
  :type credential: ~corehttp.credentials.TokenCredential
33
34
  :param str scopes: Lets you specify the type of access needed.
35
+ :keyword auth_flows: A list of authentication flows to use for the credential.
36
+ :paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
34
37
  """
35
38
 
39
+ # pylint: disable=unused-argument
36
40
  def __init__(
37
- self, credential: "TokenCredential", *scopes: str, **kwargs: Any # pylint: disable=unused-argument
41
+ self,
42
+ credential: "TokenCredential",
43
+ *scopes: str,
44
+ auth_flows: Optional[list[dict[str, Union[str, list[dict[str, str]]]]]] = None,
45
+ **kwargs: Any,
38
46
  ) -> None:
39
47
  super(_BearerTokenCredentialPolicyBase, self).__init__()
40
48
  self._scopes = scopes
41
49
  self._credential = credential
42
- self._token: Optional["AccessToken"] = None
50
+ self._token: Optional["AccessTokenInfo"] = None
51
+ self._auth_flows = auth_flows
43
52
 
44
53
  @staticmethod
45
54
  def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
@@ -68,7 +77,12 @@ class _BearerTokenCredentialPolicyBase:
68
77
 
69
78
  @property
70
79
  def _need_new_token(self) -> bool:
71
- return not self._token or self._token.expires_on - time.time() < 300
80
+ now = time.time()
81
+ return (
82
+ not self._token
83
+ or (self._token.refresh_on is not None and self._token.refresh_on <= now)
84
+ or (self._token.expires_on - now < 300)
85
+ )
72
86
 
73
87
 
74
88
  class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
@@ -77,20 +91,33 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[H
77
91
  :param credential: The credential.
78
92
  :type credential: ~corehttp.TokenCredential
79
93
  :param str scopes: Lets you specify the type of access needed.
80
- :raises: :class:`~corehttp.exceptions.ServiceRequestError`
94
+ :keyword auth_flows: A list of authentication flows to use for the credential.
95
+ :paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
96
+ :raises ~corehttp.exceptions.ServiceRequestError: If the request fails.
81
97
  """
82
98
 
83
- def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
99
+ def on_request(
100
+ self,
101
+ request: PipelineRequest[HTTPRequestType],
102
+ *,
103
+ auth_flows: Optional[list[dict[str, Union[str, list[dict[str, str]]]]]] = None,
104
+ ) -> None:
84
105
  """Called before the policy sends a request.
85
106
 
86
107
  The base implementation authorizes the request with a bearer token.
87
108
 
88
109
  :param ~corehttp.runtime.pipeline.PipelineRequest request: the request
110
+ :keyword auth_flows: A list of authentication flows to use for the credential.
111
+ :paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
89
112
  """
113
+ # If auth_flows is an empty list, we should not attempt to authorize the request.
114
+ if auth_flows is not None and len(auth_flows) == 0:
115
+ return
90
116
  self._enforce_https(request)
91
117
 
92
118
  if self._token is None or self._need_new_token:
93
- self._token = self._credential.get_token(*self._scopes)
119
+ options: TokenRequestOptions = {"auth_flows": auth_flows} if auth_flows else {} # type: ignore
120
+ self._token = self._credential.get_token_info(*self._scopes, options=options)
94
121
  self._update_headers(request.http_request.headers, self._token.token)
95
122
 
96
123
  def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
@@ -102,7 +129,12 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[H
102
129
  :param ~corehttp.runtime.pipeline.PipelineRequest request: the request
103
130
  :param str scopes: required scopes of authentication
104
131
  """
105
- self._token = self._credential.get_token(*scopes, **kwargs)
132
+ options: TokenRequestOptions = {}
133
+ # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
134
+ for key in list(kwargs.keys()):
135
+ if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member
136
+ options[key] = kwargs.pop(key) # type: ignore[literal-required]
137
+ self._token = self._credential.get_token_info(*scopes, options=options)
106
138
  self._update_headers(request.http_request.headers, self._token.token)
107
139
 
108
140
  def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
@@ -113,25 +145,39 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[H
113
145
  :return: The pipeline response object
114
146
  :rtype: ~corehttp.runtime.pipeline.PipelineResponse
115
147
  """
116
- self.on_request(request)
148
+ op_auth_flows = request.context.options.pop("auth_flows", None)
149
+ auth_flows = op_auth_flows if op_auth_flows is not None else self._auth_flows
150
+ self.on_request(request, auth_flows=auth_flows)
117
151
  try:
118
152
  response = self.next.send(request)
119
- self.on_response(request, response)
120
- except Exception: # pylint:disable=broad-except
153
+ except Exception:
121
154
  self.on_exception(request)
122
155
  raise
123
- else:
124
- if response.http_response.status_code == 401:
125
- self._token = None # any cached token is invalid
126
- if "WWW-Authenticate" in response.http_response.headers:
156
+
157
+ self.on_response(request, response)
158
+ if response.http_response.status_code == 401:
159
+ self._token = None # any cached token is invalid
160
+ if "WWW-Authenticate" in response.http_response.headers:
161
+ try:
127
162
  request_authorized = self.on_challenge(request, response)
128
- if request_authorized:
163
+ except Exception as ex:
164
+ # If the response is streamed, read it so the error message is immediately available to the user.
165
+ # Otherwise, a generic error message will be given and the user will have to read the response
166
+ # body to see the actual error.
167
+ if response.context.options.get("stream"):
129
168
  try:
130
- response = self.next.send(request)
131
- self.on_response(request, response)
169
+ response.http_response.read() # type: ignore
132
170
  except Exception: # pylint:disable=broad-except
133
- self.on_exception(request)
134
- raise
171
+ pass
172
+ # Raise the exception from the token request with the original 401 response.
173
+ raise ex from HttpResponseError(response=response.http_response)
174
+ if request_authorized:
175
+ try:
176
+ response = self.next.send(request)
177
+ self.on_response(request, response)
178
+ except Exception:
179
+ self.on_exception(request)
180
+ raise
135
181
 
136
182
  return response
137
183
 
@@ -181,7 +227,8 @@ class ServiceKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseT
181
227
  :type credential: ~corehttp.credentials.ServiceKeyCredential
182
228
  :param str name: The name of the key header used for the credential.
183
229
  :keyword str prefix: The name of the prefix for the header value if any.
184
- :raises: ValueError or TypeError
230
+ :raises ValueError: if name is None or empty.
231
+ :raises TypeError: if name is not a string or if credential is not an instance of ServiceKeyCredential.
185
232
  """
186
233
 
187
234
  def __init__( # pylint: disable=unused-argument
@@ -204,4 +251,9 @@ class ServiceKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseT
204
251
  self._prefix = prefix + " " if prefix else ""
205
252
 
206
253
  def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
254
+ """Called before the policy sends a request.
255
+
256
+ :param request: The request to be modified before sending.
257
+ :type request: ~corehttp.runtime.pipeline.PipelineRequest
258
+ """
207
259
  request.http_request.headers[self._name] = f"{self._prefix}{self._credential.key}"
@@ -5,14 +5,15 @@
5
5
  # -------------------------------------------------------------------------
6
6
  from __future__ import annotations
7
7
  import time
8
- from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast, TypeVar
8
+ from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast, TypeVar, Union
9
9
 
10
- from ...credentials import AccessToken
10
+ from ...credentials import AccessTokenInfo, TokenRequestOptions
11
11
  from ..pipeline import PipelineRequest, PipelineResponse
12
12
  from ..pipeline._tools_async import await_result
13
13
  from ._base_async import AsyncHTTPPolicy
14
14
  from ._authentication import _BearerTokenCredentialPolicyBase
15
15
  from ...rest import AsyncHttpResponse, HttpRequest
16
+ from ...exceptions import HttpResponseError
16
17
  from ...utils._utils import get_running_async_lock
17
18
 
18
19
  if TYPE_CHECKING:
@@ -29,16 +30,24 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT
29
30
  :param credential: The credential.
30
31
  :type credential: ~corehttp.credentials.TokenCredential
31
32
  :param str scopes: Lets you specify the type of access needed.
33
+ :keyword auth_flows: A list of authentication flows to use for the credential.
34
+ :paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
32
35
  """
33
36
 
37
+ # pylint: disable=unused-argument
34
38
  def __init__(
35
- self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any # pylint: disable=unused-argument
39
+ self,
40
+ credential: "AsyncTokenCredential",
41
+ *scopes: str,
42
+ auth_flows: Optional[list[dict[str, Union[str, list[dict[str, str]]]]]] = None,
43
+ **kwargs: Any,
36
44
  ) -> None:
37
45
  super().__init__()
38
46
  self._credential = credential
39
47
  self._lock_instance = None
40
48
  self._scopes = scopes
41
- self._token: Optional["AccessToken"] = None
49
+ self._token: Optional[AccessTokenInfo] = None
50
+ self._auth_flows = auth_flows
42
51
 
43
52
  @property
44
53
  def _lock(self):
@@ -46,21 +55,32 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT
46
55
  self._lock_instance = get_running_async_lock()
47
56
  return self._lock_instance
48
57
 
49
- async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
58
+ async def on_request(
59
+ self,
60
+ request: PipelineRequest[HTTPRequestType],
61
+ *,
62
+ auth_flows: Optional[list[dict[str, Union[str, list[dict[str, str]]]]]] = None,
63
+ ) -> None:
50
64
  """Adds a bearer token Authorization header to request and sends request to next policy.
51
65
 
52
66
  :param request: The pipeline request object to be modified.
53
67
  :type request: ~corehttp.runtime.pipeline.PipelineRequest
54
- :raises: :class:`~corehttp.exceptions.ServiceRequestError`
68
+ :keyword auth_flows: A list of authentication flows to use for the credential.
69
+ :paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
70
+ :raises ~corehttp.exceptions.ServiceRequestError: If the request fails.
55
71
  """
72
+ # If auth_flows is an empty list, we should not attempt to authorize the request.
73
+ if auth_flows is not None and len(auth_flows) == 0:
74
+ return
56
75
  _BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access
57
76
 
58
- if self._token is None or self._need_new_token():
77
+ if self._token is None or self._need_new_token:
59
78
  async with self._lock:
60
79
  # double check because another coroutine may have acquired a token while we waited to acquire the lock
61
- if self._token is None or self._need_new_token():
62
- self._token = await await_result(self._credential.get_token, *self._scopes)
63
- request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token
80
+ if self._token is None or self._need_new_token:
81
+ options: TokenRequestOptions = {"auth_flows": auth_flows} if auth_flows else {} # type: ignore
82
+ self._token = await await_result(self._credential.get_token_info, *self._scopes, options=options)
83
+ request.http_request.headers["Authorization"] = "Bearer " + cast(AccessTokenInfo, self._token).token
64
84
 
65
85
  async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
66
86
  """Acquire a token from the credential and authorize the request with it.
@@ -71,9 +91,15 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT
71
91
  :param ~corehttp.runtime.pipeline.PipelineRequest request: the request
72
92
  :param str scopes: required scopes of authentication
73
93
  """
94
+ options: TokenRequestOptions = {}
95
+ # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
96
+ for key in list(kwargs.keys()):
97
+ if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member
98
+ options[key] = kwargs.pop(key) # type: ignore[literal-required]
99
+
74
100
  async with self._lock:
75
- self._token = await await_result(self._credential.get_token, *scopes, **kwargs)
76
- request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token
101
+ self._token = await await_result(self._credential.get_token_info, *scopes, options=options)
102
+ request.http_request.headers["Authorization"] = "Bearer " + cast(AccessTokenInfo, self._token).token
77
103
 
78
104
  async def send(
79
105
  self, request: PipelineRequest[HTTPRequestType]
@@ -85,27 +111,40 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT
85
111
  :return: The pipeline response object
86
112
  :rtype: ~corehttp.runtime.pipeline.PipelineResponse
87
113
  """
88
- await await_result(self.on_request, request)
114
+ op_auth_flows = request.context.options.pop("auth_flows", None)
115
+ auth_flows = op_auth_flows if op_auth_flows is not None else self._auth_flows
116
+ await await_result(self.on_request, request, auth_flows=auth_flows)
89
117
  try:
90
118
  response = await self.next.send(request)
91
- except Exception: # pylint:disable=broad-except
119
+ except Exception:
92
120
  await await_result(self.on_exception, request)
93
121
  raise
94
- else:
95
- await await_result(self.on_response, request, response)
122
+ await await_result(self.on_response, request, response)
96
123
 
97
124
  if response.http_response.status_code == 401:
98
125
  self._token = None # any cached token is invalid
99
126
  if "WWW-Authenticate" in response.http_response.headers:
100
- request_authorized = await self.on_challenge(request, response)
127
+ try:
128
+ request_authorized = await self.on_challenge(request, response)
129
+ except Exception as ex:
130
+ # If the response is streamed, read it so the error message is immediately available to the user.
131
+ # Otherwise, a generic error message will be given and the user will have to read the response
132
+ # body to see the actual error.
133
+ if response.context.options.get("stream"):
134
+ try:
135
+ await response.http_response.read() # type: ignore
136
+ except Exception: # pylint:disable=broad-except
137
+ pass
138
+
139
+ # Raise the exception from the token request with the original 401 response
140
+ raise ex from HttpResponseError(response=response.http_response)
101
141
  if request_authorized:
102
142
  try:
103
143
  response = await self.next.send(request)
104
- except Exception: # pylint:disable=broad-except
144
+ except Exception:
105
145
  await await_result(self.on_exception, request)
106
146
  raise
107
- else:
108
- await await_result(self.on_response, request, response)
147
+ await await_result(self.on_response, request, response)
109
148
 
110
149
  return response
111
150
 
@@ -151,5 +190,11 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT
151
190
  # pylint: disable=unused-argument
152
191
  return
153
192
 
193
+ @property
154
194
  def _need_new_token(self) -> bool:
155
- return not self._token or self._token.expires_on - time.time() < 300
195
+ now = time.time()
196
+ return (
197
+ not self._token
198
+ or (self._token.refresh_on is not None and self._token.refresh_on <= now)
199
+ or (self._token.expires_on - now < 300)
200
+ )