@autorest/python 6.2.11 → 6.2.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. package/autorest/__init__.py +7 -5
  2. package/autorest/_utils.py +7 -1
  3. package/autorest/black/__init__.py +6 -1
  4. package/autorest/codegen/__init__.py +1 -1
  5. package/autorest/codegen/models/__init__.py +8 -2
  6. package/autorest/codegen/models/base.py +4 -13
  7. package/autorest/codegen/models/client.py +9 -11
  8. package/autorest/codegen/models/code_model.py +3 -3
  9. package/autorest/codegen/models/combined_type.py +4 -3
  10. package/autorest/codegen/models/credential_types.py +7 -14
  11. package/autorest/codegen/models/dictionary_type.py +1 -1
  12. package/autorest/codegen/models/imports.py +3 -3
  13. package/autorest/codegen/models/lro_operation.py +5 -5
  14. package/autorest/codegen/models/model_type.py +89 -47
  15. package/autorest/codegen/models/operation.py +8 -8
  16. package/autorest/codegen/models/operation_group.py +3 -1
  17. package/autorest/codegen/models/paging_operation.py +2 -2
  18. package/autorest/codegen/models/parameter.py +27 -6
  19. package/autorest/codegen/models/parameter_list.py +1 -9
  20. package/autorest/codegen/models/primitive_types.py +1 -1
  21. package/autorest/codegen/models/property.py +15 -3
  22. package/autorest/codegen/models/response.py +2 -2
  23. package/autorest/codegen/serializers/__init__.py +2 -2
  24. package/autorest/codegen/serializers/builder_serializer.py +64 -28
  25. package/autorest/codegen/serializers/client_serializer.py +6 -4
  26. package/autorest/codegen/serializers/general_serializer.py +7 -2
  27. package/autorest/codegen/serializers/model_serializer.py +14 -4
  28. package/autorest/codegen/serializers/sample_serializer.py +2 -6
  29. package/autorest/codegen/templates/config.py.jinja2 +25 -6
  30. package/autorest/codegen/templates/enum.py.jinja2 +2 -2
  31. package/autorest/codegen/templates/metadata.json.jinja2 +18 -9
  32. package/autorest/codegen/templates/model_base.py.jinja2 +74 -63
  33. package/autorest/codegen/templates/model_container.py.jinja2 +2 -2
  34. package/autorest/codegen/templates/model_dpg.py.jinja2 +6 -4
  35. package/autorest/codegen/templates/model_msrest.py.jinja2 +2 -2
  36. package/autorest/codegen/templates/serialization.py.jinja2 +57 -29
  37. package/autorest/codegen/templates/vendor.py.jinja2 +3 -2
  38. package/autorest/jsonrpc/localapi.py +3 -3
  39. package/autorest/jsonrpc/server.py +3 -3
  40. package/autorest/m2r/__init__.py +1 -1
  41. package/autorest/m4reformatter/__init__.py +13 -2
  42. package/autorest/multiapi/models/__init__.py +2 -0
  43. package/autorest/multiapi/models/code_model.py +13 -0
  44. package/autorest/multiapi/models/global_parameter.py +1 -0
  45. package/autorest/multiapi/models/imports.py +3 -3
  46. package/autorest/multiapi/serializers/__init__.py +30 -3
  47. package/autorest/multiapi/templates/multiapi_service_client.py.jinja2 +8 -12
  48. package/autorest/postprocess/get_all.py +3 -1
  49. package/autorest/postprocess/venvtools.py +5 -4
  50. package/autorest/preprocess/__init__.py +19 -7
  51. package/autorest/preprocess/python_mappings.py +1 -0
  52. package/index.js +0 -0
  53. package/package.json +2 -1
  54. package/requirements.txt +6 -6
  55. package/setup.py +0 -1
  56. package/venvtools.py +2 -3
@@ -4,7 +4,7 @@
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
6
  import abc
7
- from enum import Enum, auto
7
+ from enum import Enum
8
8
 
9
9
  from typing import (
10
10
  Dict,
@@ -22,6 +22,8 @@ from .base import BaseModel
22
22
  from .base import BaseType
23
23
  from .constant_type import ConstantType
24
24
  from .utils import add_to_description
25
+ from .combined_type import CombinedType
26
+ from .model_type import JSONModelType
25
27
 
26
28
  if TYPE_CHECKING:
27
29
  from .code_model import CodeModel
@@ -37,10 +39,10 @@ class ParameterLocation(str, Enum):
37
39
  OTHER = "other"
38
40
 
39
41
 
40
- class ParameterMethodLocation(Enum):
41
- POSITIONAL = auto()
42
- KEYWORD_ONLY = auto()
43
- KWARG = auto()
42
+ class ParameterMethodLocation(str, Enum):
43
+ POSITIONAL = "positional"
44
+ KEYWORD_ONLY = "keywordOnly"
45
+ KWARG = "kwarg"
44
46
 
45
47
 
46
48
  class ParameterDelimeter(str, Enum):
@@ -241,6 +243,20 @@ class BodyParameter(_BodyParameterBase):
241
243
  def default_content_type(self) -> str:
242
244
  return self.yaml_data["defaultContentType"]
243
245
 
246
+ @staticmethod
247
+ def _has_json_model_type(t: BaseType) -> bool:
248
+ if isinstance(t, JSONModelType):
249
+ return True
250
+ if isinstance(t, CombinedType):
251
+ for sub_t in t.types:
252
+ if BodyParameter._has_json_model_type(sub_t):
253
+ return True
254
+ return False
255
+
256
+ @property
257
+ def has_json_model_type(self) -> bool:
258
+ return BodyParameter._has_json_model_type(self.type)
259
+
244
260
  @classmethod
245
261
  def from_yaml(
246
262
  cls, yaml_data: Dict[str, Any], code_model: "CodeModel"
@@ -314,6 +330,7 @@ class Parameter(_ParameterBase):
314
330
  self.in_overload: bool = self.yaml_data["inOverload"]
315
331
  self.in_overriden: bool = self.yaml_data.get("inOverriden", False)
316
332
  self.delimiter: Optional[ParameterDelimeter] = self.yaml_data.get("delimiter")
333
+ self.in_flattened_body: bool = self.yaml_data.get("inFlattenedBody", False)
317
334
 
318
335
  @property
319
336
  def in_method_signature(self) -> bool:
@@ -334,9 +351,13 @@ class Parameter(_ParameterBase):
334
351
  return bool(self.rest_api_name) and self.rest_api_name.lower() == "content-type"
335
352
 
336
353
  @property
337
- def method_location(self) -> ParameterMethodLocation:
354
+ def method_location( # pylint: disable=too-many-return-statements
355
+ self,
356
+ ) -> ParameterMethodLocation:
338
357
  if not self.in_method_signature:
339
358
  raise ValueError(f"Parameter '{self.client_name}' is not in the method.")
359
+ if self.code_model.options["models_mode"] == "dpg" and self.in_flattened_body:
360
+ return ParameterMethodLocation.KEYWORD_ONLY
340
361
  if self.grouper:
341
362
  return ParameterMethodLocation.POSITIONAL
342
363
  if self.constant:
@@ -115,7 +115,6 @@ class _ParameterListBase(
115
115
  @abstractmethod
116
116
  def parameter_creator() -> Callable[[Dict[str, Any], "CodeModel"], ParameterType]:
117
117
  """Callable for creating parameters"""
118
- ...
119
118
 
120
119
  @staticmethod
121
120
  @abstractmethod
@@ -123,7 +122,6 @@ class _ParameterListBase(
123
122
  [Dict[str, Any], "CodeModel"], BodyParameterType
124
123
  ]:
125
124
  """Callable for creating body parameters"""
126
- ...
127
125
 
128
126
  @property
129
127
  def grouped(self) -> List[Union[ParameterType, BodyParameterType]]:
@@ -208,7 +206,6 @@ class _ParameterListBase(
208
206
  @abstractmethod
209
207
  def implementation(self) -> str:
210
208
  """Whether this is a client or a method parameter"""
211
- ...
212
209
 
213
210
  @property
214
211
  def unsorted_method_params(self) -> List[Union[ParameterType, BodyParameterType]]:
@@ -339,8 +336,6 @@ class _ParameterList(
339
336
  class ParameterList(_ParameterList):
340
337
  """ParameterList is the parameter list for Operation classes"""
341
338
 
342
- ...
343
-
344
339
 
345
340
  class _RequestBuilderParameterList(
346
341
  _ParameterListBase[ # pylint: disable=unsubscriptable-object
@@ -406,8 +401,6 @@ class _RequestBuilderParameterList(
406
401
  class RequestBuilderParameterList(_RequestBuilderParameterList):
407
402
  """Parameter list for Request Builder"""
408
403
 
409
- ...
410
-
411
404
 
412
405
  class OverloadedRequestBuilderParameterList(_RequestBuilderParameterList):
413
406
  """Parameter list for OverloadedRequestBuilder"""
@@ -423,8 +416,7 @@ class OverloadedRequestBuilderParameterList(_RequestBuilderParameterList):
423
416
  ]
424
417
 
425
418
 
426
- class _ClientGlobalParameterList(
427
- # pylint: disable=unsubscriptable-object
419
+ class _ClientGlobalParameterList( # pylint: disable=abstract-method
428
420
  _ParameterListBase[ParameterType, BodyParameter]
429
421
  ):
430
422
  """Base parameter list for client and config classes"""
@@ -133,7 +133,7 @@ class BinaryIteratorType(PrimitiveType):
133
133
 
134
134
  @property
135
135
  def default_template_representation_declaration(self) -> str:
136
- return self.get_declaration(f"Iterator[bytes]")
136
+ return self.get_declaration("Iterator[bytes]")
137
137
 
138
138
  def imports(self, **kwargs: Any) -> FileImport:
139
139
  file_import = FileImport()
@@ -129,13 +129,25 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
129
129
  retval.update(self.type.validation or {})
130
130
  return retval or None
131
131
 
132
- def imports(self, **kwargs) -> FileImport:
133
- from .model_type import ModelType
132
+ @staticmethod
133
+ def contain_model_type(t: BaseType) -> bool:
134
+ from . import ListType, DictionaryType, ModelType
135
+
136
+ if isinstance(t, ModelType):
137
+ return True
138
+ if isinstance(t, ListType):
139
+ return Property.contain_model_type(t.element_type)
140
+ if isinstance(t, DictionaryType):
141
+ return Property.contain_model_type(t.element_type)
142
+ if isinstance(t, ConstantType):
143
+ return Property.contain_model_type(t.value_type)
144
+ return False
134
145
 
146
+ def imports(self, **kwargs) -> FileImport:
135
147
  file_import = self.type.imports(**kwargs, is_operation_file=False)
136
148
  if self.optional and self.client_default_value is None:
137
149
  file_import.add_submodule_import("typing", "Optional", ImportType.STDLIB)
138
- if isinstance(self.type, ModelType):
150
+ if self.contain_model_type(self.type):
139
151
  file_import.add_submodule_import(
140
152
  "..",
141
153
  "models",
@@ -49,12 +49,12 @@ class Response(BaseModel):
49
49
  yaml_data: Dict[str, Any],
50
50
  code_model: "CodeModel",
51
51
  *,
52
- headers: List[ResponseHeader] = [],
52
+ headers: Optional[List[ResponseHeader]] = None,
53
53
  type: Optional[BaseType] = None,
54
54
  ) -> None:
55
55
  super().__init__(yaml_data=yaml_data, code_model=code_model)
56
56
  self.status_codes: List[Union[int, str]] = yaml_data["statusCodes"]
57
- self.headers = headers
57
+ self.headers = headers or []
58
58
  self.type = type
59
59
  self.nullable = yaml_data.get("nullable")
60
60
 
@@ -462,7 +462,7 @@ class JinjaSerializer(ReaderAndWriter): # pylint: disable=abstract-method
462
462
  if self.code_model.need_vendored_code(async_mode=False):
463
463
  self.write_file(
464
464
  namespace_path / Path("_vendor.py"),
465
- general_serializer.serialize_vendor_file(),
465
+ general_serializer.serialize_vendor_file(clients),
466
466
  )
467
467
 
468
468
  self._serialize_and_write_version_file(namespace_path, general_serializer)
@@ -521,7 +521,7 @@ class JinjaSerializer(ReaderAndWriter): # pylint: disable=abstract-method
521
521
  if self.code_model.need_vendored_code(async_mode=True):
522
522
  self.write_file(
523
523
  aio_path / Path("_vendor.py"),
524
- aio_general_serializer.serialize_vendor_file(),
524
+ aio_general_serializer.serialize_vendor_file(clients),
525
525
  )
526
526
 
527
527
  def _serialize_and_write_metadata(
@@ -31,6 +31,8 @@ from ..models import (
31
31
  MultipartBodyParameter,
32
32
  Property,
33
33
  RequestBuilderType,
34
+ JSONModelType,
35
+ CombinedType,
34
36
  )
35
37
  from .parameter_serializer import ParameterSerializer, PopKwargType
36
38
  from . import utils
@@ -146,6 +148,25 @@ def _serialize_flattened_body(body_parameter: BodyParameter) -> List[str]:
146
148
  return retval
147
149
 
148
150
 
151
+ def _serialize_json_model_body(body_parameter: BodyParameter) -> List[str]:
152
+ retval: List[str] = []
153
+ if not body_parameter.property_to_parameter_name:
154
+ raise ValueError(
155
+ "This method can't be called if the operation doesn't need parameter flattening"
156
+ )
157
+
158
+ retval.append(f"if {body_parameter.client_name} is None:")
159
+ parameter_string = ", \n".join(
160
+ f'"{property_name}": {parameter_name}'
161
+ for property_name, parameter_name in body_parameter.property_to_parameter_name.items()
162
+ )
163
+ model_type = cast(ModelType, body_parameter.type)
164
+ if isinstance(model_type, CombinedType):
165
+ model_type = next(t for t in model_type.types if isinstance(t, JSONModelType))
166
+ retval.append(f" {body_parameter.client_name} = {{{parameter_string}}}")
167
+ return retval
168
+
169
+
149
170
  def _serialize_multipart_body(builder: BuilderType) -> List[str]:
150
171
  retval: List[str] = []
151
172
  body_param = cast(MultipartBodyParameter, builder.parameters.body_parameter)
@@ -184,7 +205,7 @@ def _api_version_validation(builder: OperationType) -> str:
184
205
  retval.append(f" params_added_on={dict(params_added_on)},")
185
206
  if retval:
186
207
  retval_str = "\n".join(retval)
187
- return f"@api_version_validation(\n{retval_str}\n)"
208
+ return f"@api_version_validation(\n{retval_str}\n){builder.pylint_disable}"
188
209
  return ""
189
210
 
190
211
 
@@ -203,23 +224,20 @@ class _BuilderBaseSerializer(Generic[BuilderType]): # pylint: disable=abstract-
203
224
  @abstractmethod
204
225
  def _function_def(self) -> str:
205
226
  """The def keyword for the builder we're serializing, i.e. 'def' or 'async def'"""
206
- ...
207
227
 
208
228
  @property
209
229
  @abstractmethod
210
230
  def _call_method(self) -> str:
211
231
  """How to call network calls. Await if we have to await network calls"""
212
- ...
213
232
 
214
233
  @property
215
234
  @abstractmethod
216
235
  def serializer_name(self) -> str:
217
- ...
236
+ """Name of serializer"""
218
237
 
219
238
  @abstractmethod
220
239
  def response_docstring(self, builder: BuilderType) -> List[str]:
221
240
  """Response portion of the docstring"""
222
- ...
223
241
 
224
242
  def decorators(self, builder: BuilderType) -> List[str]:
225
243
  """Decorators for the method"""
@@ -258,9 +276,7 @@ class _BuilderBaseSerializer(Generic[BuilderType]): # pylint: disable=abstract-
258
276
  )
259
277
  )
260
278
 
261
- def description_and_summary( # pylint: disable=no-self-use
262
- self, builder: BuilderType
263
- ) -> List[str]:
279
+ def description_and_summary(self, builder: BuilderType) -> List[str]:
264
280
  description_list: List[str] = []
265
281
  description_list.append(
266
282
  f"{ builder.summary.strip() if builder.summary else builder.description.strip() }"
@@ -280,9 +296,7 @@ class _BuilderBaseSerializer(Generic[BuilderType]): # pylint: disable=abstract-
280
296
  template += self._json_input_example_template(builder)
281
297
  return template
282
298
 
283
- def param_description( # pylint: disable=no-self-use
284
- self, builder: BuilderType
285
- ) -> List[str]:
299
+ def param_description(self, builder: BuilderType) -> List[str]:
286
300
  description_list: List[str] = []
287
301
  for param in builder.parameters.method:
288
302
  if not param.in_docstring:
@@ -314,9 +328,6 @@ class _BuilderBaseSerializer(Generic[BuilderType]): # pylint: disable=abstract-
314
328
 
315
329
  def _json_input_example_template(self, builder: BuilderType) -> List[str]:
316
330
  template: List[str] = []
317
- if self.code_model.options["models_mode"]:
318
- # No input template if we have models
319
- return template
320
331
  if (
321
332
  not builder.parameters.has_body
322
333
  or builder.parameters.body_parameter.flattened
@@ -331,6 +342,15 @@ class _BuilderBaseSerializer(Generic[BuilderType]): # pylint: disable=abstract-
331
342
  if not isinstance(body_param.type, (ListType, DictionaryType, ModelType)):
332
343
  return template
333
344
 
345
+ if (
346
+ isinstance(body_param.type, (ListType, DictionaryType))
347
+ and self.code_model.options["models_mode"]
348
+ ):
349
+ return template
350
+
351
+ if isinstance(body_param.type, ModelType) and body_param.type.base != "json":
352
+ return template
353
+
334
354
  polymorphic_subtypes: List[ModelType] = []
335
355
  body_param.type.get_polymorphic_subtypes(polymorphic_subtypes)
336
356
  if polymorphic_subtypes:
@@ -440,11 +460,11 @@ class RequestBuilderSerializer(
440
460
 
441
461
  def response_docstring(self, builder: RequestBuilderType) -> List[str]:
442
462
  response_str = (
443
- f":return: Returns an :class:`~azure.core.rest.HttpRequest` that you will pass to the client's "
463
+ ":return: Returns an :class:`~azure.core.rest.HttpRequest` that you will pass to the client's "
444
464
  + "`send_request` method. See https://aka.ms/azsdk/dpcodegen/python/send_request for how to "
445
465
  + "incorporate this response into your code flow."
446
466
  )
447
- rtype_str = f":rtype: ~azure.core.rest.HttpRequest"
467
+ rtype_str = ":rtype: ~azure.core.rest.HttpRequest"
448
468
  return [response_str, rtype_str]
449
469
 
450
470
  def pop_kwargs_from_signature(self, builder: RequestBuilderType) -> List[str]:
@@ -612,9 +632,7 @@ class _OperationSerializer(
612
632
  retval.append(_api_version_validation(builder))
613
633
  return retval
614
634
 
615
- def param_description(
616
- self, builder: OperationType
617
- ) -> List[str]: # pylint: disable=no-self-use
635
+ def param_description(self, builder: OperationType) -> List[str]:
618
636
  description_list = super().param_description(builder)
619
637
  if not self.code_model.options["version_tolerant"]:
620
638
  description_list.append(
@@ -683,7 +701,10 @@ class _OperationSerializer(
683
701
  f"'{body_param.type.serialization_type}'{is_xml_cmd}{serialization_ctxt_cmd})"
684
702
  )
685
703
  elif self.code_model.options["models_mode"] == "dpg":
686
- create_body_call = f"_{body_kwarg_name} = json.dumps({body_param.client_name}, cls=AzureJSONEncoder)"
704
+ create_body_call = (
705
+ f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
706
+ "cls=AzureJSONEncoder) # type: ignore"
707
+ )
687
708
  else:
688
709
  create_body_call = f"_{body_kwarg_name} = {body_param.client_name}"
689
710
  if body_param.optional:
@@ -885,7 +906,7 @@ class _OperationSerializer(
885
906
  retval.append(f" template_url={template_url},")
886
907
  retval.append(" headers=_headers,")
887
908
  retval.append(" params=_params,")
888
- retval.append(f")")
909
+ retval.append(")")
889
910
  return retval
890
911
 
891
912
  def _postprocess_http_request(
@@ -928,6 +949,12 @@ class _OperationSerializer(
928
949
  if builder.parameters.has_body and builder.parameters.body_parameter.flattened:
929
950
  # unflatten before passing to request builder as well
930
951
  retval.extend(_serialize_flattened_body(builder.parameters.body_parameter))
952
+ if (
953
+ builder.parameters.has_body
954
+ and builder.parameters.body_parameter.has_json_model_type
955
+ and any(p.in_flattened_body for p in builder.parameters.parameters)
956
+ ):
957
+ retval.extend(_serialize_json_model_body(builder.parameters.body_parameter))
931
958
  if builder.overloads:
932
959
  # we are only dealing with two overloads. If there are three, we generate an abstract operation
933
960
  retval.extend(self._initialize_overloads(builder, is_paging=is_paging))
@@ -987,10 +1014,10 @@ class _OperationSerializer(
987
1014
  if response.type.is_xml
988
1015
  else "response.json()"
989
1016
  )
990
- retval.append(f"if response.content:")
1017
+ retval.append("if response.content:")
991
1018
  retval.append(f" deserialized = {deserialized_value}")
992
1019
  retval.append("else:")
993
- retval.append(f" deserialized = None")
1020
+ retval.append(" deserialized = None")
994
1021
  return retval
995
1022
 
996
1023
  def handle_error_response(self, builder: OperationType) -> List[str]:
@@ -1298,11 +1325,20 @@ class _PagingOperationSerializer(
1298
1325
  deserialize_type = f'"{response.serialization_type}"'
1299
1326
  pylint_disable = ""
1300
1327
  deserialized = f"self._deserialize(\n {deserialize_type}, pipeline_response{pylint_disable}\n)"
1328
+ retval.append(f" deserialized = {deserialized}")
1301
1329
  elif self.code_model.options["models_mode"] == "dpg":
1302
- deserialized = (
1303
- f"_deserialize({response.serialization_type}, pipeline_response)"
1330
+ pylint_disable = (
1331
+ " # pylint: disable=protected-access\n"
1332
+ if isinstance(response.type, ModelType) and not response.type.is_public
1333
+ else ""
1334
+ )
1335
+ deserialized = f"_deserialize({response.serialization_type}{pylint_disable}, pipeline_response)"
1336
+ retval.append(
1337
+ f" deserialized: {response.serialization_type} = ({pylint_disable}"
1304
1338
  )
1305
- retval.append(f" deserialized = {deserialized}")
1339
+ retval.append(f" {deserialized})")
1340
+ else:
1341
+ retval.append(f" deserialized = {deserialized}")
1306
1342
  item_name = builder.item_name
1307
1343
  list_of_elem = (
1308
1344
  f".{item_name}"
@@ -1447,7 +1483,7 @@ class _LROOperationSerializer(_OperationSerializer[LROOperationType]):
1447
1483
  f" {lro_options_str}",
1448
1484
  f" {path_format_arguments_str}",
1449
1485
  " **kwargs",
1450
- f"))",
1486
+ "))",
1451
1487
  ]
1452
1488
  )
1453
1489
  retval.append(
@@ -1504,7 +1540,7 @@ class _LROOperationSerializer(_OperationSerializer[LROOperationType]):
1504
1540
  " # type: ignore"
1505
1541
  if builder.lro_response
1506
1542
  and builder.lro_response.type
1507
- and not self.code_model.options["models_mode"]
1543
+ and self.code_model.options["models_mode"] != "msrest"
1508
1544
  else "",
1509
1545
  )
1510
1546
  )
@@ -137,7 +137,9 @@ class ClientSerializer:
137
137
  f" = {_get_client_models_value(model_dict_name)}"
138
138
  )
139
139
  if add_private_models and self.client.code_model.model_types:
140
- update_dict = f"{{k: v for k, v in _models.__dict__.items() if isinstance(v, type)}}"
140
+ update_dict = (
141
+ "{k: v for k, v in _models.__dict__.items() if isinstance(v, type)}"
142
+ )
141
143
  retval.append(f"client_models.update({update_dict})")
142
144
  client_models_str = "client_models" if is_msrest_model else ""
143
145
  retval.append(f"self._serialize = Serializer({client_models_str})")
@@ -221,7 +223,7 @@ class ClientSerializer:
221
223
  def _rest_request_example(self, async_mode: bool) -> List[str]:
222
224
  retval = [">>> from azure.core.rest import HttpRequest"]
223
225
  retval.append('>>> request = HttpRequest("GET", "https://www.example.org/")')
224
- retval.append(f"<HttpRequest [GET], url: 'https://www.example.org/'>")
226
+ retval.append("<HttpRequest [GET], url: 'https://www.example.org/'>")
225
227
  retval.extend(self._example_make_call(async_mode))
226
228
  return retval
227
229
 
@@ -236,9 +238,9 @@ class ClientSerializer:
236
238
  retval.append(
237
239
  "For more information on this code flow, see https://aka.ms/azsdk/dpcodegen/python/send_request"
238
240
  )
239
- retval.append(f"")
241
+ retval.append("")
240
242
  retval.append(":param request: The network request you want to make. Required.")
241
- retval.append(f":type request: ~azure.core.rest.HttpRequest")
243
+ retval.append(":type request: ~azure.core.rest.HttpRequest")
242
244
  retval.append(
243
245
  ":keyword bool stream: Whether the response payload will be streamed. Defaults to False."
244
246
  )
@@ -98,7 +98,7 @@ class GeneralSerializer:
98
98
  imports=FileImportSerializer(imports),
99
99
  )
100
100
 
101
- def serialize_vendor_file(self) -> str:
101
+ def serialize_vendor_file(self, clients: List[Client]) -> str:
102
102
  template = self.env.get_template("vendor.py.jinja2")
103
103
 
104
104
  # configure imports
@@ -110,6 +110,10 @@ class GeneralSerializer:
110
110
  ImportType.AZURECORE,
111
111
  )
112
112
 
113
+ if self.code_model.need_format_url and not self.async_mode:
114
+ file_import.add_submodule_import("typing", "List", ImportType.STDLIB)
115
+ file_import.add_submodule_import("typing", "cast", ImportType.STDLIB)
116
+
113
117
  if self.code_model.need_mixin_abc:
114
118
  file_import.add_submodule_import(
115
119
  "abc",
@@ -128,7 +132,7 @@ class GeneralSerializer:
128
132
  MsrestImportType.SerializerDeserializer,
129
133
  TypingSection.TYPING,
130
134
  )
131
- for client in self.code_model.clients:
135
+ for client in clients:
132
136
  file_import.add_submodule_import(
133
137
  "._configuration",
134
138
  f"{client.name}Configuration",
@@ -141,6 +145,7 @@ class GeneralSerializer:
141
145
  file_import,
142
146
  ),
143
147
  async_mode=self.async_mode,
148
+ clients=clients,
144
149
  )
145
150
 
146
151
  def serialize_config_file(self, clients: List[Client]) -> str:
@@ -208,14 +208,19 @@ class DpgModelSerializer(_ModelSerializer):
208
208
  ImportType.LOCAL,
209
209
  TypingSection.REGULAR,
210
210
  )
211
- file_import.add_submodule_import("typing", "overload", ImportType.STDLIB)
212
- file_import.add_submodule_import("typing", "Mapping", ImportType.STDLIB)
213
- file_import.add_submodule_import("typing", "Any", ImportType.STDLIB)
214
211
 
215
212
  for model in self.code_model.model_types:
216
213
  file_import.merge(model.imports(is_operation_file=False))
217
214
  for prop in model.properties:
218
215
  file_import.merge(prop.imports())
216
+ if model.is_polymorphic:
217
+ file_import.add_submodule_import("typing", "Dict", ImportType.STDLIB)
218
+ if model.is_public and self.init_line(model):
219
+ file_import.add_submodule_import(
220
+ "typing", "overload", ImportType.STDLIB
221
+ )
222
+ file_import.add_submodule_import("typing", "Mapping", ImportType.STDLIB)
223
+ file_import.add_submodule_import("typing", "Any", ImportType.STDLIB)
219
224
  return file_import
220
225
 
221
226
  def declare_model(self, model: ModelType) -> str:
@@ -256,9 +261,14 @@ class DpgModelSerializer(_ModelSerializer):
256
261
  args.append(f"default={prop.client_default_value_declaration}")
257
262
 
258
263
  field = "rest_discriminator" if prop.is_discriminator else "rest_field"
264
+ type_ignore = (
265
+ prop.is_discriminator
266
+ and prop.is_discriminator
267
+ and cast(ConstantType, prop.type).value
268
+ )
259
269
  ret = [
260
270
  f"{prop.client_name}: {prop.type_annotation()} ="
261
- f' {field}({", ".join(args)})'
271
+ f' {field}({", ".join(args)}){" # type: ignore" if type_ignore else ""}'
262
272
  ]
263
273
  comment = prop.description(is_operation_file=False).replace('"', '\\"')
264
274
  if comment:
@@ -65,7 +65,7 @@ class SampleSerializer:
65
65
 
66
66
  def _client_params(self) -> Dict[str, Any]:
67
67
  # client params
68
- special_param = dict()
68
+ special_param = {}
69
69
  credential_type = getattr(self.code_model.clients[0].credential, "type", None)
70
70
  if isinstance(credential_type, TokenCredentialType):
71
71
  special_param.update({"credential": "DefaultAzureCredential()"})
@@ -79,14 +79,10 @@ class SampleSerializer:
79
79
  for p in self.code_model.clients[0].parameters.positional
80
80
  if not (p.optional or p.client_default_value)
81
81
  ]
82
- cls = lambda x: f'"{x}"'
83
82
  client_params = {
84
83
  p.client_name: special_param.get(
85
84
  p.client_name,
86
- cls(
87
- self.sample["parameters"].get(p.rest_api_name)
88
- or p.client_name.upper()
89
- ),
85
+ f'"{self.sample["parameters"].get(p.rest_api_name) or p.client_name.upper()}"',
90
86
  )
91
87
  for p in params_positional
92
88
  }
@@ -22,12 +22,27 @@ class {{ client.name }}Configuration(Configuration): # pylint: disable=too-many
22
22
  {% if serializer.set_constants() %}
23
23
  {{ op_tools.serialize(serializer.set_constants()) | indent(8) -}}
24
24
  {% endif %}
25
- {% if client.credential and client.credential.type.policy.credential_scopes is defined %}
26
- self.credential_scopes = kwargs.pop('credential_scopes', {{ client.credential.type.policy.credential_scopes }})
25
+ {% if client.credential %}
26
+ {% set cred_scopes = client.credential.type if client.credential.type.policy is defined and client.credential.type.policy.credential_scopes is defined %}
27
+ {% if not cred_scopes %}
28
+ {% set cred_scopes = client.credential.type.types | selectattr("policy.credential_scopes") | first if client.credential.type.types is defined %}
29
+ {% endif %}
30
+ {% if cred_scopes %}
31
+ self.credential_scopes = kwargs.pop('credential_scopes', {{ cred_scopes.policy.credential_scopes }})
32
+ {% endif %}
27
33
  {% endif %}
28
34
  kwargs.setdefault('sdk_moniker', '{{ client.config.sdk_moniker }}/{}'.format(VERSION))
29
35
  self._configure(**kwargs)
30
36
 
37
+ {% if client.credential and client.credential.type.types is defined %}
38
+ def _infer_policy(self, **kwargs):
39
+ {% for cred_type in client.credential.type.types %}
40
+ if {{ cred_type.instance_check_template.format("self.credential") }}:
41
+ return {{ cred_type.policy.call(async_mode) }}
42
+ {% endfor %}
43
+ raise TypeError(f"Unsupported credential: {self.credential}")
44
+ {% endif %}
45
+
31
46
  def _configure(
32
47
  self,
33
48
  **kwargs: Any
@@ -41,12 +56,16 @@ class {{ client.name }}Configuration(Configuration): # pylint: disable=too-many
41
56
  self.custom_hook_policy = kwargs.get('custom_hook_policy') or policies.CustomHookPolicy(**kwargs)
42
57
  self.redirect_policy = kwargs.get('redirect_policy') or policies.{{ keywords.async_class }}RedirectPolicy(**kwargs)
43
58
  self.authentication_policy = kwargs.get('authentication_policy')
44
- {% if client.credential %}
45
- {# only adding this if credential_scopes is not passed during code generation #}
46
- {% if client.credential.type.policy.credential_scopes is defined and client.credential.type.policy.credential_scopes | length == 0 %}
59
+ {% if client.credential and client.credential.type.policy is defined %}
60
+ {# only adding this if credential_scopes is not passed during code generation #}
61
+ {% if client.credential.type.policy.credential_scopes is defined and client.credential.type.policy.credential_scopes | length == 0 %}
47
62
  if not self.credential_scopes and not self.authentication_policy:
48
63
  raise ValueError("You must provide either credential_scopes or authentication_policy as kwargs")
49
- {% endif %}
64
+ {% endif %}
50
65
  if self.credential and not self.authentication_policy:
51
66
  self.authentication_policy = {{ client.credential.type.policy.call(async_mode) }}
52
67
  {% endif %}
68
+ {% if client.credential and client.credential.type.types is defined %}
69
+ if self.credential and not self.authentication_policy:
70
+ self.authentication_policy = self._infer_policy(**kwargs)
71
+ {% endif %}
@@ -6,8 +6,8 @@ class {{ enum.name }}({{ enum.value_type.type_annotation(is_operation_file=False
6
6
  {% endif %}
7
7
 
8
8
  {% for value in enum.values %}
9
+ {{ value.name }} = {{ enum.value_type.get_declaration(value.value) }}
9
10
  {% if value.description %}
10
- #: {{ value.description | wordwrap(width=95, break_long_words=False, break_on_hyphens=False, wrapstring='\n #: ') }}
11
+ """{{ value.description | wordwrap(width=95, break_long_words=False, break_on_hyphens=False, wrapstring='\n #: ') }}"""
11
12
  {% endif %}
12
- {{ value.name }} = {{ enum.value_type.get_declaration(value.value) }}
13
13
  {% endfor %}