corehttp 1.0.0b6__py3-none-any.whl → 1.0.0b7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. corehttp/_version.py +1 -1
  2. corehttp/credentials.py +14 -5
  3. corehttp/instrumentation/__init__.py +9 -0
  4. corehttp/instrumentation/tracing/__init__.py +14 -0
  5. corehttp/instrumentation/tracing/_decorator.py +189 -0
  6. corehttp/instrumentation/tracing/_models.py +72 -0
  7. corehttp/instrumentation/tracing/_tracer.py +69 -0
  8. corehttp/instrumentation/tracing/opentelemetry.py +277 -0
  9. corehttp/instrumentation/tracing/utils.py +31 -0
  10. corehttp/paging.py +13 -0
  11. corehttp/rest/_aiohttp.py +17 -5
  12. corehttp/rest/_http_response_impl.py +7 -7
  13. corehttp/rest/_http_response_impl_async.py +2 -0
  14. corehttp/rest/_httpx.py +8 -8
  15. corehttp/rest/_requests_basic.py +13 -5
  16. corehttp/rest/_rest_py3.py +2 -2
  17. corehttp/runtime/pipeline/__init__.py +2 -2
  18. corehttp/runtime/pipeline/_base.py +2 -1
  19. corehttp/runtime/pipeline/_base_async.py +2 -0
  20. corehttp/runtime/pipeline/_tools.py +18 -2
  21. corehttp/runtime/policies/__init__.py +2 -0
  22. corehttp/runtime/policies/_authentication.py +28 -5
  23. corehttp/runtime/policies/_authentication_async.py +22 -3
  24. corehttp/runtime/policies/_distributed_tracing.py +169 -0
  25. corehttp/runtime/policies/_retry.py +7 -11
  26. corehttp/runtime/policies/_retry_async.py +4 -8
  27. corehttp/runtime/policies/_universal.py +11 -0
  28. corehttp/serialization.py +236 -2
  29. corehttp/settings.py +59 -0
  30. corehttp/transport/_base.py +1 -3
  31. corehttp/transport/_base_async.py +1 -3
  32. corehttp/transport/aiohttp/_aiohttp.py +39 -14
  33. corehttp/transport/requests/_requests_basic.py +31 -16
  34. corehttp/utils/_utils.py +2 -1
  35. {corehttp-1.0.0b6.dist-info → corehttp-1.0.0b7.dist-info}/METADATA +52 -6
  36. corehttp-1.0.0b7.dist-info/RECORD +61 -0
  37. {corehttp-1.0.0b6.dist-info → corehttp-1.0.0b7.dist-info}/WHEEL +1 -1
  38. corehttp-1.0.0b6.dist-info/RECORD +0 -52
  39. {corehttp-1.0.0b6.dist-info → corehttp-1.0.0b7.dist-info/licenses}/LICENSE +0 -0
  40. {corehttp-1.0.0b6.dist-info → corehttp-1.0.0b7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,169 @@
1
+ # ------------------------------------
2
+ # Copyright (c) Microsoft Corporation.
3
+ # Licensed under the MIT License.
4
+ # ------------------------------------
5
+ from __future__ import annotations
6
+ import logging
7
+ import urllib.parse
8
+ from typing import Any, Optional, Tuple, Union, Type, Mapping, Dict, TYPE_CHECKING
9
+ from types import TracebackType
10
+
11
+ from ...rest import HttpRequest
12
+ from ...rest._rest_py3 import _HttpResponseBase as SansIOHttpResponse
13
+ from ._base import SansIOHTTPPolicy
14
+ from ...settings import settings
15
+ from ...instrumentation.tracing._models import SpanKind, TracingOptions
16
+ from ...instrumentation.tracing._tracer import get_tracer
17
+
18
+ if TYPE_CHECKING:
19
+ from ..pipeline import PipelineRequest, PipelineResponse
20
+ from opentelemetry.trace import Span
21
+
22
+
23
+ ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
24
+ OptExcInfo = Union[ExcInfo, Tuple[None, None, None]]
25
+
26
+ _LOGGER = logging.getLogger(__name__)
27
+
28
+
29
+ class DistributedHttpTracingPolicy(SansIOHTTPPolicy[HttpRequest, SansIOHttpResponse]):
30
+ """The policy to create tracing spans for API calls.
31
+
32
+ :keyword instrumentation_config: Configuration for the instrumentation providers.
33
+ :type instrumentation_config: dict[str, Any]
34
+ """
35
+
36
+ TRACING_CONTEXT = "TRACING_CONTEXT"
37
+ _SUPPRESSION_TOKEN = "SUPPRESSION_TOKEN"
38
+
39
+ # Attribute names
40
+ _HTTP_RESEND_COUNT = "http.request.resend_count"
41
+ _USER_AGENT_ORIGINAL = "user_agent.original"
42
+ _HTTP_REQUEST_METHOD = "http.request.method"
43
+ _URL_FULL = "url.full"
44
+ _HTTP_RESPONSE_STATUS_CODE = "http.response.status_code"
45
+ _SERVER_ADDRESS = "server.address"
46
+ _SERVER_PORT = "server.port"
47
+ _ERROR_TYPE = "error.type"
48
+
49
+ def __init__( # pylint: disable=unused-argument
50
+ self,
51
+ *,
52
+ instrumentation_config: Optional[Mapping[str, Any]] = None,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ self._instrumentation_config = instrumentation_config
56
+
57
+ def on_request(self, request: PipelineRequest[HttpRequest]) -> None:
58
+ """Starts a span for the network call.
59
+
60
+ :param request: The PipelineRequest object.
61
+ :type request: ~corehttp.runtime.pipeline.PipelineRequest
62
+ """
63
+ ctxt = request.context.options
64
+ try:
65
+ tracing_options: TracingOptions = ctxt.pop("tracing_options", {})
66
+
67
+ # User can explicitly disable tracing for this request.
68
+ user_enabled = tracing_options.get("enabled")
69
+ if user_enabled is False:
70
+ return
71
+
72
+ # If tracing is disabled globally and user didn't explicitly enable it, don't trace.
73
+ if not settings.tracing_enabled and user_enabled is None:
74
+ return
75
+
76
+ config = self._instrumentation_config or {}
77
+ tracer = get_tracer(
78
+ library_name=config.get("library_name"),
79
+ library_version=config.get("library_version"),
80
+ attributes=config.get("attributes"),
81
+ )
82
+ if not tracer:
83
+ _LOGGER.warning(
84
+ "Tracing is enabled, but not able to get an OpenTelemetry tracer. "
85
+ "Please ensure that `opentelemetry-api` is installed."
86
+ )
87
+ return
88
+
89
+ span_name = request.http_request.method
90
+ span = tracer.start_span(
91
+ name=span_name,
92
+ kind=SpanKind.CLIENT,
93
+ attributes=tracing_options.get("attributes"),
94
+ )
95
+
96
+ with tracer.use_span(span, end_on_exit=False):
97
+ trace_context_headers = tracer.get_trace_context()
98
+ request.http_request.headers.update(trace_context_headers)
99
+
100
+ request.context[self.TRACING_CONTEXT] = span
101
+ token = tracer._suppress_auto_http_instrumentation() # pylint: disable=protected-access
102
+ request.context[self._SUPPRESSION_TOKEN] = token
103
+ except Exception as err: # pylint: disable=broad-except
104
+ _LOGGER.warning("Unable to start HTTP span: %s", err) # pylint: disable=do-not-log-exceptions-if-not-debug
105
+
106
+ def on_response(
107
+ self,
108
+ request: PipelineRequest[HttpRequest],
109
+ response: PipelineResponse[HttpRequest, SansIOHttpResponse],
110
+ ) -> None:
111
+ """Ends the span for the network call and updates its status.
112
+
113
+ :param request: The PipelineRequest object.
114
+ :type request: ~corehttp.runtime.pipeline.PipelineRequest
115
+ :param response: The PipelineResponse object.
116
+ :type response: ~corehttp.runtime.pipeline.PipelineResponse
117
+ """
118
+ if self.TRACING_CONTEXT not in request.context:
119
+ return
120
+
121
+ span: Optional["Span"] = request.context[self.TRACING_CONTEXT]
122
+ http_request = request.http_request
123
+ if span:
124
+ self._set_http_client_span_attributes(span, http_request, response=response.http_response)
125
+ if request.context.get("retry_count"):
126
+ span.set_attribute(self._HTTP_RESEND_COUNT, request.context["retry_count"])
127
+ span.end()
128
+
129
+ suppression_token = request.context.get(self._SUPPRESSION_TOKEN)
130
+ if suppression_token:
131
+ tracer = get_tracer()
132
+ if tracer:
133
+ tracer._detach_from_context(suppression_token) # pylint: disable=protected-access
134
+
135
+ def _set_http_client_span_attributes(
136
+ self,
137
+ span: "Span",
138
+ request: HttpRequest,
139
+ response: Optional[SansIOHttpResponse] = None,
140
+ ) -> None:
141
+ """Add attributes to an HTTP client span.
142
+
143
+ :param span: The span to add attributes to.
144
+ :type span: ~opentelemetry.trace.Span
145
+ :param request: The request made
146
+ :type request: ~corehttp.rest.HttpRequest
147
+ :param response: The response received from the server. Is None if no response received.
148
+ :type response: ~corehttp.rest.HttpResponse
149
+ """
150
+ attributes: Dict[str, Any] = {
151
+ self._HTTP_REQUEST_METHOD: request.method,
152
+ self._URL_FULL: request.url,
153
+ }
154
+
155
+ parsed_url = urllib.parse.urlparse(request.url)
156
+ if parsed_url.hostname:
157
+ attributes[self._SERVER_ADDRESS] = parsed_url.hostname
158
+ if parsed_url.port:
159
+ attributes[self._SERVER_PORT] = parsed_url.port
160
+
161
+ user_agent = request.headers.get("User-Agent")
162
+ if user_agent:
163
+ attributes[self._USER_AGENT_ORIGINAL] = user_agent
164
+ if response and response.status_code:
165
+ attributes[self._HTTP_RESPONSE_STATUS_CODE] = response.status_code
166
+ if response.status_code >= 400:
167
+ attributes[self._ERROR_TYPE] = str(response.status_code)
168
+
169
+ span.set_attributes(attributes)
@@ -52,6 +52,8 @@ _LOGGER = logging.getLogger(__name__)
52
52
 
53
53
 
54
54
  class RetryMode(str, Enum, metaclass=CaseInsensitiveEnumMeta):
55
+ """Enum for retry modes."""
56
+
55
57
  # pylint: disable=enum-must-be-uppercase
56
58
  Exponential = "exponential"
57
59
  Fixed = "fixed"
@@ -104,7 +106,7 @@ class RetryPolicyBase:
104
106
  "read": options.pop("retry_read", self.read_retries),
105
107
  "status": options.pop("retry_status", self.status_retries),
106
108
  "backoff": options.pop("retry_backoff_factor", self.backoff_factor),
107
- "max_backoff": options.pop("retry_backoff_max", self.BACKOFF_MAX),
109
+ "max_backoff": options.pop("retry_backoff_max", self.backoff_max),
108
110
  "methods": options.pop("retry_on_methods", self._method_whitelist),
109
111
  "timeout": options.pop("timeout", self.timeout),
110
112
  "history": [],
@@ -394,28 +396,21 @@ class RetryPolicy(RetryPolicyBase, HTTPPolicy[HttpRequest, HttpResponse]):
394
396
 
395
397
  :keyword int retry_total: Total number of retries to allow. Takes precedence over other counts.
396
398
  Default value is 10.
397
-
398
399
  :keyword int retry_connect: How many connection-related errors to retry on.
399
400
  These are errors raised before the request is sent to the remote server,
400
401
  which we assume has not triggered the server to process the request. Default value is 3.
401
-
402
402
  :keyword int retry_read: How many times to retry on read errors.
403
403
  These errors are raised after the request was sent to the server, so the
404
404
  request may have side-effects. Default value is 3.
405
-
406
405
  :keyword int retry_status: How many times to retry on bad status codes. Default value is 3.
407
-
408
406
  :keyword float retry_backoff_factor: A backoff factor to apply between attempts after the second try
409
407
  (most errors are resolved immediately by a second try without a delay).
410
408
  In fixed mode, retry policy will always sleep for {backoff factor}.
411
409
  In 'exponential' mode, retry policy will sleep for: `{backoff factor} * (2 ** ({number of total retries} - 1))`
412
410
  seconds. If the backoff_factor is 0.1, then the retry will sleep
413
411
  for [0.0s, 0.2s, 0.4s, ...] between retries. The default value is 0.8.
414
-
415
412
  :keyword int retry_backoff_max: The maximum back off time. Default value is 120 seconds (2 minutes).
416
-
417
413
  :keyword RetryMode retry_mode: Fixed or exponential delay between attemps, default is exponential.
418
-
419
414
  :keyword int timeout: Timeout setting for the operation in seconds, default is 604800s (7 days).
420
415
  """
421
416
 
@@ -481,10 +476,10 @@ class RetryPolicy(RetryPolicyBase, HTTPPolicy[HttpRequest, HttpResponse]):
481
476
 
482
477
  :param request: The PipelineRequest object
483
478
  :type request: ~corehttp.runtime.pipeline.PipelineRequest
484
- :return: Returns the PipelineResponse or raises error if maximum retries exceeded.
479
+ :return: The PipelineResponse.
485
480
  :rtype: ~corehttp.runtime.pipeline.PipelineResponse
486
- :raises: ~corehttp.exceptions.BaseError if maximum retries exceeded.
487
- :raises: ~corehttp.exceptions.ClientAuthenticationError if authentication
481
+ :raises ~corehttp.exceptions.BaseError: if maximum retries exceeded.
482
+ :raises ~corehttp.exceptions.ClientAuthenticationError: if authentication fails.
488
483
  """
489
484
  retry_active = True
490
485
  response = None
@@ -505,6 +500,7 @@ class RetryPolicy(RetryPolicyBase, HTTPPolicy[HttpRequest, HttpResponse]):
505
500
  )
506
501
  try:
507
502
  self._configure_timeout(request, absolute_timeout, is_response_error)
503
+ request.context["retry_count"] = len(retry_settings["history"])
508
504
  response = self.next.send(request)
509
505
  if self.is_retry(retry_settings, response):
510
506
  retry_active = self.increment(retry_settings, response=response)
@@ -54,23 +54,18 @@ class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HttpRequest, AsyncHttpRe
54
54
 
55
55
  :keyword int retry_total: Total number of retries to allow. Takes precedence over other counts.
56
56
  Default value is 10.
57
-
58
57
  :keyword int retry_connect: How many connection-related errors to retry on.
59
58
  These are errors raised before the request is sent to the remote server,
60
59
  which we assume has not triggered the server to process the request. Default value is 3.
61
-
62
60
  :keyword int retry_read: How many times to retry on read errors.
63
61
  These errors are raised after the request was sent to the server, so the
64
62
  request may have side-effects. Default value is 3.
65
-
66
63
  :keyword int retry_status: How many times to retry on bad status codes. Default value is 3.
67
-
68
64
  :keyword float retry_backoff_factor: A backoff factor to apply between attempts after the second try
69
65
  (most errors are resolved immediately by a second try without a delay).
70
66
  Retry policy will sleep for: `{backoff factor} * (2 ** ({number of total retries} - 1))`
71
67
  seconds. If the backoff_factor is 0.1, then the retry will sleep
72
68
  for [0.0s, 0.2s, 0.4s, ...] between retries. The default value is 0.8.
73
-
74
69
  :keyword int retry_backoff_max: The maximum back off time. Default value is 120 seconds (2 minutes).
75
70
  """
76
71
 
@@ -138,10 +133,10 @@ class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HttpRequest, AsyncHttpRe
138
133
 
139
134
  :param request: The PipelineRequest object
140
135
  :type request: ~corehttp.runtime.pipeline.PipelineRequest
141
- :return: Returns the PipelineResponse or raises error if maximum retries exceeded.
136
+ :return: The PipelineResponse.
142
137
  :rtype: ~corehttp.runtime.pipeline.PipelineResponse
143
- :raise: ~corehttp.exceptions.BaseError if maximum retries exceeded.
144
- :raise: ~corehttp.exceptions.ClientAuthenticationError if authentication fails
138
+ :raises ~corehttp.exceptions.BaseError: if maximum retries exceeded.
139
+ :raises ~corehttp.exceptions.ClientAuthenticationError: if authentication fails.
145
140
  """
146
141
  retry_active = True
147
142
  response = None
@@ -162,6 +157,7 @@ class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HttpRequest, AsyncHttpRe
162
157
  )
163
158
  try:
164
159
  self._configure_timeout(request, absolute_timeout, is_response_error)
160
+ request.context["retry_count"] = len(retry_settings["history"])
165
161
  response = await self.next.send(request)
166
162
  if self.is_retry(retry_settings, response):
167
163
  retry_active = self.increment(retry_settings, response=response)
@@ -143,6 +143,7 @@ class UserAgentPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
143
143
 
144
144
  def add_user_agent(self, value: str) -> None:
145
145
  """Add value to current user agent with a space.
146
+
146
147
  :param str value: value to add to user agent.
147
148
  """
148
149
  self._user_agent = "{} {}".format(self._user_agent, value)
@@ -401,6 +402,11 @@ class ContentDecodePolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
401
402
  return cls.deserialize_from_text(response.text(encoding), mime_type, response=response)
402
403
 
403
404
  def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
405
+ """Set the response encoding in the request context.
406
+
407
+ :param request: The PipelineRequest object.
408
+ :type request: ~corehttp.runtime.pipeline.PipelineRequest
409
+ """
404
410
  options = request.context.options
405
411
  response_encoding = options.pop("response_encoding", self._response_encoding)
406
412
  if response_encoding:
@@ -452,6 +458,11 @@ class ProxyPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
452
458
  self.proxies = proxies
453
459
 
454
460
  def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
461
+ """Adds the proxy information to the request context.
462
+
463
+ :param request: The PipelineRequest object.
464
+ :type request: ~corehttp.runtime.pipeline.PipelineRequest
465
+ """
455
466
  ctxt = request.context.options
456
467
  if self.proxies and "proxies" not in ctxt:
457
468
  ctxt["proxies"] = self.proxies
corehttp/serialization.py CHANGED
@@ -4,14 +4,23 @@
4
4
  # Licensed under the MIT License. See License.txt in the project root for
5
5
  # license information.
6
6
  # --------------------------------------------------------------------------
7
+ # pylint: disable=protected-access
7
8
  import base64
9
+ from functools import partial
8
10
  from json import JSONEncoder
9
- from typing import Union, cast, Any
11
+ from typing import Dict, List, Optional, Union, cast, Any, Type, Callable, Tuple
10
12
  from datetime import datetime, date, time, timedelta
11
13
  from datetime import timezone
12
14
 
13
15
 
14
- __all__ = ["NULL", "CoreJSONEncoder"]
16
+ __all__ = [
17
+ "NULL",
18
+ "CoreJSONEncoder",
19
+ "is_generated_model",
20
+ "attribute_list",
21
+ "TypeHandlerRegistry",
22
+ ]
23
+ TZ_UTC = timezone.utc
15
24
 
16
25
 
17
26
  class _Null:
@@ -111,10 +120,176 @@ def _datetime_as_isostr(dt: Union[datetime, date, time, timedelta]) -> str:
111
120
  return _timedelta_as_isostr(dt)
112
121
 
113
122
 
123
+ class TypeHandlerRegistry:
124
+ """A registry for custom serializers and deserializers for specific types or conditions."""
125
+
126
+ def __init__(self) -> None:
127
+ self._serializer_types: Dict[Type, Callable] = {}
128
+ self._deserializer_types: Dict[Type, Callable] = {}
129
+ self._serializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = []
130
+ self._deserializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = []
131
+
132
+ self._serializer_cache: Dict[Type, Optional[Callable]] = {}
133
+ self._deserializer_cache: Dict[Type, Optional[Callable]] = {}
134
+
135
+ def register_serializer(
136
+ self, condition: Union[Type, Callable[[Any], bool]]
137
+ ) -> Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]:
138
+ """Decorator to register a serializer.
139
+
140
+ The handler function is expected to take a single argument, the object to serialize,
141
+ and return a dictionary representation of that object.
142
+
143
+ Examples:
144
+
145
+ .. code-block:: python
146
+
147
+ @registry.register_serializer(CustomModel)
148
+ def serialize_single_type(value: CustomModel) -> dict:
149
+ return value.to_dict()
150
+
151
+ @registry.register_serializer(lambda x: isinstance(x, BaseModel))
152
+ def serialize_with_condition(value: BaseModel) -> dict:
153
+ return value.to_dict()
154
+
155
+ # Called manually for a specific type
156
+ def custom_serializer(value: CustomModel) -> Dict[str, Any]:
157
+ return {"custom": value.custom}
158
+
159
+ registry.register_serializer(CustomModel)(custom_serializer)
160
+
161
+ :param condition: A type or a callable predicate function that takes an object and returns a bool.
162
+ :type condition: Union[Type, Callable[[Any], bool]]
163
+ :return: A decorator that registers the handler function.
164
+ :rtype: Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]
165
+ :raises TypeError: If the condition is neither a type nor a callable.
166
+ """
167
+
168
+ def decorator(handler_func: Callable[[Any], Dict[str, Any]]) -> Callable[[Any], Dict[str, Any]]:
169
+ if isinstance(condition, type):
170
+ self._serializer_types[condition] = handler_func
171
+ elif callable(condition):
172
+ self._serializer_predicates.append((condition, handler_func))
173
+ else:
174
+ raise TypeError("Condition must be a type or a callable predicate function.")
175
+
176
+ self._serializer_cache.clear()
177
+ return handler_func
178
+
179
+ return decorator
180
+
181
+ def register_deserializer(
182
+ self, condition: Union[Type, Callable[[Any], bool]]
183
+ ) -> Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]:
184
+ """Decorator to register a deserializer.
185
+
186
+ The handler function is expected to take two arguments: the target type and the data dictionary,
187
+ and return an instance of the target type.
188
+
189
+ Examples:
190
+
191
+ .. code-block:: python
192
+
193
+ @registry.register_deserializer(CustomModel)
194
+ def deserialize_single_type(cls: Type[CustomModel], data: dict) -> CustomModel:
195
+ return cls(**data)
196
+
197
+ @registry.register_deserializer(lambda t: issubclass(t, BaseModel))
198
+ def deserialize_with_condition(cls: Type[BaseModel], data: dict) -> BaseModel:
199
+ return cls(**data)
200
+
201
+ # Called manually for a specific type
202
+ def custom_deserializer(cls: Type[CustomModel], data: Dict[str, Any]) -> CustomModel:
203
+ return cls(custom=data["custom"])
204
+
205
+ registry.register_deserializer(CustomModel)(custom_deserializer)
206
+
207
+ :param condition: A type or a callable predicate function that takes an object and returns a bool.
208
+ :type condition: Union[Type, Callable[[Any], bool]]
209
+ :return: A decorator that registers the handler function.
210
+ :rtype: Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]
211
+ :raises TypeError: If the condition is neither a type nor a callable.
212
+ """
213
+
214
+ def decorator(handler_func: Callable[[Type, Dict[str, Any]], Any]) -> Callable[[Type, Dict[str, Any]], Any]:
215
+ if isinstance(condition, type):
216
+ self._deserializer_types[condition] = handler_func
217
+ elif callable(condition):
218
+ self._deserializer_predicates.append((condition, handler_func))
219
+ else:
220
+ raise TypeError("Condition must be a type or a callable predicate function.")
221
+
222
+ self._deserializer_cache.clear()
223
+ return handler_func
224
+
225
+ return decorator
226
+
227
+ def get_serializer(self, obj: Any) -> Optional[Callable[[Any], Dict[str, Any]]]:
228
+ """Gets the appropriate serializer for an object.
229
+
230
+ It first checks the type dictionary for a direct type match.
231
+ If no match is found, it iterates through the predicate list to find a match.
232
+
233
+ Results of the lookup are cached for performance based on the object's type.
234
+
235
+ :param obj: The object to serialize.
236
+ :type obj: any
237
+ :return: The serializer function if found, otherwise None.
238
+ :rtype: Optional[Callable[[Any], Dict[str, Any]]]
239
+ """
240
+ obj_type = type(obj)
241
+ if obj_type in self._serializer_cache:
242
+ return self._serializer_cache[obj_type]
243
+
244
+ handler = self._serializer_types.get(type(obj))
245
+ if not handler:
246
+ for predicate, pred_handler in self._serializer_predicates:
247
+ if predicate(obj):
248
+ handler = pred_handler
249
+ break
250
+
251
+ self._serializer_cache[obj_type] = handler
252
+ return handler
253
+
254
+ def get_deserializer(self, cls: Type) -> Optional[Callable[[Dict[str, Any]], Any]]:
255
+ """Gets the appropriate deserializer for a class.
256
+
257
+ It first checks the type dictionary for a direct type match.
258
+ If no match is found, it iterates through the predicate list to find a match.
259
+
260
+ Results of the lookup are cached for performance based on the class.
261
+
262
+ :param cls: The class to deserialize.
263
+ :type cls: type
264
+ :return: A deserializer function bound to the specified class that takes a dictionary and returns
265
+ an instance of that class, or None if no deserializer is found.
266
+ :rtype: Optional[Callable[[Dict[str, Any]], Any]]
267
+ """
268
+ if cls in self._deserializer_cache:
269
+ return self._deserializer_cache[cls]
270
+
271
+ handler = self._deserializer_types.get(cls)
272
+ if not handler:
273
+ for predicate, pred_handler in self._deserializer_predicates:
274
+ if predicate(cls):
275
+ handler = pred_handler
276
+ break
277
+
278
+ self._deserializer_cache[cls] = partial(handler, cls) if handler else None
279
+ return self._deserializer_cache[cls]
280
+
281
+
114
282
  class CoreJSONEncoder(JSONEncoder):
115
283
  """A JSON encoder that's capable of serializing datetime objects and bytes."""
116
284
 
117
285
  def default(self, o: Any) -> Any:
286
+ """Override the default method to handle datetime and bytes serialization.
287
+
288
+ :param o: The object to serialize.
289
+ :type o: Any
290
+ :return: A JSON-serializable representation of the object.
291
+ :rtype: Any
292
+ """
118
293
  if isinstance(o, (bytes, bytearray)):
119
294
  return base64.b64encode(o).decode()
120
295
  try:
@@ -122,3 +297,62 @@ class CoreJSONEncoder(JSONEncoder):
122
297
  except AttributeError:
123
298
  pass
124
299
  return super(CoreJSONEncoder, self).default(o)
300
+
301
+
302
+ def is_generated_model(obj: Any) -> bool:
303
+ """Check if the object is a generated SDK model.
304
+
305
+ :param obj: The object to check.
306
+ :type obj: any
307
+ :return: True if the object is a generated SDK model, False otherwise.
308
+ :rtype: bool
309
+ """
310
+ return bool(getattr(obj, "_is_model", False) or hasattr(obj, "_attribute_map"))
311
+
312
+
313
+ def _get_flattened_attribute(obj: Any) -> Optional[str]:
314
+ """Get the name of the flattened attribute in a generated TypeSpec model if one exists.
315
+
316
+ :param any obj: The object to check.
317
+ :return: The name of the flattened attribute if it exists, otherwise None.
318
+ :rtype: Optional[str]
319
+ """
320
+ flattened_items = None
321
+ try:
322
+ flattened_items = getattr(obj, next(a for a in dir(obj) if "__flattened_items" in a), None)
323
+ except StopIteration:
324
+ return None
325
+
326
+ if flattened_items is None:
327
+ return None
328
+
329
+ for k, v in obj._attr_to_rest_field.items():
330
+ try:
331
+ if set(v._class_type._attr_to_rest_field.keys()).intersection(set(flattened_items)):
332
+ return k
333
+ except AttributeError:
334
+ # if the attribute does not have _class_type, it is not a typespec generated model
335
+ continue
336
+ return None
337
+
338
+
339
+ def attribute_list(obj: Any) -> List[str]:
340
+ """Get a list of attribute names for a generated SDK model.
341
+
342
+ :param obj: The object to get attributes from.
343
+ :type obj: any
344
+ :return: A list of attribute names.
345
+ :rtype: List[str]
346
+ """
347
+ if not is_generated_model(obj):
348
+ raise TypeError("Object is not a generated SDK model.")
349
+ if hasattr(obj, "_attribute_map"):
350
+ return list(obj._attribute_map.keys())
351
+ flattened_attribute = _get_flattened_attribute(obj)
352
+ retval: List[str] = []
353
+ for attr_name, rest_field in obj._attr_to_rest_field.items():
354
+ if flattened_attribute == attr_name:
355
+ retval.extend(attribute_list(rest_field._class_type))
356
+ else:
357
+ retval.append(attr_name)
358
+ return retval
corehttp/settings.py ADDED
@@ -0,0 +1,59 @@
1
+ # ------------------------------------
2
+ # Copyright (c) Microsoft Corporation.
3
+ # Licensed under the MIT License.
4
+ # ------------------------------------
5
+ import os
6
+ from typing import Union
7
+
8
+
9
+ def _convert_bool(value: Union[str, bool]) -> bool:
10
+ """Convert a string to True or False
11
+
12
+ If a boolean is passed in, it is returned as-is. Otherwise the function
13
+ maps the following strings, ignoring case:
14
+
15
+ * "yes", "1", "on" -> True
16
+ * "no", "0", "off" -> False
17
+
18
+ :param value: the value to convert
19
+ :type value: str or bool
20
+ :returns: A boolean value matching the intent of the input
21
+ :rtype: bool
22
+ :raises ValueError: If conversion to bool fails
23
+
24
+ """
25
+ if isinstance(value, bool):
26
+ return value
27
+ val = value.lower()
28
+ if val in ["yes", "1", "on", "true", "True"]:
29
+ return True
30
+ if val in ["no", "0", "off", "false", "False"]:
31
+ return False
32
+ raise ValueError("Cannot convert {} to boolean value".format(value))
33
+
34
+
35
+ class Settings:
36
+ """Global settings for the SDK."""
37
+
38
+ def __init__(self) -> None:
39
+ self._tracing_enabled: bool = _convert_bool(os.environ.get("SDK_TRACING_ENABLED", False))
40
+
41
+ @property
42
+ def tracing_enabled(self) -> bool:
43
+ """Whether tracing for SDKs is enabled.
44
+
45
+ :return: True if tracing is enabled, False otherwise.
46
+ :rtype: bool
47
+ """
48
+ return self._tracing_enabled
49
+
50
+ @tracing_enabled.setter
51
+ def tracing_enabled(self, value: bool):
52
+ self._tracing_enabled = _convert_bool(value)
53
+
54
+
55
+ settings: Settings = Settings()
56
+ """The settings global instance.
57
+
58
+ :type settings: Settings
59
+ """
@@ -80,10 +80,8 @@ def _handle_non_stream_rest_response(response: HttpResponse) -> None:
80
80
  """
81
81
  try:
82
82
  response.read()
83
+ finally:
83
84
  response.close()
84
- except Exception as exc:
85
- response.close()
86
- raise exc
87
85
 
88
86
 
89
87
  class HttpTransport(ContextManager["HttpTransport"], abc.ABC, Generic[HTTPRequestType, HTTPResponseType]):
@@ -46,10 +46,8 @@ async def _handle_non_stream_rest_response(response: AsyncHttpResponse) -> None:
46
46
  """
47
47
  try:
48
48
  await response.read()
49
+ finally:
49
50
  await response.close()
50
- except Exception as exc:
51
- await response.close()
52
- raise exc
53
51
 
54
52
 
55
53
  class _ResponseStopIteration(Exception):