@autorest/python 6.7.0 → 6.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,6 +28,7 @@ from .primitive_types import (
28
28
  BooleanType,
29
29
  AnyObjectType,
30
30
  UnixTimeType,
31
+ AzureCoreType,
31
32
  )
32
33
  from .enum_type import EnumType
33
34
  from .base import BaseType
@@ -144,6 +145,7 @@ TYPE_TO_OBJECT = {
144
145
  "any-object": AnyObjectType,
145
146
  "unixtime": UnixTimeType,
146
147
  "credential": StringType,
148
+ "azurecore": AzureCoreType,
147
149
  }
148
150
  _LOGGER = logging.getLogger(__name__)
149
151
 
@@ -76,6 +76,8 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
76
76
  for op_group in self.yaml_data.get("operationGroups", [])
77
77
  ]
78
78
  self.link_lro_initial_operations()
79
+ self.request_id_header_name = self.yaml_data.get("requestIdHeaderName", None)
80
+ self.has_etag: bool = yaml_data.get("hasEtag", False)
79
81
 
80
82
  def _build_request_builders(
81
83
  self,
@@ -15,10 +15,10 @@ from .constant_type import ConstantType
15
15
 
16
16
 
17
17
  def _is_legacy(options) -> bool:
18
- return not (options["version_tolerant"] or options["low_level_client"])
18
+ return not (options.get("version_tolerant") or options.get("low_level_client"))
19
19
 
20
20
 
21
- class CodeModel: # pylint: disable=too-many-public-methods
21
+ class CodeModel: # pylint: disable=too-many-public-methods, disable=too-many-instance-attributes
22
22
  """Top level code model
23
23
 
24
24
  :param options: Options of the code model. I.e., whether this is for management generation
@@ -77,6 +77,10 @@ class CodeModel: # pylint: disable=too-many-public-methods
77
77
  t for t in self.types_map.values() if isinstance(t, CombinedType) and t.name
78
78
  ]
79
79
 
80
+ @property
81
+ def has_etag(self) -> bool:
82
+ return any(client.has_etag for client in self.clients)
83
+
80
84
  @property
81
85
  def has_operations(self) -> bool:
82
86
  if any(c for c in self.clients if c.has_operations):
@@ -30,6 +30,10 @@ class DictionaryType(BaseType):
30
30
  super().__init__(yaml_data=yaml_data, code_model=code_model)
31
31
  self.element_type = element_type
32
32
 
33
+ @property
34
+ def format(self) -> Optional[str]:
35
+ return self.element_type.format if hasattr(self.element_type, "format") else None # type: ignore
36
+
33
37
  @property
34
38
  def serialization_type(self) -> str:
35
39
  """Returns the serialization value for msrest.
@@ -25,6 +25,10 @@ class ListType(BaseType):
25
25
  self.min_items: Optional[int] = yaml_data.get("minItems")
26
26
  self.unique_items: bool = yaml_data.get("uniqueItems", False)
27
27
 
28
+ @property
29
+ def format(self) -> Optional[str]:
30
+ return self.element_type.format if hasattr(self.element_type, "format") else None # type: ignore
31
+
28
32
  @property
29
33
  def serialization_type(self) -> str:
30
34
  return f"[{self.element_type.serialization_type}]"
@@ -90,6 +90,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
90
90
  self.internal: bool = self.yaml_data.get("internal", False)
91
91
  if self.internal:
92
92
  self.name = "_" + self.name
93
+ self.has_etag: bool = self.yaml_data.get("hasEtag", False)
93
94
 
94
95
  @property
95
96
  def expose_stream_keyword(self) -> bool:
@@ -331,7 +332,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
331
332
  )
332
333
  return file_import
333
334
 
334
- def imports( # pylint: disable=too-many-branches
335
+ def imports( # pylint: disable=too-many-branches, disable=too-many-statements
335
336
  self, async_mode: bool, **kwargs: Any
336
337
  ) -> FileImport:
337
338
  if self.abstract:
@@ -394,11 +395,22 @@ class OperationBase( # pylint: disable=too-many-public-methods
394
395
  if self.deprecated:
395
396
  file_import.add_import("warnings", ImportType.STDLIB)
396
397
 
398
+ relative_path = "..." if async_mode else ".."
397
399
  if self.code_model.need_request_converter:
398
- relative_path = "..." if async_mode else ".."
399
400
  file_import.add_submodule_import(
400
401
  f"{relative_path}_vendor", "_convert_request", ImportType.LOCAL
401
402
  )
403
+ if self.has_etag:
404
+ file_import.add_submodule_import(
405
+ "azure.core.exceptions", "ResourceModifiedError", ImportType.AZURECORE
406
+ )
407
+ if not async_mode:
408
+ file_import.add_submodule_import(
409
+ f"{relative_path}_vendor", "prep_if_match", ImportType.LOCAL
410
+ )
411
+ file_import.add_submodule_import(
412
+ f"{relative_path}_vendor", "prep_if_none_match", ImportType.LOCAL
413
+ )
402
414
  if self.code_model.need_request_converter:
403
415
  if async_mode:
404
416
  file_import.add_submodule_import(
@@ -52,9 +52,6 @@ class ParameterDelimeter(str, Enum):
52
52
  COMMA = "comma"
53
53
 
54
54
 
55
- SPECIAL_HANDLE_HEADERS = ["repeatability-request-id", "repeatability-first-sent"]
56
-
57
-
58
55
  class _ParameterBase(
59
56
  BaseModel, abc.ABC
60
57
  ): # pylint: disable=too-many-instance-attributes
@@ -93,10 +90,7 @@ class _ParameterBase(
93
90
  self.default_to_unset_sentinel: bool = self.yaml_data.get(
94
91
  "defaultToUnsetSentinel", False
95
92
  )
96
- self.is_special_handle_header: bool = (
97
- self.location == ParameterLocation.HEADER
98
- and self.wire_name.lower() in SPECIAL_HANDLE_HEADERS
99
- )
93
+ self.hide_in_method: bool = self.yaml_data.get("hideInMethod", False)
100
94
 
101
95
  @property
102
96
  def constant(self) -> bool:
@@ -215,7 +215,7 @@ class _ParameterListBase(
215
215
  for p in self.parameters
216
216
  if p.in_method_signature
217
217
  and p.implementation == self.implementation
218
- and not p.is_special_handle_header
218
+ and (self.code_model.is_legacy or not p.hide_in_method)
219
219
  ]
220
220
  if self._body_parameter:
221
221
  if self._body_parameter.in_method_signature:
@@ -4,7 +4,6 @@
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
6
  import datetime
7
- from enum import Enum
8
7
  from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
9
8
 
10
9
  from .base import BaseType
@@ -356,17 +355,17 @@ class StringType(PrimitiveType):
356
355
  class DatetimeType(PrimitiveType):
357
356
  def __init__(self, yaml_data: Dict[str, Any], code_model: "CodeModel") -> None:
358
357
  super().__init__(yaml_data=yaml_data, code_model=code_model)
359
- self.format = self.Formats(yaml_data.get("format", "date-time"))
360
-
361
- class Formats(str, Enum):
362
- datetime = "date-time"
363
- rfc1123 = "date-time-rfc1123"
358
+ self.format = (
359
+ "rfc3339"
360
+ if yaml_data.get("format", "date-time") == "date-time"
361
+ else "rfc7231"
362
+ )
364
363
 
365
364
  @property
366
365
  def serialization_type(self) -> str:
367
366
  formats_to_attribute_type = {
368
- self.Formats.datetime: "iso-8601",
369
- self.Formats.rfc1123: "rfc-1123",
367
+ "rfc3339": "iso-8601",
368
+ "rfc7231": "rfc-1123",
370
369
  }
371
370
  return formats_to_attribute_type[self.format]
372
371
 
@@ -454,6 +453,10 @@ class TimeType(PrimitiveType):
454
453
 
455
454
 
456
455
  class UnixTimeType(PrimitiveType):
456
+ @property
457
+ def format(self) -> str:
458
+ return "unix-timestamp"
459
+
457
460
  @property
458
461
  def serialization_type(self) -> str:
459
462
  return "unix-time"
@@ -588,7 +591,7 @@ class DurationType(PrimitiveType):
588
591
  class ByteArraySchema(PrimitiveType):
589
592
  def __init__(self, yaml_data: Dict[str, Any], code_model: "CodeModel") -> None:
590
593
  super().__init__(yaml_data=yaml_data, code_model=code_model)
591
- self.format = yaml_data.get("format", "bytes")
594
+ self.format = yaml_data.get("format", "base64")
592
595
 
593
596
  @property
594
597
  def serialization_type(self) -> str:
@@ -605,3 +608,28 @@ class ByteArraySchema(PrimitiveType):
605
608
  @property
606
609
  def instance_check_template(self) -> str:
607
610
  return "isinstance({}, bytes)"
611
+
612
+
613
+ class AzureCoreType(PrimitiveType):
614
+ def __init__(self, yaml_data: Dict[str, Any], code_model: "CodeModel") -> None:
615
+ super().__init__(yaml_data=yaml_data, code_model=code_model)
616
+ self.name = yaml_data.get("name", "")
617
+
618
+ def docstring_type(self, **kwargs: Any) -> str:
619
+ return "~azure.core." + self.type_annotation(**kwargs)
620
+
621
+ def type_annotation(self, **kwargs: Any) -> str:
622
+ return self.name
623
+
624
+ def imports(self, **kwargs: Any) -> FileImport:
625
+ file_import = FileImport()
626
+ file_import.add_submodule_import("azure.core", self.name, ImportType.AZURECORE)
627
+ return file_import
628
+
629
+ @property
630
+ def instance_check_template(self) -> str:
631
+ return f"isinstance({{}}, {self.name})"
632
+
633
+ @property
634
+ def serialization_type(self) -> str:
635
+ return self.name
@@ -52,6 +52,7 @@ _PACKAGE_FILES = [
52
52
 
53
53
  _REGENERATE_FILES = {"setup.py", "MANIFEST.in"}
54
54
 
55
+
55
56
  # extract sub folders. For example, source_file_path is like:
56
57
  # "xxx/resource-manager/Microsoft.XX/stable/2023-04-01/examples/Compute/createOrUpdate/AKSCompute.json",
57
58
  # and we want to extract the sub folders after "examples/", which is "compute/create_or_update"
@@ -18,7 +18,6 @@ from ..models import (
18
18
  ModelType,
19
19
  DictionaryType,
20
20
  ListType,
21
- Parameter,
22
21
  RequestBuilder,
23
22
  ParameterLocation,
24
23
  Response,
@@ -422,36 +421,6 @@ class _BuilderBaseSerializer(Generic[BuilderType]): # pylint: disable=abstract-
422
421
  )
423
422
  return template
424
423
 
425
- def _serialize_parameter(self, param: Parameter, kwarg_name: str) -> List[str]:
426
- set_parameter = "_{}['{}'] = {}".format(
427
- kwarg_name,
428
- param.wire_name,
429
- self.parameter_serializer.serialize_parameter(param, self.serializer_name),
430
- )
431
- if not param.optional:
432
- retval = [set_parameter]
433
- else:
434
- retval = [
435
- f"if {param.full_client_name} is not None:",
436
- f" {set_parameter}",
437
- ]
438
- return retval
439
-
440
- @staticmethod
441
- def _serialize_special_handle_header(param: Parameter) -> List[str]:
442
- if param.wire_name.lower() == "repeatability-request-id":
443
- return [
444
- """if "Repeatability-Request-ID" not in _headers:""",
445
- """ _headers["Repeatability-Request-ID"] = str(uuid.uuid4())""",
446
- ]
447
- if param.wire_name.lower() == "repeatability-first-sent":
448
- return [
449
- """if "Repeatability-First-Sent" not in _headers:""",
450
- """ _headers["Repeatability-First-Sent"] = _SERIALIZER.serialize_data(datetime.datetime.now(),
451
- "rfc-1123")""",
452
- ]
453
- raise ValueError(f"Unsupported special header: {param}")
454
-
455
424
  def serialize_path(self, builder: BuilderType) -> List[str]:
456
425
  return self.parameter_serializer.serialize_path(
457
426
  builder.parameters.path, self.serializer_name
@@ -563,24 +532,25 @@ class RequestBuilderSerializer(
563
532
  def serialize_headers(self, builder: RequestBuilderType) -> List[str]:
564
533
  retval = ["# Construct headers"]
565
534
  for parameter in builder.parameters.headers:
566
- if parameter.is_special_handle_header:
567
- retval.extend(self._serialize_special_handle_header(parameter))
568
- else:
569
- retval.extend(
570
- self._serialize_parameter(
571
- parameter,
572
- kwarg_name="headers",
573
- )
535
+ retval.extend(
536
+ self.parameter_serializer.serialize_query_header(
537
+ parameter,
538
+ "headers",
539
+ self.serializer_name,
540
+ self.code_model.is_legacy,
574
541
  )
542
+ )
575
543
  return retval
576
544
 
577
545
  def serialize_query(self, builder: RequestBuilderType) -> List[str]:
578
546
  retval = ["# Construct parameters"]
579
547
  for parameter in builder.parameters.query:
580
548
  retval.extend(
581
- self._serialize_parameter(
549
+ self.parameter_serializer.serialize_query_header(
582
550
  parameter,
583
- kwarg_name="params",
551
+ "params",
552
+ self.serializer_name,
553
+ self.code_model.is_legacy,
584
554
  )
585
555
  )
586
556
  return retval
@@ -1119,9 +1089,17 @@ class _OperationSerializer(
1119
1089
  return retval
1120
1090
 
1121
1091
  def handle_error_response(self, builder: OperationType) -> List[str]:
1092
+ async_await = "await " if self.async_mode else ""
1122
1093
  retval = [
1123
1094
  f"if response.status_code not in {str(builder.success_status_codes)}:"
1124
1095
  ]
1096
+ if not self.code_model.need_request_converter:
1097
+ retval.extend(
1098
+ [
1099
+ " if _stream:",
1100
+ f" {async_await} response.read() # Load the body in memory and close the socket",
1101
+ ]
1102
+ )
1125
1103
  retval.append(
1126
1104
  " map_error(status_code=response.status_code, response=response, error_map=error_map)"
1127
1105
  )
@@ -1283,6 +1261,17 @@ class _OperationSerializer(
1283
1261
  "304: ResourceNotModifiedError"
1284
1262
  )
1285
1263
  retval.append("}")
1264
+ if builder.has_etag:
1265
+ retval.extend(
1266
+ [
1267
+ "if match_condition == MatchConditions.IfNotModified:",
1268
+ " error_map[412] = ResourceModifiedError",
1269
+ "elif match_condition == MatchConditions.IfPresent:",
1270
+ " error_map[412] = ResourceNotFoundError",
1271
+ "elif match_condition == MatchConditions.IfMissing:",
1272
+ " error_map[412] = ResourceExistsError",
1273
+ ]
1274
+ )
1286
1275
  retval.append("error_map.update(kwargs.pop('error_map', {}) or {})")
1287
1276
  return retval
1288
1277
 
@@ -109,9 +109,15 @@ class ClientSerializer:
109
109
 
110
110
  def initialize_pipeline_client(self, async_mode: bool) -> str:
111
111
  pipeline_client_name = self.client.pipeline_class(async_mode)
112
+ params = {
113
+ "base_url": self.host_variable_name,
114
+ "config": "self._config",
115
+ }
116
+ if not self.client.code_model.is_legacy and self.client.request_id_header_name:
117
+ params["request_id_header_name"] = f'"{self.client.request_id_header_name}"'
112
118
  return (
113
- f"self._client: {pipeline_client_name} = {pipeline_client_name}(base_url={self.host_variable_name}, "
114
- "config=self._config, **kwargs)"
119
+ f"self._client: {pipeline_client_name} = {pipeline_client_name}("
120
+ f"{', '.join(f'{k}={v}' for k, v in params.items())}, **kwargs)"
115
121
  )
116
122
 
117
123
  def serializers_and_operation_groups_properties(self) -> List[str]:
@@ -133,6 +133,13 @@ class GeneralSerializer:
133
133
  f"{client.name}Configuration",
134
134
  ImportType.LOCAL,
135
135
  )
136
+ if self.code_model.has_etag:
137
+ file_import.add_submodule_import("typing", "Optional", ImportType.STDLIB)
138
+ file_import.add_submodule_import(
139
+ "azure.core",
140
+ "MatchConditions",
141
+ ImportType.AZURECORE,
142
+ )
136
143
 
137
144
  return template.render(
138
145
  code_model=self.code_model,
@@ -255,13 +255,13 @@ class DpgModelSerializer(_ModelSerializer):
255
255
  args = []
256
256
  if prop.client_name != prop.wire_name or prop.is_discriminator:
257
257
  args.append(f'name="{prop.wire_name}"')
258
- if prop.readonly:
259
- args.append("readonly=True")
260
258
  if prop.visibility:
261
259
  v_list = ", ".join(f'"{x}"' for x in prop.visibility)
262
260
  args.append(f"visibility=[{v_list}]")
263
261
  if prop.client_default_value is not None:
264
262
  args.append(f"default={prop.client_default_value_declaration}")
263
+ if hasattr(prop.type, "format") and prop.type.format: # type: ignore
264
+ args.append(f'format="{prop.type.format}"') # type: ignore
265
265
 
266
266
  field = "rest_discriminator" if prop.is_discriminator else "rest_field"
267
267
  type_ignore = (
@@ -3,10 +3,9 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import List, Sequence, Union, Optional
6
+ from typing import List, Sequence, Union, Optional, Dict
7
7
  from enum import Enum, auto
8
8
 
9
-
10
9
  from ..models import (
11
10
  Parameter,
12
11
  ParameterLocation,
@@ -26,6 +25,32 @@ class PopKwargType(Enum):
26
25
  CASE_INSENSITIVE = auto()
27
26
 
28
27
 
28
+ SPECIAL_HEADER_SERIALIZATION: Dict[str, List[str]] = {
29
+ "repeatability-request-id": [
30
+ """if "Repeatability-Request-ID" not in _headers:""",
31
+ """ _headers["Repeatability-Request-ID"] = str(uuid.uuid4())""",
32
+ ],
33
+ "repeatability-first-sent": [
34
+ """if "Repeatability-First-Sent" not in _headers:""",
35
+ """ _headers["Repeatability-First-Sent"] = _SERIALIZER.serialize_data(datetime.datetime.now(),
36
+ "rfc-1123")""",
37
+ ],
38
+ "client-request-id": [],
39
+ "x-ms-client-request-id": [],
40
+ "return-client-request-id": [],
41
+ "etag": [
42
+ """if_match = prep_if_match(etag, match_condition)""",
43
+ """if if_match is not None:""",
44
+ """ _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str")""",
45
+ ],
46
+ "match-condition": [
47
+ """if_none_match = prep_if_none_match(etag, match_condition)""",
48
+ """if if_none_match is not None:""",
49
+ """ _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str")""",
50
+ ],
51
+ }
52
+
53
+
29
54
  class ParameterSerializer:
30
55
  @staticmethod
31
56
  def serialize_parameter(parameter: ParameterType, serializer_name: str) -> str:
@@ -81,8 +106,8 @@ class ParameterSerializer:
81
106
  return f"[{serialize_line} if q is not None else '' for q in {origin_name}]"
82
107
  return serialize_line
83
108
 
109
+ @staticmethod
84
110
  def serialize_path(
85
- self,
86
111
  parameters: Union[
87
112
  List[Parameter],
88
113
  List[RequestBuilderParameter],
@@ -96,7 +121,9 @@ class ParameterSerializer:
96
121
  [
97
122
  ' "{}": {},'.format(
98
123
  path_parameter.wire_name,
99
- self.serialize_parameter(path_parameter, serializer_name),
124
+ ParameterSerializer.serialize_parameter(
125
+ path_parameter, serializer_name
126
+ ),
100
127
  )
101
128
  for path_parameter in parameters
102
129
  ]
@@ -104,6 +131,34 @@ class ParameterSerializer:
104
131
  retval.append("}")
105
132
  return retval
106
133
 
134
+ @staticmethod
135
+ def serialize_query_header(
136
+ param: Parameter,
137
+ kwarg_name: str,
138
+ serializer_name: str,
139
+ is_legacy: bool,
140
+ ) -> List[str]:
141
+ if (
142
+ not is_legacy
143
+ and param.location == ParameterLocation.HEADER
144
+ and param.wire_name.lower() in SPECIAL_HEADER_SERIALIZATION
145
+ ):
146
+ return SPECIAL_HEADER_SERIALIZATION[param.wire_name.lower()]
147
+
148
+ set_parameter = "_{}['{}'] = {}".format(
149
+ kwarg_name,
150
+ param.wire_name,
151
+ ParameterSerializer.serialize_parameter(param, serializer_name),
152
+ )
153
+ if not param.optional:
154
+ retval = [set_parameter]
155
+ else:
156
+ retval = [
157
+ f"if {param.full_client_name} is not None:",
158
+ f" {set_parameter}",
159
+ ]
160
+ return retval
161
+
107
162
  @staticmethod
108
163
  def pop_kwargs_from_signature(
109
164
  parameters: Sequence[_ParameterBase],
@@ -7,6 +7,7 @@
7
7
  # pylint: disable=protected-access, arguments-differ, signature-differs, broad-except
8
8
  # pyright: reportGeneralTypeIssues=false
9
9
 
10
+ import calendar
10
11
  import functools
11
12
  import sys
12
13
  import logging
@@ -14,13 +15,14 @@ import base64
14
15
  import re
15
16
  import copy
16
17
  import typing
18
+ import email
17
19
  from datetime import datetime, date, time, timedelta, timezone
18
20
  from json import JSONEncoder
19
21
  import isodate
20
22
  from azure.core.exceptions import DeserializationError
21
23
  from azure.core import CaseInsensitiveEnumMeta
22
24
  from azure.core.pipeline import PipelineResponse
23
- from azure.core.serialization import _Null # pylint: disable=protected-access
25
+ from azure.core.serialization import _Null
24
26
 
25
27
  if sys.version_info >= (3, 9):
26
28
  from collections.abc import MutableMapping
@@ -31,9 +33,9 @@ _LOGGER = logging.getLogger(__name__)
31
33
 
32
34
  __all__ = ["AzureJSONEncoder", "Model", "rest_field", "rest_discriminator"]
33
35
 
34
-
35
36
  TZ_UTC = timezone.utc
36
37
 
38
+
37
39
  def _timedelta_as_isostr(td: timedelta) -> str:
38
40
  """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S'
39
41
 
@@ -91,38 +93,20 @@ def _timedelta_as_isostr(td: timedelta) -> str:
91
93
  return "P" + date_str + time_str
92
94
 
93
95
 
94
- def _datetime_as_isostr(dt: typing.Union[datetime, date, time, timedelta]) -> str:
95
- """Converts a datetime.(datetime|date|time|timedelta) object into an ISO 8601 formatted string
96
-
97
- :param timedelta dt: The date object to convert
98
- :rtype: str
99
- :return: ISO8601 version of this datetime
100
- """
101
- # First try datetime.datetime
102
- if hasattr(dt, "year") and hasattr(dt, "hour"):
103
- dt = typing.cast(datetime, dt)
104
- # astimezone() fails for naive times in Python 2.7, so make make sure dt is aware (tzinfo is set)
105
- if not dt.tzinfo:
106
- iso_formatted = dt.replace(tzinfo=TZ_UTC).isoformat()
107
- else:
108
- iso_formatted = dt.astimezone(TZ_UTC).isoformat()
109
- # Replace the trailing "+00:00" UTC offset with "Z" (RFC 3339: https://www.ietf.org/rfc/rfc3339.txt)
110
- return iso_formatted.replace("+00:00", "Z")
111
- # Next try datetime.date or datetime.time
112
- try:
113
- dt = typing.cast(typing.Union[date, time], dt)
114
- return dt.isoformat()
115
- # Last, try datetime.timedelta
116
- except AttributeError:
117
- dt = typing.cast(timedelta, dt)
118
- return _timedelta_as_isostr(dt)
119
-
120
- def _serialize_bytes(o) -> str:
121
- return base64.b64encode(o).decode()
96
+ def _serialize_bytes(o, format: typing.Optional[str] = None) -> str:
97
+ encoded = base64.b64encode(o).decode()
98
+ if format == "base64url":
99
+ return encoded.strip("=").replace("+", "-").replace("/", "_")
100
+ return encoded
122
101
 
123
102
 
124
- def _serialize_datetime(o):
103
+ def _serialize_datetime(o, format: typing.Optional[str] = None):
125
104
  if hasattr(o, "year") and hasattr(o, "hour"):
105
+ if format == "rfc7231":
106
+ return email.utils.format_datetime(o, usegmt=True)
107
+ if format == "unix-timestamp":
108
+ return int(calendar.timegm(o.utctimetuple()))
109
+
126
110
  # astimezone() fails for naive times in Python 2.7, so make make sure o is aware (tzinfo is set)
127
111
  if not o.tzinfo:
128
112
  iso_formatted = o.replace(tzinfo=TZ_UTC).isoformat()
@@ -146,7 +130,7 @@ class AzureJSONEncoder(JSONEncoder):
146
130
 
147
131
  def default(self, o): # pylint: disable=too-many-return-statements
148
132
  if _is_model(o):
149
- readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] # pylint: disable=protected-access
133
+ readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)]
150
134
  return {k: v for k, v in o.items() if k not in readonly_props}
151
135
  if isinstance(o, (bytes, bytearray)):
152
136
  return base64.b64encode(o).decode()
@@ -172,6 +156,10 @@ class AzureJSONEncoder(JSONEncoder):
172
156
 
173
157
 
174
158
  _VALID_DATE = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" + r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?")
159
+ _VALID_RFC7231 = re.compile(
160
+ r"(Mon|Tue|Wed|Thu|Fri|Sat|Sun),\s\d{2}\s"
161
+ r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s\d{4}\s\d{2}:\d{2}:\d{2}\sGMT"
162
+ )
175
163
 
176
164
 
177
165
  def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime:
@@ -207,6 +195,36 @@ def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime:
207
195
  return date_obj
208
196
 
209
197
 
198
+ def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime:
199
+ """Deserialize RFC7231 formatted string into Datetime object.
200
+
201
+ :param str attr: response string to be deserialized.
202
+ :rtype: ~datetime.datetime
203
+ :returns: The datetime object from that input
204
+ """
205
+ if isinstance(attr, datetime):
206
+ # i'm already deserialized
207
+ return attr
208
+ match = _VALID_RFC7231.match(attr)
209
+ if not match:
210
+ raise ValueError("Invalid datetime string: " + attr)
211
+
212
+ return email.utils.parsedate_to_datetime(attr)
213
+
214
+
215
+ def _deserialize_datetime_unix_timestamp(attr: typing.Union[float, datetime]) -> datetime:
216
+ """Deserialize unix timestamp into Datetime object.
217
+
218
+ :param str attr: response string to be deserialized.
219
+ :rtype: ~datetime.datetime
220
+ :returns: The datetime object from that input
221
+ """
222
+ if isinstance(attr, datetime):
223
+ # i'm already deserialized
224
+ return attr
225
+ return datetime.fromtimestamp(attr, TZ_UTC)
226
+
227
+
210
228
  def _deserialize_date(attr: typing.Union[str, date]) -> date:
211
229
  """Deserialize ISO-8601 formatted string into Date object.
212
230
  :param str attr: response string to be deserialized.
@@ -231,13 +249,22 @@ def _deserialize_time(attr: typing.Union[str, time]) -> time:
231
249
  return isodate.parse_time(attr)
232
250
 
233
251
 
234
- def deserialize_bytes(attr):
252
+ def _deserialize_bytes(attr):
235
253
  if isinstance(attr, (bytes, bytearray)):
236
254
  return attr
237
255
  return bytes(base64.b64decode(attr))
238
256
 
239
257
 
240
- def deserialize_duration(attr):
258
+ def _deserialize_bytes_base64(attr):
259
+ if isinstance(attr, (bytes, bytearray)):
260
+ return attr
261
+ padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore
262
+ attr = attr + padding # type: ignore
263
+ encoded = attr.replace("-", "+").replace("_", "/")
264
+ return bytes(base64.b64decode(encoded))
265
+
266
+
267
+ def _deserialize_duration(attr):
241
268
  if isinstance(attr, timedelta):
242
269
  return attr
243
270
  return isodate.parse_duration(attr)
@@ -247,11 +274,26 @@ _DESERIALIZE_MAPPING = {
247
274
  datetime: _deserialize_datetime,
248
275
  date: _deserialize_date,
249
276
  time: _deserialize_time,
250
- bytes: deserialize_bytes,
251
- timedelta: deserialize_duration,
277
+ bytes: _deserialize_bytes,
278
+ bytearray: _deserialize_bytes,
279
+ timedelta: _deserialize_duration,
252
280
  typing.Any: lambda x: x,
253
281
  }
254
282
 
283
+ _DESERIALIZE_MAPPING_WITHFORMAT = {
284
+ "rfc3339": _deserialize_datetime,
285
+ "rfc7231": _deserialize_datetime_rfc7231,
286
+ "unix-timestamp": _deserialize_datetime_unix_timestamp,
287
+ "base64": _deserialize_bytes,
288
+ "base64url": _deserialize_bytes_base64,
289
+ }
290
+
291
+
292
+ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None):
293
+ if rf and rf._format:
294
+ return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format)
295
+ return _DESERIALIZE_MAPPING.get(annotation)
296
+
255
297
 
256
298
  def _get_model(module_name: str, model_name: str):
257
299
  models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)}
@@ -358,12 +400,20 @@ def _is_model(obj: typing.Any) -> bool:
358
400
  return getattr(obj, "_is_model", False)
359
401
 
360
402
 
361
- def _serialize(o):
403
+ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-many-return-statements
404
+ if isinstance(o, list):
405
+ return [_serialize(x, format) for x in o]
406
+ if isinstance(o, dict):
407
+ return {k: _serialize(v, format) for k, v in o.items()}
408
+ if isinstance(o, set):
409
+ return {_serialize(x, format) for x in o}
410
+ if isinstance(o, tuple):
411
+ return tuple(_serialize(x, format) for x in o)
362
412
  if isinstance(o, (bytes, bytearray)):
363
- return _serialize_bytes(o)
413
+ return _serialize_bytes(o, format)
364
414
  try:
365
415
  # First try datetime.datetime
366
- return _serialize_datetime(o)
416
+ return _serialize_datetime(o, format)
367
417
  except AttributeError:
368
418
  pass
369
419
  # Last, try datetime.timedelta
@@ -385,7 +435,7 @@ def _get_rest_field(
385
435
 
386
436
 
387
437
  def _create_value(rf: typing.Optional["_RestField"], value: typing.Any) -> typing.Any:
388
- return _deserialize(rf._type, value) if (rf and rf._is_model) else _serialize(value)
438
+ return _deserialize(rf._type, value) if (rf and rf._is_model) else _serialize(value, rf._format if rf else None)
389
439
 
390
440
 
391
441
  class Model(_MyMutableMapping):
@@ -409,10 +459,13 @@ class Model(_MyMutableMapping):
409
459
  if non_attr_kwargs:
410
460
  # actual type errors only throw the first wrong keyword arg they see, so following that.
411
461
  raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'")
412
- dict_to_pass.update({
413
- self._attr_to_rest_field[k]._rest_name: _serialize(v)
414
- for k, v in kwargs.items() if v is not None
415
- })
462
+ dict_to_pass.update(
463
+ {
464
+ self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format)
465
+ for k, v in kwargs.items()
466
+ if v is not None
467
+ }
468
+ )
416
469
  super().__init__(dict_to_pass)
417
470
 
418
471
  def copy(self) -> "Model":
@@ -464,145 +517,157 @@ class Model(_MyMutableMapping):
464
517
 
465
518
 
466
519
  def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements
467
- annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None
468
- ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]:
469
- if not annotation or annotation in [int, float]:
470
- return None
520
+ annotation: typing.Any,
521
+ module: typing.Optional[str],
522
+ rf: typing.Optional["_RestField"] = None,
523
+ ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]:
524
+ if not annotation or annotation in [int, float]:
525
+ return None
471
526
 
472
- try:
473
- if module and _is_model(_get_model(module, annotation)):
474
- if rf:
475
- rf._is_model = True
476
- def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj):
477
- if _is_model(obj):
478
- return obj
479
- return _deserialize(model_deserializer, obj)
527
+ try:
528
+ if module and _is_model(_get_model(module, annotation)):
529
+ if rf:
530
+ rf._is_model = True
480
531
 
481
- return functools.partial(_deserialize_model, _get_model(module, annotation))
482
- except Exception:
483
- pass
532
+ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj):
533
+ if _is_model(obj):
534
+ return obj
535
+ return _deserialize(model_deserializer, obj)
484
536
 
485
- # is it a literal?
486
- try:
487
- if sys.version_info >= (3, 8):
488
- from typing import Literal # pylint: disable=no-name-in-module, ungrouped-imports
489
- else:
490
- from typing_extensions import Literal # type: ignore # pylint: disable=ungrouped-imports
537
+ return functools.partial(_deserialize_model, _get_model(module, annotation))
538
+ except Exception:
539
+ pass
491
540
 
492
- if annotation.__origin__ == Literal:
493
- return None
494
- except AttributeError:
495
- pass
541
+ # is it a literal?
542
+ try:
543
+ if sys.version_info >= (3, 8):
544
+ from typing import (
545
+ Literal,
546
+ ) # pylint: disable=no-name-in-module, ungrouped-imports
547
+ else:
548
+ from typing_extensions import Literal # type: ignore # pylint: disable=ungrouped-imports
496
549
 
497
- if getattr(annotation, "__origin__", None) is typing.Union:
498
- def _deserialize_with_union(union_annotation, obj):
499
- for t in union_annotation.__args__:
500
- try:
501
- return _deserialize(t, obj, module)
502
- except DeserializationError:
503
- pass
504
- raise DeserializationError()
505
- return functools.partial(_deserialize_with_union, annotation)
506
-
507
- # is it optional?
508
- try:
509
- # right now, assuming we don't have unions, since we're getting rid of the only
510
- # union we used to have in msrest models, which was union of str and enum
511
- if any(a for a in annotation.__args__ if a == type(None)):
550
+ if annotation.__origin__ == Literal:
551
+ return None
552
+ except AttributeError:
553
+ pass
512
554
 
513
- if_obj_deserializer = _get_deserialize_callable_from_annotation(
514
- next(a for a in annotation.__args__ if a != type(None)), module, rf
515
- )
555
+ if getattr(annotation, "__origin__", None) is typing.Union:
516
556
 
517
- def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj):
518
- if obj is None:
519
- return obj
520
- return _deserialize_with_callable(if_obj_deserializer, obj)
557
+ def _deserialize_with_union(union_annotation, obj):
558
+ for t in union_annotation.__args__:
559
+ try:
560
+ return _deserialize(t, obj, module, rf)
561
+ except DeserializationError:
562
+ pass
563
+ raise DeserializationError()
521
564
 
522
- return functools.partial(_deserialize_with_optional, if_obj_deserializer)
523
- except AttributeError:
524
- pass
565
+ return functools.partial(_deserialize_with_union, annotation)
525
566
 
526
- # is it a forward ref / in quotes?
527
- if isinstance(annotation, (str, typing.ForwardRef)):
528
- try:
529
- model_name = annotation.__forward_arg__ # type: ignore
530
- except AttributeError:
531
- model_name = annotation
532
- if module is not None:
533
- annotation = _get_model(module, model_name)
567
+ # is it optional?
568
+ try:
569
+ # right now, assuming we don't have unions, since we're getting rid of the only
570
+ # union we used to have in msrest models, which was union of str and enum
571
+ if any(a for a in annotation.__args__ if a == type(None)):
572
+ if_obj_deserializer = _get_deserialize_callable_from_annotation(
573
+ next(a for a in annotation.__args__ if a != type(None)), module, rf
574
+ )
534
575
 
576
+ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj):
577
+ if obj is None:
578
+ return obj
579
+ return _deserialize_with_callable(if_obj_deserializer, obj)
580
+
581
+ return functools.partial(_deserialize_with_optional, if_obj_deserializer)
582
+ except AttributeError:
583
+ pass
584
+
585
+ # is it a forward ref / in quotes?
586
+ if isinstance(annotation, (str, typing.ForwardRef)):
535
587
  try:
536
- if annotation._name == "Dict":
537
- key_deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[0], module, rf)
538
- value_deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[1], module, rf)
539
-
540
- def _deserialize_dict(
541
- key_deserializer: typing.Optional[typing.Callable],
542
- value_deserializer: typing.Optional[typing.Callable],
543
- obj: typing.Dict[typing.Any, typing.Any],
544
- ):
545
- if obj is None:
546
- return obj
547
- return {
548
- _deserialize(key_deserializer, k, module): _deserialize(value_deserializer, v, module) for k, v in obj.items()
549
- }
550
-
551
- return functools.partial(
552
- _deserialize_dict,
553
- key_deserializer,
554
- value_deserializer,
555
- )
556
- except (AttributeError, IndexError):
557
- pass
558
- try:
559
- if annotation._name in ["List", "Set", "Tuple", "Sequence"]:
560
- if len(annotation.__args__) > 1:
561
-
562
- def _deserialize_multiple_sequence(
563
- entry_deserializers: typing.List[typing.Optional[typing.Callable]], obj
564
- ):
565
- if obj is None:
566
- return obj
567
- return type(obj)(
568
- _deserialize(deserializer, entry, module) for entry, deserializer in zip(obj, entry_deserializers)
569
- )
570
-
571
- entry_deserializers = [
572
- _get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__
573
- ]
574
- return functools.partial(_deserialize_multiple_sequence, entry_deserializers)
575
- deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[0], module, rf)
576
-
577
- def _deserialize_sequence(
578
- deserializer: typing.Optional[typing.Callable],
588
+ model_name = annotation.__forward_arg__ # type: ignore
589
+ except AttributeError:
590
+ model_name = annotation
591
+ if module is not None:
592
+ annotation = _get_model(module, model_name)
593
+
594
+ try:
595
+ if annotation._name == "Dict":
596
+ key_deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[0], module, rf)
597
+ value_deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[1], module, rf)
598
+
599
+ def _deserialize_dict(
600
+ key_deserializer: typing.Optional[typing.Callable],
601
+ value_deserializer: typing.Optional[typing.Callable],
602
+ obj: typing.Dict[typing.Any, typing.Any],
603
+ ):
604
+ if obj is None:
605
+ return obj
606
+ return {
607
+ _deserialize(key_deserializer, k, module): _deserialize(value_deserializer, v, module)
608
+ for k, v in obj.items()
609
+ }
610
+
611
+ return functools.partial(
612
+ _deserialize_dict,
613
+ key_deserializer,
614
+ value_deserializer,
615
+ )
616
+ except (AttributeError, IndexError):
617
+ pass
618
+ try:
619
+ if annotation._name in ["List", "Set", "Tuple", "Sequence"]:
620
+ if len(annotation.__args__) > 1:
621
+
622
+ def _deserialize_multiple_sequence(
623
+ entry_deserializers: typing.List[typing.Optional[typing.Callable]],
579
624
  obj,
580
625
  ):
581
626
  if obj is None:
582
627
  return obj
583
- return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)
628
+ return type(obj)(
629
+ _deserialize(deserializer, entry, module)
630
+ for entry, deserializer in zip(obj, entry_deserializers)
631
+ )
632
+
633
+ entry_deserializers = [
634
+ _get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__
635
+ ]
636
+ return functools.partial(_deserialize_multiple_sequence, entry_deserializers)
637
+ deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[0], module, rf)
638
+
639
+ def _deserialize_sequence(
640
+ deserializer: typing.Optional[typing.Callable],
641
+ obj,
642
+ ):
643
+ if obj is None:
644
+ return obj
645
+ return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)
646
+
647
+ return functools.partial(_deserialize_sequence, deserializer)
648
+ except (TypeError, IndexError, AttributeError, SyntaxError):
649
+ pass
584
650
 
585
- return functools.partial(_deserialize_sequence, deserializer)
586
- except (TypeError, IndexError, AttributeError, SyntaxError):
651
+ def _deserialize_default(
652
+ annotation,
653
+ deserializer_from_mapping,
654
+ obj,
655
+ ):
656
+ if obj is None:
657
+ return obj
658
+ try:
659
+ return _deserialize_with_callable(annotation, obj)
660
+ except Exception:
587
661
  pass
662
+ return _deserialize_with_callable(deserializer_from_mapping, obj)
588
663
 
589
- def _deserialize_default(
590
- annotation,
591
- deserializer_from_mapping,
592
- obj,
593
- ):
594
- if obj is None:
595
- return obj
596
- try:
597
- return _deserialize_with_callable(annotation, obj)
598
- except Exception:
599
- pass
600
- return _deserialize_with_callable(deserializer_from_mapping, obj)
601
-
602
- return functools.partial(_deserialize_default, annotation, _DESERIALIZE_MAPPING.get(annotation))
664
+ return functools.partial(_deserialize_default, annotation, get_deserializer(annotation, rf))
603
665
 
604
666
 
605
- def _deserialize_with_callable(deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]], value: typing.Any):
667
+ def _deserialize_with_callable(
668
+ deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]],
669
+ value: typing.Any,
670
+ ):
606
671
  try:
607
672
  if value is None:
608
673
  return None
@@ -621,12 +686,18 @@ def _deserialize_with_callable(deserializer: typing.Optional[typing.Callable[[ty
621
686
  raise DeserializationError() from e
622
687
 
623
688
 
624
- def _deserialize(deserializer: typing.Any, value: typing.Any, module: typing.Optional[str] = None) -> typing.Any:
689
+ def _deserialize(
690
+ deserializer: typing.Any,
691
+ value: typing.Any,
692
+ module: typing.Optional[str] = None,
693
+ rf: typing.Optional["_RestField"] = None,
694
+ ) -> typing.Any:
625
695
  if isinstance(value, PipelineResponse):
626
696
  value = value.http_response.json()
627
- deserializer = _get_deserialize_callable_from_annotation(deserializer, module)
697
+ deserializer = _get_deserialize_callable_from_annotation(deserializer, module, rf)
628
698
  return _deserialize_with_callable(deserializer, value)
629
699
 
700
+
630
701
  class _RestField:
631
702
  def __init__(
632
703
  self,
@@ -636,6 +707,7 @@ class _RestField:
636
707
  is_discriminator: bool = False,
637
708
  visibility: typing.Optional[typing.List[str]] = None,
638
709
  default: typing.Any = _UNSET,
710
+ format: typing.Optional[str] = None,
639
711
  ):
640
712
  self._type = type
641
713
  self._rest_name_input = name
@@ -644,6 +716,7 @@ class _RestField:
644
716
  self._visibility = visibility
645
717
  self._is_model = False
646
718
  self._default = default
719
+ self._format = format
647
720
 
648
721
  @property
649
722
  def _rest_name(self) -> str:
@@ -657,7 +730,7 @@ class _RestField:
657
730
  item = obj.get(self._rest_name)
658
731
  if item is None:
659
732
  return item
660
- return _deserialize(self._type, _serialize(item))
733
+ return _deserialize(self._type, _serialize(item, self._format), rf=self)
661
734
 
662
735
  def __set__(self, obj: Model, value) -> None:
663
736
  if value is None:
@@ -669,7 +742,7 @@ class _RestField:
669
742
  return
670
743
  if self._is_model and not _is_model(value):
671
744
  obj.__setitem__(self._rest_name, _deserialize(self._type, value))
672
- obj.__setitem__(self._rest_name, _serialize(value))
745
+ obj.__setitem__(self._rest_name, _serialize(value, self._format))
673
746
 
674
747
  def _get_deserialize_callable_from_annotation(
675
748
  self, annotation: typing.Any
@@ -683,8 +756,9 @@ def rest_field(
683
756
  type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
684
757
  visibility: typing.Optional[typing.List[str]] = None,
685
758
  default: typing.Any = _UNSET,
759
+ format: typing.Optional[str] = None,
686
760
  ) -> typing.Any:
687
- return _RestField(name=name, type=type, visibility=visibility, default=default)
761
+ return _RestField(name=name, type=type, visibility=visibility, default=default, format=format)
688
762
 
689
763
 
690
764
  def rest_discriminator(
@@ -90,7 +90,7 @@ setup(
90
90
  {% if azure_arm %}
91
91
  "azure-mgmt-core<2.0.0,>=1.3.2",
92
92
  {% else %}
93
- "azure-core<2.0.0,>=1.27.0",
93
+ "azure-core<2.0.0,>=1.28.0",
94
94
  {% endif %}
95
95
  {% if code_model.need_typing_extensions %}
96
96
  "typing-extensions>=4.3.0; python_version<'3.8.0'",
@@ -744,6 +744,8 @@ class Serializer(object):
744
744
 
745
745
  :param data: The data to be serialized.
746
746
  :param str data_type: The type to be serialized from.
747
+ :keyword bool skip_quote: Whether to skip quote the serialized result.
748
+ Defaults to False.
747
749
  :rtype: str
748
750
  :raises: TypeError if serialization fails.
749
751
  :raises: ValueError if data is None
@@ -752,10 +754,8 @@ class Serializer(object):
752
754
  # Treat the list aside, since we don't want to encode the div separator
753
755
  if data_type.startswith("["):
754
756
  internal_data_type = data_type[1:-1]
755
- data = [self.serialize_data(d, internal_data_type, **kwargs) if d is not None else "" for d in data]
756
- if not kwargs.get("skip_quote", False):
757
- data = [quote(str(d), safe="") for d in data]
758
- return str(self.serialize_iter(data, internal_data_type, **kwargs))
757
+ do_quote = not kwargs.get('skip_quote', False)
758
+ return str(self.serialize_iter(data, internal_data_type, do_quote=do_quote, **kwargs))
759
759
 
760
760
  # Not a list, regular serialization
761
761
  output = self.serialize_data(data, data_type, **kwargs)
@@ -894,6 +894,8 @@ class Serializer(object):
894
894
  not be None or empty.
895
895
  :param str div: If set, this str will be used to combine the elements
896
896
  in the iterable into a combined string. Default is 'None'.
897
+ :keyword bool do_quote: Whether to quote the serialized result of each iterable element.
898
+ Defaults to False.
897
899
  :rtype: list, str
898
900
  """
899
901
  if isinstance(data, str):
@@ -911,6 +913,13 @@ class Serializer(object):
911
913
  raise
912
914
  serialized.append(None)
913
915
 
916
+ if kwargs.get('do_quote', False):
917
+ serialized = [
918
+ '' if s is None else quote(str(s), safe='')
919
+ for s
920
+ in serialized
921
+ ]
922
+
914
923
  if div:
915
924
  serialized = ["" if s is None else str(s) for s in serialized]
916
925
  serialized = div.join(serialized)
@@ -32,3 +32,32 @@ def raise_if_not_implemented(cls, abstract_methods):
32
32
  cls.__name__, '\', \''.join(not_implemented))
33
33
  )
34
34
  {% endif %}
35
+
36
+ {% if code_model.has_etag %}
37
+ def quote_etag(etag: Optional[str]) -> Optional[str]:
38
+ if not etag or etag == "*":
39
+ return etag
40
+ if etag.startswith('"') and etag.endswith('"'):
41
+ return etag
42
+ if etag.startswith("'") and etag.endswith("'"):
43
+ return etag
44
+ return '"' + etag + '"'
45
+
46
+
47
+ def prep_if_match(etag: Optional[str], match_condition: Optional[MatchConditions]) -> Optional[str]:
48
+ if match_condition == MatchConditions.IfNotModified:
49
+ if_match = quote_etag(etag) if etag else None
50
+ return if_match
51
+ if match_condition == MatchConditions.IfPresent:
52
+ return "*"
53
+ return None
54
+
55
+
56
+ def prep_if_none_match(etag: Optional[str], match_condition: Optional[MatchConditions]) -> Optional[str]:
57
+ if match_condition == MatchConditions.IfModified:
58
+ if_none_match = quote_etag(etag) if etag else None
59
+ return if_none_match
60
+ if match_condition == MatchConditions.IfMissing:
61
+ return "*"
62
+ return None
63
+ {% endif %}
@@ -176,6 +176,37 @@ def update_paging_response(yaml_data: Dict[str, Any]) -> None:
176
176
  )
177
177
 
178
178
 
179
+ HEADERS_HIDE_IN_METHOD = (
180
+ "repeatability-request-id",
181
+ "repeatability-first-sent",
182
+ "x-ms-client-request-id",
183
+ "client-request-id",
184
+ "return-client-request-id",
185
+ )
186
+ HEADERS_CONVERT_IN_METHOD = {
187
+ "if-match": {
188
+ "clientName": "etag",
189
+ "wireName": "etag",
190
+ "description": "check if resource is changed. Set None to skip checking etag.",
191
+ },
192
+ "if-none-match": {
193
+ "clientName": "match_condition",
194
+ "wireName": "match-condition",
195
+ "description": "The match condition to use upon the etag.",
196
+ "type": {
197
+ "type": "azurecore",
198
+ "name": "MatchConditions",
199
+ },
200
+ },
201
+ }
202
+
203
+
204
+ def headers_convert(yaml_data: Dict[str, Any], replace_data: Any) -> None:
205
+ if isinstance(replace_data, dict):
206
+ for k, v in replace_data.items():
207
+ yaml_data[k] = v
208
+
209
+
179
210
  class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
180
211
  """Add Python naming information."""
181
212
 
@@ -242,6 +273,11 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
242
273
  )
243
274
  type["values"].extend(values_to_add)
244
275
 
276
+ # add type for reference
277
+ for v in HEADERS_CONVERT_IN_METHOD.values():
278
+ if isinstance(v, dict) and "type" in v:
279
+ yaml_data.append(v["type"])
280
+
245
281
  def update_client(self, yaml_data: Dict[str, Any]) -> None:
246
282
  yaml_data["description"] = update_description(
247
283
  yaml_data["description"], default_description=yaml_data["name"]
@@ -253,6 +289,21 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
253
289
  if prop_name.endswith("Client"):
254
290
  prop_name = prop_name[: len(prop_name) - len("Client")]
255
291
  yaml_data["builderPadName"] = to_snake_case(prop_name)
292
+ for og in yaml_data["operationGroups"]:
293
+ for o in og["operations"]:
294
+ for p in o["parameters"]:
295
+ if (
296
+ p["location"] == "header"
297
+ and p["wireName"] == "client-request-id"
298
+ ):
299
+ yaml_data["requestIdHeaderName"] = p["wireName"]
300
+ if (
301
+ self.version_tolerant
302
+ and p["location"] == "header"
303
+ and p["clientName"] in ("if_match", "if_none_match")
304
+ ):
305
+ o["hasEtag"] = True
306
+ yaml_data["hasEtag"] = True
256
307
 
257
308
  def get_operation_updater(
258
309
  self, yaml_data: Dict[str, Any]
@@ -282,6 +333,18 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
282
333
  .lower()
283
334
  for prop, param_name in yaml_data["propertyToParameterName"].items()
284
335
  }
336
+ wire_name_lower = (yaml_data.get("wireName") or "").lower()
337
+ if (
338
+ yaml_data["location"] == "header"
339
+ and wire_name_lower in HEADERS_HIDE_IN_METHOD
340
+ ):
341
+ yaml_data["hideInMethod"] = True
342
+ if (
343
+ self.version_tolerant
344
+ and yaml_data["location"] == "header"
345
+ and wire_name_lower in HEADERS_CONVERT_IN_METHOD
346
+ ):
347
+ headers_convert(yaml_data, HEADERS_CONVERT_IN_METHOD[wire_name_lower])
285
348
 
286
349
  def update_operation(
287
350
  self,
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@autorest/python",
3
- "version": "6.7.0",
3
+ "version": "6.7.1",
4
4
  "description": "The Python extension for generators in AutoRest.",
5
5
  "main": "index.js",
6
6
  "repository": {