@autorest/python 6.7.0 → 6.7.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -52,6 +52,9 @@ class BlackScriptPlugin(Plugin): # pylint: disable=abstract-method
52
52
  )
53
53
  except NothingChanged:
54
54
  pass
55
+ except: # pylint: disable=bare-except
56
+ _LOGGER.error("Error: failed to format %s", file)
57
+ raise
55
58
  self.write_file(file, file_content)
56
59
 
57
60
 
@@ -28,6 +28,7 @@ from .primitive_types import (
28
28
  BooleanType,
29
29
  AnyObjectType,
30
30
  UnixTimeType,
31
+ AzureCoreType,
31
32
  )
32
33
  from .enum_type import EnumType
33
34
  from .base import BaseType
@@ -144,6 +145,7 @@ TYPE_TO_OBJECT = {
144
145
  "any-object": AnyObjectType,
145
146
  "unixtime": UnixTimeType,
146
147
  "credential": StringType,
148
+ "azurecore": AzureCoreType,
147
149
  }
148
150
  _LOGGER = logging.getLogger(__name__)
149
151
 
@@ -76,6 +76,8 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
76
76
  for op_group in self.yaml_data.get("operationGroups", [])
77
77
  ]
78
78
  self.link_lro_initial_operations()
79
+ self.request_id_header_name = self.yaml_data.get("requestIdHeaderName", None)
80
+ self.has_etag: bool = yaml_data.get("hasEtag", False)
79
81
 
80
82
  def _build_request_builders(
81
83
  self,
@@ -15,10 +15,10 @@ from .constant_type import ConstantType
15
15
 
16
16
 
17
17
  def _is_legacy(options) -> bool:
18
- return not (options["version_tolerant"] or options["low_level_client"])
18
+ return not (options.get("version_tolerant") or options.get("low_level_client"))
19
19
 
20
20
 
21
- class CodeModel: # pylint: disable=too-many-public-methods
21
+ class CodeModel: # pylint: disable=too-many-public-methods, disable=too-many-instance-attributes
22
22
  """Top level code model
23
23
 
24
24
  :param options: Options of the code model. I.e., whether this is for management generation
@@ -77,6 +77,10 @@ class CodeModel: # pylint: disable=too-many-public-methods
77
77
  t for t in self.types_map.values() if isinstance(t, CombinedType) and t.name
78
78
  ]
79
79
 
80
+ @property
81
+ def has_etag(self) -> bool:
82
+ return any(client.has_etag for client in self.clients)
83
+
80
84
  @property
81
85
  def has_operations(self) -> bool:
82
86
  if any(c for c in self.clients if c.has_operations):
@@ -30,6 +30,10 @@ class DictionaryType(BaseType):
30
30
  super().__init__(yaml_data=yaml_data, code_model=code_model)
31
31
  self.element_type = element_type
32
32
 
33
+ @property
34
+ def format(self) -> Optional[str]:
35
+ return self.element_type.format if hasattr(self.element_type, "format") else None # type: ignore
36
+
33
37
  @property
34
38
  def serialization_type(self) -> str:
35
39
  """Returns the serialization value for msrest.
@@ -25,6 +25,10 @@ class ListType(BaseType):
25
25
  self.min_items: Optional[int] = yaml_data.get("minItems")
26
26
  self.unique_items: bool = yaml_data.get("uniqueItems", False)
27
27
 
28
+ @property
29
+ def format(self) -> Optional[str]:
30
+ return self.element_type.format if hasattr(self.element_type, "format") else None # type: ignore
31
+
28
32
  @property
29
33
  def serialization_type(self) -> str:
30
34
  return f"[{self.element_type.serialization_type}]"
@@ -90,6 +90,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
90
90
  self.internal: bool = self.yaml_data.get("internal", False)
91
91
  if self.internal:
92
92
  self.name = "_" + self.name
93
+ self.has_etag: bool = self.yaml_data.get("hasEtag", False)
93
94
 
94
95
  @property
95
96
  def expose_stream_keyword(self) -> bool:
@@ -331,7 +332,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
331
332
  )
332
333
  return file_import
333
334
 
334
- def imports( # pylint: disable=too-many-branches
335
+ def imports( # pylint: disable=too-many-branches, disable=too-many-statements
335
336
  self, async_mode: bool, **kwargs: Any
336
337
  ) -> FileImport:
337
338
  if self.abstract:
@@ -394,11 +395,22 @@ class OperationBase( # pylint: disable=too-many-public-methods
394
395
  if self.deprecated:
395
396
  file_import.add_import("warnings", ImportType.STDLIB)
396
397
 
398
+ relative_path = "..." if async_mode else ".."
397
399
  if self.code_model.need_request_converter:
398
- relative_path = "..." if async_mode else ".."
399
400
  file_import.add_submodule_import(
400
401
  f"{relative_path}_vendor", "_convert_request", ImportType.LOCAL
401
402
  )
403
+ if self.has_etag:
404
+ file_import.add_submodule_import(
405
+ "azure.core.exceptions", "ResourceModifiedError", ImportType.AZURECORE
406
+ )
407
+ if not async_mode:
408
+ file_import.add_submodule_import(
409
+ f"{relative_path}_vendor", "prep_if_match", ImportType.LOCAL
410
+ )
411
+ file_import.add_submodule_import(
412
+ f"{relative_path}_vendor", "prep_if_none_match", ImportType.LOCAL
413
+ )
402
414
  if self.code_model.need_request_converter:
403
415
  if async_mode:
404
416
  file_import.add_submodule_import(
@@ -52,9 +52,6 @@ class ParameterDelimeter(str, Enum):
52
52
  COMMA = "comma"
53
53
 
54
54
 
55
- SPECIAL_HANDLE_HEADERS = ["repeatability-request-id", "repeatability-first-sent"]
56
-
57
-
58
55
  class _ParameterBase(
59
56
  BaseModel, abc.ABC
60
57
  ): # pylint: disable=too-many-instance-attributes
@@ -93,10 +90,7 @@ class _ParameterBase(
93
90
  self.default_to_unset_sentinel: bool = self.yaml_data.get(
94
91
  "defaultToUnsetSentinel", False
95
92
  )
96
- self.is_special_handle_header: bool = (
97
- self.location == ParameterLocation.HEADER
98
- and self.wire_name.lower() in SPECIAL_HANDLE_HEADERS
99
- )
93
+ self.hide_in_method: bool = self.yaml_data.get("hideInMethod", False)
100
94
 
101
95
  @property
102
96
  def constant(self) -> bool:
@@ -215,7 +215,7 @@ class _ParameterListBase(
215
215
  for p in self.parameters
216
216
  if p.in_method_signature
217
217
  and p.implementation == self.implementation
218
- and not p.is_special_handle_header
218
+ and (self.code_model.is_legacy or not p.hide_in_method)
219
219
  ]
220
220
  if self._body_parameter:
221
221
  if self._body_parameter.in_method_signature:
@@ -4,7 +4,6 @@
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
6
  import datetime
7
- from enum import Enum
8
7
  from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
9
8
 
10
9
  from .base import BaseType
@@ -356,17 +355,17 @@ class StringType(PrimitiveType):
356
355
  class DatetimeType(PrimitiveType):
357
356
  def __init__(self, yaml_data: Dict[str, Any], code_model: "CodeModel") -> None:
358
357
  super().__init__(yaml_data=yaml_data, code_model=code_model)
359
- self.format = self.Formats(yaml_data.get("format", "date-time"))
360
-
361
- class Formats(str, Enum):
362
- datetime = "date-time"
363
- rfc1123 = "date-time-rfc1123"
358
+ self.format = (
359
+ "rfc3339"
360
+ if yaml_data.get("format", "date-time") == "date-time"
361
+ else "rfc7231"
362
+ )
364
363
 
365
364
  @property
366
365
  def serialization_type(self) -> str:
367
366
  formats_to_attribute_type = {
368
- self.Formats.datetime: "iso-8601",
369
- self.Formats.rfc1123: "rfc-1123",
367
+ "rfc3339": "iso-8601",
368
+ "rfc7231": "rfc-1123",
370
369
  }
371
370
  return formats_to_attribute_type[self.format]
372
371
 
@@ -454,6 +453,10 @@ class TimeType(PrimitiveType):
454
453
 
455
454
 
456
455
  class UnixTimeType(PrimitiveType):
456
+ @property
457
+ def format(self) -> str:
458
+ return "unix-timestamp"
459
+
457
460
  @property
458
461
  def serialization_type(self) -> str:
459
462
  return "unix-time"
@@ -588,7 +591,7 @@ class DurationType(PrimitiveType):
588
591
  class ByteArraySchema(PrimitiveType):
589
592
  def __init__(self, yaml_data: Dict[str, Any], code_model: "CodeModel") -> None:
590
593
  super().__init__(yaml_data=yaml_data, code_model=code_model)
591
- self.format = yaml_data.get("format", "bytes")
594
+ self.format = yaml_data.get("format", "base64")
592
595
 
593
596
  @property
594
597
  def serialization_type(self) -> str:
@@ -605,3 +608,28 @@ class ByteArraySchema(PrimitiveType):
605
608
  @property
606
609
  def instance_check_template(self) -> str:
607
610
  return "isinstance({}, bytes)"
611
+
612
+
613
+ class AzureCoreType(PrimitiveType):
614
+ def __init__(self, yaml_data: Dict[str, Any], code_model: "CodeModel") -> None:
615
+ super().__init__(yaml_data=yaml_data, code_model=code_model)
616
+ self.name = yaml_data.get("name", "")
617
+
618
+ def docstring_type(self, **kwargs: Any) -> str:
619
+ return "~azure.core." + self.type_annotation(**kwargs)
620
+
621
+ def type_annotation(self, **kwargs: Any) -> str:
622
+ return self.name
623
+
624
+ def imports(self, **kwargs: Any) -> FileImport:
625
+ file_import = FileImport()
626
+ file_import.add_submodule_import("azure.core", self.name, ImportType.AZURECORE)
627
+ return file_import
628
+
629
+ @property
630
+ def instance_check_template(self) -> str:
631
+ return f"isinstance({{}}, {self.name})"
632
+
633
+ @property
634
+ def serialization_type(self) -> str:
635
+ return self.name
@@ -52,6 +52,7 @@ _PACKAGE_FILES = [
52
52
 
53
53
  _REGENERATE_FILES = {"setup.py", "MANIFEST.in"}
54
54
 
55
+
55
56
  # extract sub folders. For example, source_file_path is like:
56
57
  # "xxx/resource-manager/Microsoft.XX/stable/2023-04-01/examples/Compute/createOrUpdate/AKSCompute.json",
57
58
  # and we want to extract the sub folders after "examples/", which is "compute/create_or_update"
@@ -18,7 +18,6 @@ from ..models import (
18
18
  ModelType,
19
19
  DictionaryType,
20
20
  ListType,
21
- Parameter,
22
21
  RequestBuilder,
23
22
  ParameterLocation,
24
23
  Response,
@@ -422,36 +421,6 @@ class _BuilderBaseSerializer(Generic[BuilderType]): # pylint: disable=abstract-
422
421
  )
423
422
  return template
424
423
 
425
- def _serialize_parameter(self, param: Parameter, kwarg_name: str) -> List[str]:
426
- set_parameter = "_{}['{}'] = {}".format(
427
- kwarg_name,
428
- param.wire_name,
429
- self.parameter_serializer.serialize_parameter(param, self.serializer_name),
430
- )
431
- if not param.optional:
432
- retval = [set_parameter]
433
- else:
434
- retval = [
435
- f"if {param.full_client_name} is not None:",
436
- f" {set_parameter}",
437
- ]
438
- return retval
439
-
440
- @staticmethod
441
- def _serialize_special_handle_header(param: Parameter) -> List[str]:
442
- if param.wire_name.lower() == "repeatability-request-id":
443
- return [
444
- """if "Repeatability-Request-ID" not in _headers:""",
445
- """ _headers["Repeatability-Request-ID"] = str(uuid.uuid4())""",
446
- ]
447
- if param.wire_name.lower() == "repeatability-first-sent":
448
- return [
449
- """if "Repeatability-First-Sent" not in _headers:""",
450
- """ _headers["Repeatability-First-Sent"] = _SERIALIZER.serialize_data(datetime.datetime.now(),
451
- "rfc-1123")""",
452
- ]
453
- raise ValueError(f"Unsupported special header: {param}")
454
-
455
424
  def serialize_path(self, builder: BuilderType) -> List[str]:
456
425
  return self.parameter_serializer.serialize_path(
457
426
  builder.parameters.path, self.serializer_name
@@ -563,24 +532,25 @@ class RequestBuilderSerializer(
563
532
  def serialize_headers(self, builder: RequestBuilderType) -> List[str]:
564
533
  retval = ["# Construct headers"]
565
534
  for parameter in builder.parameters.headers:
566
- if parameter.is_special_handle_header:
567
- retval.extend(self._serialize_special_handle_header(parameter))
568
- else:
569
- retval.extend(
570
- self._serialize_parameter(
571
- parameter,
572
- kwarg_name="headers",
573
- )
535
+ retval.extend(
536
+ self.parameter_serializer.serialize_query_header(
537
+ parameter,
538
+ "headers",
539
+ self.serializer_name,
540
+ self.code_model.is_legacy,
574
541
  )
542
+ )
575
543
  return retval
576
544
 
577
545
  def serialize_query(self, builder: RequestBuilderType) -> List[str]:
578
546
  retval = ["# Construct parameters"]
579
547
  for parameter in builder.parameters.query:
580
548
  retval.extend(
581
- self._serialize_parameter(
549
+ self.parameter_serializer.serialize_query_header(
582
550
  parameter,
583
- kwarg_name="params",
551
+ "params",
552
+ self.serializer_name,
553
+ self.code_model.is_legacy,
584
554
  )
585
555
  )
586
556
  return retval
@@ -1119,9 +1089,17 @@ class _OperationSerializer(
1119
1089
  return retval
1120
1090
 
1121
1091
  def handle_error_response(self, builder: OperationType) -> List[str]:
1092
+ async_await = "await " if self.async_mode else ""
1122
1093
  retval = [
1123
1094
  f"if response.status_code not in {str(builder.success_status_codes)}:"
1124
1095
  ]
1096
+ if not self.code_model.need_request_converter:
1097
+ retval.extend(
1098
+ [
1099
+ " if _stream:",
1100
+ f" {async_await} response.read() # Load the body in memory and close the socket",
1101
+ ]
1102
+ )
1125
1103
  retval.append(
1126
1104
  " map_error(status_code=response.status_code, response=response, error_map=error_map)"
1127
1105
  )
@@ -1283,6 +1261,17 @@ class _OperationSerializer(
1283
1261
  "304: ResourceNotModifiedError"
1284
1262
  )
1285
1263
  retval.append("}")
1264
+ if builder.has_etag:
1265
+ retval.extend(
1266
+ [
1267
+ "if match_condition == MatchConditions.IfNotModified:",
1268
+ " error_map[412] = ResourceModifiedError",
1269
+ "elif match_condition == MatchConditions.IfPresent:",
1270
+ " error_map[412] = ResourceNotFoundError",
1271
+ "elif match_condition == MatchConditions.IfMissing:",
1272
+ " error_map[412] = ResourceExistsError",
1273
+ ]
1274
+ )
1286
1275
  retval.append("error_map.update(kwargs.pop('error_map', {}) or {})")
1287
1276
  return retval
1288
1277
 
@@ -109,9 +109,15 @@ class ClientSerializer:
109
109
 
110
110
  def initialize_pipeline_client(self, async_mode: bool) -> str:
111
111
  pipeline_client_name = self.client.pipeline_class(async_mode)
112
+ params = {
113
+ "base_url": self.host_variable_name,
114
+ "config": "self._config",
115
+ }
116
+ if not self.client.code_model.is_legacy and self.client.request_id_header_name:
117
+ params["request_id_header_name"] = f'"{self.client.request_id_header_name}"'
112
118
  return (
113
- f"self._client: {pipeline_client_name} = {pipeline_client_name}(base_url={self.host_variable_name}, "
114
- "config=self._config, **kwargs)"
119
+ f"self._client: {pipeline_client_name} = {pipeline_client_name}("
120
+ f"{', '.join(f'{k}={v}' for k, v in params.items())}, **kwargs)"
115
121
  )
116
122
 
117
123
  def serializers_and_operation_groups_properties(self) -> List[str]:
@@ -133,6 +133,13 @@ class GeneralSerializer:
133
133
  f"{client.name}Configuration",
134
134
  ImportType.LOCAL,
135
135
  )
136
+ if self.code_model.has_etag:
137
+ file_import.add_submodule_import("typing", "Optional", ImportType.STDLIB)
138
+ file_import.add_submodule_import(
139
+ "azure.core",
140
+ "MatchConditions",
141
+ ImportType.AZURECORE,
142
+ )
136
143
 
137
144
  return template.render(
138
145
  code_model=self.code_model,
@@ -255,13 +255,13 @@ class DpgModelSerializer(_ModelSerializer):
255
255
  args = []
256
256
  if prop.client_name != prop.wire_name or prop.is_discriminator:
257
257
  args.append(f'name="{prop.wire_name}"')
258
- if prop.readonly:
259
- args.append("readonly=True")
260
258
  if prop.visibility:
261
259
  v_list = ", ".join(f'"{x}"' for x in prop.visibility)
262
260
  args.append(f"visibility=[{v_list}]")
263
261
  if prop.client_default_value is not None:
264
262
  args.append(f"default={prop.client_default_value_declaration}")
263
+ if hasattr(prop.type, "format") and prop.type.format: # type: ignore
264
+ args.append(f'format="{prop.type.format}"') # type: ignore
265
265
 
266
266
  field = "rest_discriminator" if prop.is_discriminator else "rest_field"
267
267
  type_ignore = (
@@ -3,10 +3,9 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import List, Sequence, Union, Optional
6
+ from typing import List, Sequence, Union, Optional, Dict
7
7
  from enum import Enum, auto
8
8
 
9
-
10
9
  from ..models import (
11
10
  Parameter,
12
11
  ParameterLocation,
@@ -26,6 +25,32 @@ class PopKwargType(Enum):
26
25
  CASE_INSENSITIVE = auto()
27
26
 
28
27
 
28
+ SPECIAL_HEADER_SERIALIZATION: Dict[str, List[str]] = {
29
+ "repeatability-request-id": [
30
+ """if "Repeatability-Request-ID" not in _headers:""",
31
+ """ _headers["Repeatability-Request-ID"] = str(uuid.uuid4())""",
32
+ ],
33
+ "repeatability-first-sent": [
34
+ """if "Repeatability-First-Sent" not in _headers:""",
35
+ """ _headers["Repeatability-First-Sent"] = _SERIALIZER.serialize_data(datetime.datetime.now(),
36
+ "rfc-1123")""",
37
+ ],
38
+ "client-request-id": [],
39
+ "x-ms-client-request-id": [],
40
+ "return-client-request-id": [],
41
+ "etag": [
42
+ """if_match = prep_if_match(etag, match_condition)""",
43
+ """if if_match is not None:""",
44
+ """ _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str")""",
45
+ ],
46
+ "match-condition": [
47
+ """if_none_match = prep_if_none_match(etag, match_condition)""",
48
+ """if if_none_match is not None:""",
49
+ """ _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str")""",
50
+ ],
51
+ }
52
+
53
+
29
54
  class ParameterSerializer:
30
55
  @staticmethod
31
56
  def serialize_parameter(parameter: ParameterType, serializer_name: str) -> str:
@@ -81,8 +106,8 @@ class ParameterSerializer:
81
106
  return f"[{serialize_line} if q is not None else '' for q in {origin_name}]"
82
107
  return serialize_line
83
108
 
109
+ @staticmethod
84
110
  def serialize_path(
85
- self,
86
111
  parameters: Union[
87
112
  List[Parameter],
88
113
  List[RequestBuilderParameter],
@@ -96,7 +121,9 @@ class ParameterSerializer:
96
121
  [
97
122
  ' "{}": {},'.format(
98
123
  path_parameter.wire_name,
99
- self.serialize_parameter(path_parameter, serializer_name),
124
+ ParameterSerializer.serialize_parameter(
125
+ path_parameter, serializer_name
126
+ ),
100
127
  )
101
128
  for path_parameter in parameters
102
129
  ]
@@ -104,6 +131,34 @@ class ParameterSerializer:
104
131
  retval.append("}")
105
132
  return retval
106
133
 
134
+ @staticmethod
135
+ def serialize_query_header(
136
+ param: Parameter,
137
+ kwarg_name: str,
138
+ serializer_name: str,
139
+ is_legacy: bool,
140
+ ) -> List[str]:
141
+ if (
142
+ not is_legacy
143
+ and param.location == ParameterLocation.HEADER
144
+ and param.wire_name.lower() in SPECIAL_HEADER_SERIALIZATION
145
+ ):
146
+ return SPECIAL_HEADER_SERIALIZATION[param.wire_name.lower()]
147
+
148
+ set_parameter = "_{}['{}'] = {}".format(
149
+ kwarg_name,
150
+ param.wire_name,
151
+ ParameterSerializer.serialize_parameter(param, serializer_name),
152
+ )
153
+ if not param.optional:
154
+ retval = [set_parameter]
155
+ else:
156
+ retval = [
157
+ f"if {param.full_client_name} is not None:",
158
+ f" {set_parameter}",
159
+ ]
160
+ return retval
161
+
107
162
  @staticmethod
108
163
  def pop_kwargs_from_signature(
109
164
  parameters: Sequence[_ParameterBase],