corehttp 1.0.0b5__py3-none-any.whl → 1.0.0b7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- corehttp/_version.py +1 -1
- corehttp/credentials.py +66 -25
- corehttp/exceptions.py +7 -6
- corehttp/instrumentation/__init__.py +9 -0
- corehttp/instrumentation/tracing/__init__.py +14 -0
- corehttp/instrumentation/tracing/_decorator.py +189 -0
- corehttp/instrumentation/tracing/_models.py +72 -0
- corehttp/instrumentation/tracing/_tracer.py +69 -0
- corehttp/instrumentation/tracing/opentelemetry.py +277 -0
- corehttp/instrumentation/tracing/utils.py +31 -0
- corehttp/paging.py +13 -0
- corehttp/rest/_aiohttp.py +21 -9
- corehttp/rest/_http_response_impl.py +9 -15
- corehttp/rest/_http_response_impl_async.py +2 -0
- corehttp/rest/_httpx.py +9 -9
- corehttp/rest/_requests_basic.py +17 -10
- corehttp/rest/_rest_py3.py +6 -10
- corehttp/runtime/pipeline/__init__.py +5 -9
- corehttp/runtime/pipeline/_base.py +3 -2
- corehttp/runtime/pipeline/_base_async.py +6 -8
- corehttp/runtime/pipeline/_tools.py +18 -2
- corehttp/runtime/pipeline/_tools_async.py +2 -4
- corehttp/runtime/policies/__init__.py +2 -0
- corehttp/runtime/policies/_authentication.py +76 -24
- corehttp/runtime/policies/_authentication_async.py +66 -21
- corehttp/runtime/policies/_distributed_tracing.py +169 -0
- corehttp/runtime/policies/_retry.py +8 -12
- corehttp/runtime/policies/_retry_async.py +5 -9
- corehttp/runtime/policies/_universal.py +15 -11
- corehttp/serialization.py +237 -3
- corehttp/settings.py +59 -0
- corehttp/transport/_base.py +1 -3
- corehttp/transport/_base_async.py +1 -3
- corehttp/transport/aiohttp/_aiohttp.py +41 -16
- corehttp/transport/requests/_bigger_block_size_http_adapters.py +1 -1
- corehttp/transport/requests/_requests_basic.py +33 -18
- corehttp/utils/_enum_meta.py +1 -1
- corehttp/utils/_utils.py +2 -1
- corehttp-1.0.0b7.dist-info/METADATA +196 -0
- corehttp-1.0.0b7.dist-info/RECORD +61 -0
- {corehttp-1.0.0b5.dist-info → corehttp-1.0.0b7.dist-info}/WHEEL +1 -1
- corehttp-1.0.0b5.dist-info/METADATA +0 -132
- corehttp-1.0.0b5.dist-info/RECORD +0 -52
- {corehttp-1.0.0b5.dist-info → corehttp-1.0.0b7.dist-info/licenses}/LICENSE +0 -0
- {corehttp-1.0.0b5.dist-info → corehttp-1.0.0b7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# ------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# ------------------------------------
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
import logging
|
|
7
|
+
import urllib.parse
|
|
8
|
+
from typing import Any, Optional, Tuple, Union, Type, Mapping, Dict, TYPE_CHECKING
|
|
9
|
+
from types import TracebackType
|
|
10
|
+
|
|
11
|
+
from ...rest import HttpRequest
|
|
12
|
+
from ...rest._rest_py3 import _HttpResponseBase as SansIOHttpResponse
|
|
13
|
+
from ._base import SansIOHTTPPolicy
|
|
14
|
+
from ...settings import settings
|
|
15
|
+
from ...instrumentation.tracing._models import SpanKind, TracingOptions
|
|
16
|
+
from ...instrumentation.tracing._tracer import get_tracer
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from ..pipeline import PipelineRequest, PipelineResponse
|
|
20
|
+
from opentelemetry.trace import Span
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
|
|
24
|
+
OptExcInfo = Union[ExcInfo, Tuple[None, None, None]]
|
|
25
|
+
|
|
26
|
+
_LOGGER = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DistributedHttpTracingPolicy(SansIOHTTPPolicy[HttpRequest, SansIOHttpResponse]):
|
|
30
|
+
"""The policy to create tracing spans for API calls.
|
|
31
|
+
|
|
32
|
+
:keyword instrumentation_config: Configuration for the instrumentation providers.
|
|
33
|
+
:type instrumentation_config: dict[str, Any]
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
TRACING_CONTEXT = "TRACING_CONTEXT"
|
|
37
|
+
_SUPPRESSION_TOKEN = "SUPPRESSION_TOKEN"
|
|
38
|
+
|
|
39
|
+
# Attribute names
|
|
40
|
+
_HTTP_RESEND_COUNT = "http.request.resend_count"
|
|
41
|
+
_USER_AGENT_ORIGINAL = "user_agent.original"
|
|
42
|
+
_HTTP_REQUEST_METHOD = "http.request.method"
|
|
43
|
+
_URL_FULL = "url.full"
|
|
44
|
+
_HTTP_RESPONSE_STATUS_CODE = "http.response.status_code"
|
|
45
|
+
_SERVER_ADDRESS = "server.address"
|
|
46
|
+
_SERVER_PORT = "server.port"
|
|
47
|
+
_ERROR_TYPE = "error.type"
|
|
48
|
+
|
|
49
|
+
def __init__( # pylint: disable=unused-argument
|
|
50
|
+
self,
|
|
51
|
+
*,
|
|
52
|
+
instrumentation_config: Optional[Mapping[str, Any]] = None,
|
|
53
|
+
**kwargs: Any,
|
|
54
|
+
) -> None:
|
|
55
|
+
self._instrumentation_config = instrumentation_config
|
|
56
|
+
|
|
57
|
+
def on_request(self, request: PipelineRequest[HttpRequest]) -> None:
|
|
58
|
+
"""Starts a span for the network call.
|
|
59
|
+
|
|
60
|
+
:param request: The PipelineRequest object.
|
|
61
|
+
:type request: ~corehttp.runtime.pipeline.PipelineRequest
|
|
62
|
+
"""
|
|
63
|
+
ctxt = request.context.options
|
|
64
|
+
try:
|
|
65
|
+
tracing_options: TracingOptions = ctxt.pop("tracing_options", {})
|
|
66
|
+
|
|
67
|
+
# User can explicitly disable tracing for this request.
|
|
68
|
+
user_enabled = tracing_options.get("enabled")
|
|
69
|
+
if user_enabled is False:
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
# If tracing is disabled globally and user didn't explicitly enable it, don't trace.
|
|
73
|
+
if not settings.tracing_enabled and user_enabled is None:
|
|
74
|
+
return
|
|
75
|
+
|
|
76
|
+
config = self._instrumentation_config or {}
|
|
77
|
+
tracer = get_tracer(
|
|
78
|
+
library_name=config.get("library_name"),
|
|
79
|
+
library_version=config.get("library_version"),
|
|
80
|
+
attributes=config.get("attributes"),
|
|
81
|
+
)
|
|
82
|
+
if not tracer:
|
|
83
|
+
_LOGGER.warning(
|
|
84
|
+
"Tracing is enabled, but not able to get an OpenTelemetry tracer. "
|
|
85
|
+
"Please ensure that `opentelemetry-api` is installed."
|
|
86
|
+
)
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
span_name = request.http_request.method
|
|
90
|
+
span = tracer.start_span(
|
|
91
|
+
name=span_name,
|
|
92
|
+
kind=SpanKind.CLIENT,
|
|
93
|
+
attributes=tracing_options.get("attributes"),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
with tracer.use_span(span, end_on_exit=False):
|
|
97
|
+
trace_context_headers = tracer.get_trace_context()
|
|
98
|
+
request.http_request.headers.update(trace_context_headers)
|
|
99
|
+
|
|
100
|
+
request.context[self.TRACING_CONTEXT] = span
|
|
101
|
+
token = tracer._suppress_auto_http_instrumentation() # pylint: disable=protected-access
|
|
102
|
+
request.context[self._SUPPRESSION_TOKEN] = token
|
|
103
|
+
except Exception as err: # pylint: disable=broad-except
|
|
104
|
+
_LOGGER.warning("Unable to start HTTP span: %s", err) # pylint: disable=do-not-log-exceptions-if-not-debug
|
|
105
|
+
|
|
106
|
+
def on_response(
|
|
107
|
+
self,
|
|
108
|
+
request: PipelineRequest[HttpRequest],
|
|
109
|
+
response: PipelineResponse[HttpRequest, SansIOHttpResponse],
|
|
110
|
+
) -> None:
|
|
111
|
+
"""Ends the span for the network call and updates its status.
|
|
112
|
+
|
|
113
|
+
:param request: The PipelineRequest object.
|
|
114
|
+
:type request: ~corehttp.runtime.pipeline.PipelineRequest
|
|
115
|
+
:param response: The PipelineResponse object.
|
|
116
|
+
:type response: ~corehttp.runtime.pipeline.PipelineResponse
|
|
117
|
+
"""
|
|
118
|
+
if self.TRACING_CONTEXT not in request.context:
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
span: Optional["Span"] = request.context[self.TRACING_CONTEXT]
|
|
122
|
+
http_request = request.http_request
|
|
123
|
+
if span:
|
|
124
|
+
self._set_http_client_span_attributes(span, http_request, response=response.http_response)
|
|
125
|
+
if request.context.get("retry_count"):
|
|
126
|
+
span.set_attribute(self._HTTP_RESEND_COUNT, request.context["retry_count"])
|
|
127
|
+
span.end()
|
|
128
|
+
|
|
129
|
+
suppression_token = request.context.get(self._SUPPRESSION_TOKEN)
|
|
130
|
+
if suppression_token:
|
|
131
|
+
tracer = get_tracer()
|
|
132
|
+
if tracer:
|
|
133
|
+
tracer._detach_from_context(suppression_token) # pylint: disable=protected-access
|
|
134
|
+
|
|
135
|
+
def _set_http_client_span_attributes(
|
|
136
|
+
self,
|
|
137
|
+
span: "Span",
|
|
138
|
+
request: HttpRequest,
|
|
139
|
+
response: Optional[SansIOHttpResponse] = None,
|
|
140
|
+
) -> None:
|
|
141
|
+
"""Add attributes to an HTTP client span.
|
|
142
|
+
|
|
143
|
+
:param span: The span to add attributes to.
|
|
144
|
+
:type span: ~opentelemetry.trace.Span
|
|
145
|
+
:param request: The request made
|
|
146
|
+
:type request: ~corehttp.rest.HttpRequest
|
|
147
|
+
:param response: The response received from the server. Is None if no response received.
|
|
148
|
+
:type response: ~corehttp.rest.HttpResponse
|
|
149
|
+
"""
|
|
150
|
+
attributes: Dict[str, Any] = {
|
|
151
|
+
self._HTTP_REQUEST_METHOD: request.method,
|
|
152
|
+
self._URL_FULL: request.url,
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
parsed_url = urllib.parse.urlparse(request.url)
|
|
156
|
+
if parsed_url.hostname:
|
|
157
|
+
attributes[self._SERVER_ADDRESS] = parsed_url.hostname
|
|
158
|
+
if parsed_url.port:
|
|
159
|
+
attributes[self._SERVER_PORT] = parsed_url.port
|
|
160
|
+
|
|
161
|
+
user_agent = request.headers.get("User-Agent")
|
|
162
|
+
if user_agent:
|
|
163
|
+
attributes[self._USER_AGENT_ORIGINAL] = user_agent
|
|
164
|
+
if response and response.status_code:
|
|
165
|
+
attributes[self._HTTP_RESPONSE_STATUS_CODE] = response.status_code
|
|
166
|
+
if response.status_code >= 400:
|
|
167
|
+
attributes[self._ERROR_TYPE] = str(response.status_code)
|
|
168
|
+
|
|
169
|
+
span.set_attributes(attributes)
|
|
@@ -52,6 +52,8 @@ _LOGGER = logging.getLogger(__name__)
|
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
class RetryMode(str, Enum, metaclass=CaseInsensitiveEnumMeta):
|
|
55
|
+
"""Enum for retry modes."""
|
|
56
|
+
|
|
55
57
|
# pylint: disable=enum-must-be-uppercase
|
|
56
58
|
Exponential = "exponential"
|
|
57
59
|
Fixed = "fixed"
|
|
@@ -104,7 +106,7 @@ class RetryPolicyBase:
|
|
|
104
106
|
"read": options.pop("retry_read", self.read_retries),
|
|
105
107
|
"status": options.pop("retry_status", self.status_retries),
|
|
106
108
|
"backoff": options.pop("retry_backoff_factor", self.backoff_factor),
|
|
107
|
-
"max_backoff": options.pop("retry_backoff_max", self.
|
|
109
|
+
"max_backoff": options.pop("retry_backoff_max", self.backoff_max),
|
|
108
110
|
"methods": options.pop("retry_on_methods", self._method_whitelist),
|
|
109
111
|
"timeout": options.pop("timeout", self.timeout),
|
|
110
112
|
"history": [],
|
|
@@ -394,28 +396,21 @@ class RetryPolicy(RetryPolicyBase, HTTPPolicy[HttpRequest, HttpResponse]):
|
|
|
394
396
|
|
|
395
397
|
:keyword int retry_total: Total number of retries to allow. Takes precedence over other counts.
|
|
396
398
|
Default value is 10.
|
|
397
|
-
|
|
398
399
|
:keyword int retry_connect: How many connection-related errors to retry on.
|
|
399
400
|
These are errors raised before the request is sent to the remote server,
|
|
400
401
|
which we assume has not triggered the server to process the request. Default value is 3.
|
|
401
|
-
|
|
402
402
|
:keyword int retry_read: How many times to retry on read errors.
|
|
403
403
|
These errors are raised after the request was sent to the server, so the
|
|
404
404
|
request may have side-effects. Default value is 3.
|
|
405
|
-
|
|
406
405
|
:keyword int retry_status: How many times to retry on bad status codes. Default value is 3.
|
|
407
|
-
|
|
408
406
|
:keyword float retry_backoff_factor: A backoff factor to apply between attempts after the second try
|
|
409
407
|
(most errors are resolved immediately by a second try without a delay).
|
|
410
408
|
In fixed mode, retry policy will always sleep for {backoff factor}.
|
|
411
409
|
In 'exponential' mode, retry policy will sleep for: `{backoff factor} * (2 ** ({number of total retries} - 1))`
|
|
412
410
|
seconds. If the backoff_factor is 0.1, then the retry will sleep
|
|
413
411
|
for [0.0s, 0.2s, 0.4s, ...] between retries. The default value is 0.8.
|
|
414
|
-
|
|
415
412
|
:keyword int retry_backoff_max: The maximum back off time. Default value is 120 seconds (2 minutes).
|
|
416
|
-
|
|
417
413
|
:keyword RetryMode retry_mode: Fixed or exponential delay between attemps, default is exponential.
|
|
418
|
-
|
|
419
414
|
:keyword int timeout: Timeout setting for the operation in seconds, default is 604800s (7 days).
|
|
420
415
|
"""
|
|
421
416
|
|
|
@@ -481,10 +476,10 @@ class RetryPolicy(RetryPolicyBase, HTTPPolicy[HttpRequest, HttpResponse]):
|
|
|
481
476
|
|
|
482
477
|
:param request: The PipelineRequest object
|
|
483
478
|
:type request: ~corehttp.runtime.pipeline.PipelineRequest
|
|
484
|
-
:return:
|
|
479
|
+
:return: The PipelineResponse.
|
|
485
480
|
:rtype: ~corehttp.runtime.pipeline.PipelineResponse
|
|
486
|
-
:raises
|
|
487
|
-
:raises
|
|
481
|
+
:raises ~corehttp.exceptions.BaseError: if maximum retries exceeded.
|
|
482
|
+
:raises ~corehttp.exceptions.ClientAuthenticationError: if authentication fails.
|
|
488
483
|
"""
|
|
489
484
|
retry_active = True
|
|
490
485
|
response = None
|
|
@@ -505,6 +500,7 @@ class RetryPolicy(RetryPolicyBase, HTTPPolicy[HttpRequest, HttpResponse]):
|
|
|
505
500
|
)
|
|
506
501
|
try:
|
|
507
502
|
self._configure_timeout(request, absolute_timeout, is_response_error)
|
|
503
|
+
request.context["retry_count"] = len(retry_settings["history"])
|
|
508
504
|
response = self.next.send(request)
|
|
509
505
|
if self.is_retry(retry_settings, response):
|
|
510
506
|
retry_active = self.increment(retry_settings, response=response)
|
|
@@ -513,7 +509,7 @@ class RetryPolicy(RetryPolicyBase, HTTPPolicy[HttpRequest, HttpResponse]):
|
|
|
513
509
|
is_response_error = True
|
|
514
510
|
continue
|
|
515
511
|
break
|
|
516
|
-
except ClientAuthenticationError:
|
|
512
|
+
except ClientAuthenticationError:
|
|
517
513
|
# the authentication policy failed such that the client's request can't
|
|
518
514
|
# succeed--we'll never have a response to it, so propagate the exception
|
|
519
515
|
raise
|
|
@@ -54,23 +54,18 @@ class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HttpRequest, AsyncHttpRe
|
|
|
54
54
|
|
|
55
55
|
:keyword int retry_total: Total number of retries to allow. Takes precedence over other counts.
|
|
56
56
|
Default value is 10.
|
|
57
|
-
|
|
58
57
|
:keyword int retry_connect: How many connection-related errors to retry on.
|
|
59
58
|
These are errors raised before the request is sent to the remote server,
|
|
60
59
|
which we assume has not triggered the server to process the request. Default value is 3.
|
|
61
|
-
|
|
62
60
|
:keyword int retry_read: How many times to retry on read errors.
|
|
63
61
|
These errors are raised after the request was sent to the server, so the
|
|
64
62
|
request may have side-effects. Default value is 3.
|
|
65
|
-
|
|
66
63
|
:keyword int retry_status: How many times to retry on bad status codes. Default value is 3.
|
|
67
|
-
|
|
68
64
|
:keyword float retry_backoff_factor: A backoff factor to apply between attempts after the second try
|
|
69
65
|
(most errors are resolved immediately by a second try without a delay).
|
|
70
66
|
Retry policy will sleep for: `{backoff factor} * (2 ** ({number of total retries} - 1))`
|
|
71
67
|
seconds. If the backoff_factor is 0.1, then the retry will sleep
|
|
72
68
|
for [0.0s, 0.2s, 0.4s, ...] between retries. The default value is 0.8.
|
|
73
|
-
|
|
74
69
|
:keyword int retry_backoff_max: The maximum back off time. Default value is 120 seconds (2 minutes).
|
|
75
70
|
"""
|
|
76
71
|
|
|
@@ -138,10 +133,10 @@ class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HttpRequest, AsyncHttpRe
|
|
|
138
133
|
|
|
139
134
|
:param request: The PipelineRequest object
|
|
140
135
|
:type request: ~corehttp.runtime.pipeline.PipelineRequest
|
|
141
|
-
:return:
|
|
136
|
+
:return: The PipelineResponse.
|
|
142
137
|
:rtype: ~corehttp.runtime.pipeline.PipelineResponse
|
|
143
|
-
:
|
|
144
|
-
:
|
|
138
|
+
:raises ~corehttp.exceptions.BaseError: if maximum retries exceeded.
|
|
139
|
+
:raises ~corehttp.exceptions.ClientAuthenticationError: if authentication fails.
|
|
145
140
|
"""
|
|
146
141
|
retry_active = True
|
|
147
142
|
response = None
|
|
@@ -162,6 +157,7 @@ class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HttpRequest, AsyncHttpRe
|
|
|
162
157
|
)
|
|
163
158
|
try:
|
|
164
159
|
self._configure_timeout(request, absolute_timeout, is_response_error)
|
|
160
|
+
request.context["retry_count"] = len(retry_settings["history"])
|
|
165
161
|
response = await self.next.send(request)
|
|
166
162
|
if self.is_retry(retry_settings, response):
|
|
167
163
|
retry_active = self.increment(retry_settings, response=response)
|
|
@@ -174,7 +170,7 @@ class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HttpRequest, AsyncHttpRe
|
|
|
174
170
|
is_response_error = True
|
|
175
171
|
continue
|
|
176
172
|
break
|
|
177
|
-
except ClientAuthenticationError:
|
|
173
|
+
except ClientAuthenticationError:
|
|
178
174
|
# the authentication policy failed such that the client's request can't
|
|
179
175
|
# succeed--we'll never have a response to it, so propagate the exception
|
|
180
176
|
raise
|
|
@@ -64,9 +64,7 @@ class HeadersPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
|
|
|
64
64
|
:param dict base_headers: Headers to send with the request.
|
|
65
65
|
"""
|
|
66
66
|
|
|
67
|
-
def __init__(
|
|
68
|
-
self, base_headers: Optional[Dict[str, str]] = None, **kwargs: Any
|
|
69
|
-
) -> None: # pylint: disable=super-init-not-called
|
|
67
|
+
def __init__(self, base_headers: Optional[Dict[str, str]] = None, **kwargs: Any) -> None:
|
|
70
68
|
self._headers: Dict[str, str] = base_headers or {}
|
|
71
69
|
self._headers.update(kwargs.pop("headers", {}))
|
|
72
70
|
|
|
@@ -114,9 +112,7 @@ class UserAgentPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
|
|
|
114
112
|
_USERAGENT = "User-Agent"
|
|
115
113
|
_ENV_ADDITIONAL_USER_AGENT = "CORE_HTTP_USER_AGENT"
|
|
116
114
|
|
|
117
|
-
def __init__(
|
|
118
|
-
self, base_user_agent: Optional[str] = None, **kwargs: Any
|
|
119
|
-
) -> None: # pylint: disable=super-init-not-called
|
|
115
|
+
def __init__(self, base_user_agent: Optional[str] = None, **kwargs: Any) -> None:
|
|
120
116
|
self.overwrite: bool = kwargs.pop("user_agent_overwrite", False)
|
|
121
117
|
self.use_env: bool = kwargs.pop("user_agent_use_env", True)
|
|
122
118
|
application_id: Optional[str] = kwargs.pop("user_agent", None)
|
|
@@ -147,6 +143,7 @@ class UserAgentPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
|
|
|
147
143
|
|
|
148
144
|
def add_user_agent(self, value: str) -> None:
|
|
149
145
|
"""Add value to current user agent with a space.
|
|
146
|
+
|
|
150
147
|
:param str value: value to add to user agent.
|
|
151
148
|
"""
|
|
152
149
|
self._user_agent = "{} {}".format(self._user_agent, value)
|
|
@@ -172,7 +169,6 @@ class UserAgentPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
|
|
|
172
169
|
|
|
173
170
|
|
|
174
171
|
class NetworkTraceLoggingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
|
|
175
|
-
|
|
176
172
|
"""The logging policy in the pipeline is used to output HTTP network trace to the configured logger.
|
|
177
173
|
|
|
178
174
|
This accepts both global configuration, and per-request level with "enable_http_logger"
|
|
@@ -183,9 +179,7 @@ class NetworkTraceLoggingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseTy
|
|
|
183
179
|
def __init__(self, logging_enable: bool = False, **kwargs: Any): # pylint: disable=unused-argument
|
|
184
180
|
self.enable_http_logger = logging_enable
|
|
185
181
|
|
|
186
|
-
def on_request(
|
|
187
|
-
self, request: PipelineRequest[HTTPRequestType]
|
|
188
|
-
) -> None: # pylint: disable=too-many-return-statements
|
|
182
|
+
def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
|
|
189
183
|
"""Logs HTTP request to the DEBUG logger.
|
|
190
184
|
|
|
191
185
|
:param request: The PipelineRequest object.
|
|
@@ -408,6 +402,11 @@ class ContentDecodePolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
|
|
|
408
402
|
return cls.deserialize_from_text(response.text(encoding), mime_type, response=response)
|
|
409
403
|
|
|
410
404
|
def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
|
|
405
|
+
"""Set the response encoding in the request context.
|
|
406
|
+
|
|
407
|
+
:param request: The PipelineRequest object.
|
|
408
|
+
:type request: ~corehttp.runtime.pipeline.PipelineRequest
|
|
409
|
+
"""
|
|
411
410
|
options = request.context.options
|
|
412
411
|
response_encoding = options.pop("response_encoding", self._response_encoding)
|
|
413
412
|
if response_encoding:
|
|
@@ -455,10 +454,15 @@ class ProxyPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
|
|
|
455
454
|
|
|
456
455
|
def __init__(
|
|
457
456
|
self, proxies: Optional[MutableMapping[str, str]] = None, **kwargs: Any
|
|
458
|
-
): # pylint: disable=unused-argument
|
|
457
|
+
): # pylint: disable=unused-argument
|
|
459
458
|
self.proxies = proxies
|
|
460
459
|
|
|
461
460
|
def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
|
|
461
|
+
"""Adds the proxy information to the request context.
|
|
462
|
+
|
|
463
|
+
:param request: The PipelineRequest object.
|
|
464
|
+
:type request: ~corehttp.runtime.pipeline.PipelineRequest
|
|
465
|
+
"""
|
|
462
466
|
ctxt = request.context.options
|
|
463
467
|
if self.proxies and "proxies" not in ctxt:
|
|
464
468
|
ctxt["proxies"] = self.proxies
|
corehttp/serialization.py
CHANGED
|
@@ -4,14 +4,23 @@
|
|
|
4
4
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
5
5
|
# license information.
|
|
6
6
|
# --------------------------------------------------------------------------
|
|
7
|
+
# pylint: disable=protected-access
|
|
7
8
|
import base64
|
|
9
|
+
from functools import partial
|
|
8
10
|
from json import JSONEncoder
|
|
9
|
-
from typing import Union, cast, Any
|
|
11
|
+
from typing import Dict, List, Optional, Union, cast, Any, Type, Callable, Tuple
|
|
10
12
|
from datetime import datetime, date, time, timedelta
|
|
11
13
|
from datetime import timezone
|
|
12
14
|
|
|
13
15
|
|
|
14
|
-
__all__ = [
|
|
16
|
+
__all__ = [
|
|
17
|
+
"NULL",
|
|
18
|
+
"CoreJSONEncoder",
|
|
19
|
+
"is_generated_model",
|
|
20
|
+
"attribute_list",
|
|
21
|
+
"TypeHandlerRegistry",
|
|
22
|
+
]
|
|
23
|
+
TZ_UTC = timezone.utc
|
|
15
24
|
|
|
16
25
|
|
|
17
26
|
class _Null:
|
|
@@ -111,10 +120,176 @@ def _datetime_as_isostr(dt: Union[datetime, date, time, timedelta]) -> str:
|
|
|
111
120
|
return _timedelta_as_isostr(dt)
|
|
112
121
|
|
|
113
122
|
|
|
123
|
+
class TypeHandlerRegistry:
|
|
124
|
+
"""A registry for custom serializers and deserializers for specific types or conditions."""
|
|
125
|
+
|
|
126
|
+
def __init__(self) -> None:
|
|
127
|
+
self._serializer_types: Dict[Type, Callable] = {}
|
|
128
|
+
self._deserializer_types: Dict[Type, Callable] = {}
|
|
129
|
+
self._serializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = []
|
|
130
|
+
self._deserializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = []
|
|
131
|
+
|
|
132
|
+
self._serializer_cache: Dict[Type, Optional[Callable]] = {}
|
|
133
|
+
self._deserializer_cache: Dict[Type, Optional[Callable]] = {}
|
|
134
|
+
|
|
135
|
+
def register_serializer(
|
|
136
|
+
self, condition: Union[Type, Callable[[Any], bool]]
|
|
137
|
+
) -> Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]:
|
|
138
|
+
"""Decorator to register a serializer.
|
|
139
|
+
|
|
140
|
+
The handler function is expected to take a single argument, the object to serialize,
|
|
141
|
+
and return a dictionary representation of that object.
|
|
142
|
+
|
|
143
|
+
Examples:
|
|
144
|
+
|
|
145
|
+
.. code-block:: python
|
|
146
|
+
|
|
147
|
+
@registry.register_serializer(CustomModel)
|
|
148
|
+
def serialize_single_type(value: CustomModel) -> dict:
|
|
149
|
+
return value.to_dict()
|
|
150
|
+
|
|
151
|
+
@registry.register_serializer(lambda x: isinstance(x, BaseModel))
|
|
152
|
+
def serialize_with_condition(value: BaseModel) -> dict:
|
|
153
|
+
return value.to_dict()
|
|
154
|
+
|
|
155
|
+
# Called manually for a specific type
|
|
156
|
+
def custom_serializer(value: CustomModel) -> Dict[str, Any]:
|
|
157
|
+
return {"custom": value.custom}
|
|
158
|
+
|
|
159
|
+
registry.register_serializer(CustomModel)(custom_serializer)
|
|
160
|
+
|
|
161
|
+
:param condition: A type or a callable predicate function that takes an object and returns a bool.
|
|
162
|
+
:type condition: Union[Type, Callable[[Any], bool]]
|
|
163
|
+
:return: A decorator that registers the handler function.
|
|
164
|
+
:rtype: Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]
|
|
165
|
+
:raises TypeError: If the condition is neither a type nor a callable.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
def decorator(handler_func: Callable[[Any], Dict[str, Any]]) -> Callable[[Any], Dict[str, Any]]:
|
|
169
|
+
if isinstance(condition, type):
|
|
170
|
+
self._serializer_types[condition] = handler_func
|
|
171
|
+
elif callable(condition):
|
|
172
|
+
self._serializer_predicates.append((condition, handler_func))
|
|
173
|
+
else:
|
|
174
|
+
raise TypeError("Condition must be a type or a callable predicate function.")
|
|
175
|
+
|
|
176
|
+
self._serializer_cache.clear()
|
|
177
|
+
return handler_func
|
|
178
|
+
|
|
179
|
+
return decorator
|
|
180
|
+
|
|
181
|
+
def register_deserializer(
|
|
182
|
+
self, condition: Union[Type, Callable[[Any], bool]]
|
|
183
|
+
) -> Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]:
|
|
184
|
+
"""Decorator to register a deserializer.
|
|
185
|
+
|
|
186
|
+
The handler function is expected to take two arguments: the target type and the data dictionary,
|
|
187
|
+
and return an instance of the target type.
|
|
188
|
+
|
|
189
|
+
Examples:
|
|
190
|
+
|
|
191
|
+
.. code-block:: python
|
|
192
|
+
|
|
193
|
+
@registry.register_deserializer(CustomModel)
|
|
194
|
+
def deserialize_single_type(cls: Type[CustomModel], data: dict) -> CustomModel:
|
|
195
|
+
return cls(**data)
|
|
196
|
+
|
|
197
|
+
@registry.register_deserializer(lambda t: issubclass(t, BaseModel))
|
|
198
|
+
def deserialize_with_condition(cls: Type[BaseModel], data: dict) -> BaseModel:
|
|
199
|
+
return cls(**data)
|
|
200
|
+
|
|
201
|
+
# Called manually for a specific type
|
|
202
|
+
def custom_deserializer(cls: Type[CustomModel], data: Dict[str, Any]) -> CustomModel:
|
|
203
|
+
return cls(custom=data["custom"])
|
|
204
|
+
|
|
205
|
+
registry.register_deserializer(CustomModel)(custom_deserializer)
|
|
206
|
+
|
|
207
|
+
:param condition: A type or a callable predicate function that takes an object and returns a bool.
|
|
208
|
+
:type condition: Union[Type, Callable[[Any], bool]]
|
|
209
|
+
:return: A decorator that registers the handler function.
|
|
210
|
+
:rtype: Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]
|
|
211
|
+
:raises TypeError: If the condition is neither a type nor a callable.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
def decorator(handler_func: Callable[[Type, Dict[str, Any]], Any]) -> Callable[[Type, Dict[str, Any]], Any]:
|
|
215
|
+
if isinstance(condition, type):
|
|
216
|
+
self._deserializer_types[condition] = handler_func
|
|
217
|
+
elif callable(condition):
|
|
218
|
+
self._deserializer_predicates.append((condition, handler_func))
|
|
219
|
+
else:
|
|
220
|
+
raise TypeError("Condition must be a type or a callable predicate function.")
|
|
221
|
+
|
|
222
|
+
self._deserializer_cache.clear()
|
|
223
|
+
return handler_func
|
|
224
|
+
|
|
225
|
+
return decorator
|
|
226
|
+
|
|
227
|
+
def get_serializer(self, obj: Any) -> Optional[Callable[[Any], Dict[str, Any]]]:
|
|
228
|
+
"""Gets the appropriate serializer for an object.
|
|
229
|
+
|
|
230
|
+
It first checks the type dictionary for a direct type match.
|
|
231
|
+
If no match is found, it iterates through the predicate list to find a match.
|
|
232
|
+
|
|
233
|
+
Results of the lookup are cached for performance based on the object's type.
|
|
234
|
+
|
|
235
|
+
:param obj: The object to serialize.
|
|
236
|
+
:type obj: any
|
|
237
|
+
:return: The serializer function if found, otherwise None.
|
|
238
|
+
:rtype: Optional[Callable[[Any], Dict[str, Any]]]
|
|
239
|
+
"""
|
|
240
|
+
obj_type = type(obj)
|
|
241
|
+
if obj_type in self._serializer_cache:
|
|
242
|
+
return self._serializer_cache[obj_type]
|
|
243
|
+
|
|
244
|
+
handler = self._serializer_types.get(type(obj))
|
|
245
|
+
if not handler:
|
|
246
|
+
for predicate, pred_handler in self._serializer_predicates:
|
|
247
|
+
if predicate(obj):
|
|
248
|
+
handler = pred_handler
|
|
249
|
+
break
|
|
250
|
+
|
|
251
|
+
self._serializer_cache[obj_type] = handler
|
|
252
|
+
return handler
|
|
253
|
+
|
|
254
|
+
def get_deserializer(self, cls: Type) -> Optional[Callable[[Dict[str, Any]], Any]]:
|
|
255
|
+
"""Gets the appropriate deserializer for a class.
|
|
256
|
+
|
|
257
|
+
It first checks the type dictionary for a direct type match.
|
|
258
|
+
If no match is found, it iterates through the predicate list to find a match.
|
|
259
|
+
|
|
260
|
+
Results of the lookup are cached for performance based on the class.
|
|
261
|
+
|
|
262
|
+
:param cls: The class to deserialize.
|
|
263
|
+
:type cls: type
|
|
264
|
+
:return: A deserializer function bound to the specified class that takes a dictionary and returns
|
|
265
|
+
an instance of that class, or None if no deserializer is found.
|
|
266
|
+
:rtype: Optional[Callable[[Dict[str, Any]], Any]]
|
|
267
|
+
"""
|
|
268
|
+
if cls in self._deserializer_cache:
|
|
269
|
+
return self._deserializer_cache[cls]
|
|
270
|
+
|
|
271
|
+
handler = self._deserializer_types.get(cls)
|
|
272
|
+
if not handler:
|
|
273
|
+
for predicate, pred_handler in self._deserializer_predicates:
|
|
274
|
+
if predicate(cls):
|
|
275
|
+
handler = pred_handler
|
|
276
|
+
break
|
|
277
|
+
|
|
278
|
+
self._deserializer_cache[cls] = partial(handler, cls) if handler else None
|
|
279
|
+
return self._deserializer_cache[cls]
|
|
280
|
+
|
|
281
|
+
|
|
114
282
|
class CoreJSONEncoder(JSONEncoder):
|
|
115
283
|
"""A JSON encoder that's capable of serializing datetime objects and bytes."""
|
|
116
284
|
|
|
117
|
-
def default(self, o: Any) -> Any:
|
|
285
|
+
def default(self, o: Any) -> Any:
|
|
286
|
+
"""Override the default method to handle datetime and bytes serialization.
|
|
287
|
+
|
|
288
|
+
:param o: The object to serialize.
|
|
289
|
+
:type o: Any
|
|
290
|
+
:return: A JSON-serializable representation of the object.
|
|
291
|
+
:rtype: Any
|
|
292
|
+
"""
|
|
118
293
|
if isinstance(o, (bytes, bytearray)):
|
|
119
294
|
return base64.b64encode(o).decode()
|
|
120
295
|
try:
|
|
@@ -122,3 +297,62 @@ class CoreJSONEncoder(JSONEncoder):
|
|
|
122
297
|
except AttributeError:
|
|
123
298
|
pass
|
|
124
299
|
return super(CoreJSONEncoder, self).default(o)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def is_generated_model(obj: Any) -> bool:
|
|
303
|
+
"""Check if the object is a generated SDK model.
|
|
304
|
+
|
|
305
|
+
:param obj: The object to check.
|
|
306
|
+
:type obj: any
|
|
307
|
+
:return: True if the object is a generated SDK model, False otherwise.
|
|
308
|
+
:rtype: bool
|
|
309
|
+
"""
|
|
310
|
+
return bool(getattr(obj, "_is_model", False) or hasattr(obj, "_attribute_map"))
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def _get_flattened_attribute(obj: Any) -> Optional[str]:
|
|
314
|
+
"""Get the name of the flattened attribute in a generated TypeSpec model if one exists.
|
|
315
|
+
|
|
316
|
+
:param any obj: The object to check.
|
|
317
|
+
:return: The name of the flattened attribute if it exists, otherwise None.
|
|
318
|
+
:rtype: Optional[str]
|
|
319
|
+
"""
|
|
320
|
+
flattened_items = None
|
|
321
|
+
try:
|
|
322
|
+
flattened_items = getattr(obj, next(a for a in dir(obj) if "__flattened_items" in a), None)
|
|
323
|
+
except StopIteration:
|
|
324
|
+
return None
|
|
325
|
+
|
|
326
|
+
if flattened_items is None:
|
|
327
|
+
return None
|
|
328
|
+
|
|
329
|
+
for k, v in obj._attr_to_rest_field.items():
|
|
330
|
+
try:
|
|
331
|
+
if set(v._class_type._attr_to_rest_field.keys()).intersection(set(flattened_items)):
|
|
332
|
+
return k
|
|
333
|
+
except AttributeError:
|
|
334
|
+
# if the attribute does not have _class_type, it is not a typespec generated model
|
|
335
|
+
continue
|
|
336
|
+
return None
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def attribute_list(obj: Any) -> List[str]:
|
|
340
|
+
"""Get a list of attribute names for a generated SDK model.
|
|
341
|
+
|
|
342
|
+
:param obj: The object to get attributes from.
|
|
343
|
+
:type obj: any
|
|
344
|
+
:return: A list of attribute names.
|
|
345
|
+
:rtype: List[str]
|
|
346
|
+
"""
|
|
347
|
+
if not is_generated_model(obj):
|
|
348
|
+
raise TypeError("Object is not a generated SDK model.")
|
|
349
|
+
if hasattr(obj, "_attribute_map"):
|
|
350
|
+
return list(obj._attribute_map.keys())
|
|
351
|
+
flattened_attribute = _get_flattened_attribute(obj)
|
|
352
|
+
retval: List[str] = []
|
|
353
|
+
for attr_name, rest_field in obj._attr_to_rest_field.items():
|
|
354
|
+
if flattened_attribute == attr_name:
|
|
355
|
+
retval.extend(attribute_list(rest_field._class_type))
|
|
356
|
+
else:
|
|
357
|
+
retval.append(attr_name)
|
|
358
|
+
return retval
|