@autorest/python 6.2.4 → 6.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/autorest/black/__init__.py +3 -2
  2. package/autorest/codegen/__init__.py +1 -1
  3. package/autorest/codegen/models/base.py +14 -2
  4. package/autorest/codegen/models/client.py +20 -22
  5. package/autorest/codegen/models/combined_type.py +1 -1
  6. package/autorest/codegen/models/imports.py +2 -2
  7. package/autorest/codegen/models/model_type.py +4 -0
  8. package/autorest/codegen/models/operation.py +6 -6
  9. package/autorest/codegen/models/operation_group.py +12 -11
  10. package/autorest/codegen/models/primitive_types.py +50 -0
  11. package/autorest/codegen/models/property.py +4 -0
  12. package/autorest/codegen/models/request_builder.py +9 -7
  13. package/autorest/codegen/serializers/__init__.py +7 -6
  14. package/autorest/codegen/serializers/builder_serializer.py +67 -30
  15. package/autorest/codegen/serializers/client_serializer.py +22 -8
  16. package/autorest/codegen/serializers/metadata_serializer.py +7 -1
  17. package/autorest/codegen/serializers/model_serializer.py +12 -13
  18. package/autorest/codegen/serializers/operation_groups_serializer.py +1 -0
  19. package/autorest/codegen/serializers/parameter_serializer.py +3 -3
  20. package/autorest/codegen/serializers/patch_serializer.py +2 -4
  21. package/autorest/codegen/serializers/sample_serializer.py +23 -14
  22. package/autorest/codegen/serializers/utils.py +6 -0
  23. package/autorest/codegen/templates/client.py.jinja2 +3 -12
  24. package/autorest/codegen/templates/config.py.jinja2 +2 -5
  25. package/autorest/codegen/templates/keywords.jinja2 +2 -2
  26. package/autorest/codegen/templates/metadata.json.jinja2 +2 -2
  27. package/autorest/codegen/templates/model_base.py.jinja2 +171 -130
  28. package/autorest/codegen/templates/model_dpg.py.jinja2 +1 -1
  29. package/autorest/codegen/templates/packaging_templates/setup.py.jinja2 +1 -0
  30. package/autorest/codegen/templates/request_builder.py.jinja2 +1 -1
  31. package/autorest/codegen/templates/serialization.py.jinja2 +286 -325
  32. package/autorest/jsonrpc/__init__.py +3 -1
  33. package/autorest/jsonrpc/localapi.py +3 -1
  34. package/autorest/jsonrpc/stdstream.py +1 -1
  35. package/autorest/m2r/__init__.py +2 -2
  36. package/autorest/multiapi/models/imports.py +34 -22
  37. package/autorest/multiapi/serializers/import_serializer.py +1 -1
  38. package/autorest/multiapi/templates/multiapi_config.py.jinja2 +2 -8
  39. package/autorest/multiapi/templates/multiapi_service_client.py.jinja2 +1 -1
  40. package/autorest/postprocess/__init__.py +5 -4
  41. package/autorest/preprocess/__init__.py +7 -1
  42. package/autorest/preprocess/helpers.py +14 -2
  43. package/autorest/preprocess/python_mappings.py +27 -0
  44. package/package.json +2 -2
  45. package/setup.py +3 -0
@@ -9,7 +9,6 @@ from abc import abstractmethod
9
9
  from collections import defaultdict
10
10
  from typing import Any, Generic, List, Type, TypeVar, Dict, Union, Optional, cast
11
11
 
12
-
13
12
  from ..models import (
14
13
  Operation,
15
14
  PagingOperation,
@@ -584,8 +583,10 @@ class _OperationSerializer(
584
583
  return retval
585
584
 
586
585
  def make_pipeline_call(self, builder: OperationType) -> List[str]:
586
+ type_ignore = self.async_mode and builder.group_name == "" # is in a mixin
587
587
  return [
588
- f"pipeline_response = {self._call_method}self._client._pipeline.run( # type: ignore # pylint: disable=protected-access",
588
+ f"pipeline_response: PipelineResponse = {self._call_method}self._client._pipeline.run( "
589
+ + f"{'# type: ignore' if type_ignore else ''} # pylint: disable=protected-access",
589
590
  " request,",
590
591
  f" stream={builder.has_stream_response},",
591
592
  " **kwargs",
@@ -628,24 +629,22 @@ class _OperationSerializer(
628
629
  check_kwarg_dict=True,
629
630
  pop_headers_kwarg=PopKwargType.CASE_INSENSITIVE
630
631
  if builder.has_kwargs_to_pop_with_default(
631
- kwargs_to_pop, ParameterLocation.HEADER
632
+ kwargs_to_pop, ParameterLocation.HEADER # type: ignore
632
633
  )
633
634
  else PopKwargType.SIMPLE,
634
635
  pop_params_kwarg=PopKwargType.CASE_INSENSITIVE
635
636
  if builder.has_kwargs_to_pop_with_default(
636
- kwargs_to_pop, ParameterLocation.QUERY
637
+ kwargs_to_pop, ParameterLocation.QUERY # type: ignore
637
638
  )
638
639
  else PopKwargType.SIMPLE,
639
640
  check_client_input=not self.code_model.options["multiapi"],
640
641
  )
641
- kwargs.append(
642
- f"cls = kwargs.pop('cls', None) {self.cls_type_annotation(builder)}"
643
- )
642
+ cls_annotation = builder.cls_type_annotation(async_mode=self.async_mode)
643
+ kwargs.append(f"cls: {cls_annotation} = kwargs.pop('cls', None)")
644
+ if any(x.startswith("_") for x in cls_annotation.split(".")):
645
+ kwargs[-1] += " # pylint: disable=protected-access"
644
646
  return kwargs
645
647
 
646
- def cls_type_annotation(self, builder: OperationType) -> str:
647
- return f"# type: {builder.cls_type_annotation(async_mode=self.async_mode)}"
648
-
649
648
  def response_docstring(self, builder: OperationType) -> List[str]:
650
649
  response_str = (
651
650
  f":return: {builder.response_docstring_text(async_mode=self.async_mode)}"
@@ -848,8 +847,18 @@ class _OperationSerializer(
848
847
  # in paging operations with a single swagger operation defintion,
849
848
  # we skip passing query params when building the next request
850
849
  continue
850
+ type_ignore = (
851
+ parameter.grouped_by
852
+ and parameter.client_default_value is not None
853
+ and next(
854
+ p
855
+ for p in builder.parameters
856
+ if p.grouper and p.client_name == parameter.grouped_by
857
+ ).optional
858
+ )
851
859
  retval.append(
852
860
  f" {parameter.client_name}={parameter.name_in_high_level_operation},"
861
+ f"{' # type: ignore' if type_ignore else ''}"
853
862
  )
854
863
  if request_builder.overloads:
855
864
  seen_body_params = set()
@@ -897,7 +906,7 @@ class _OperationSerializer(
897
906
  if self.code_model.options["version_tolerant"] and template_url:
898
907
  url_to_format = template_url
899
908
  retval.append(
900
- "request.url = self._client.format_url({}{}) # type: ignore".format(
909
+ "request.url = self._client.format_url({}{})".format(
901
910
  url_to_format,
902
911
  ", **path_format_arguments" if builder.parameters.path else "",
903
912
  )
@@ -967,11 +976,10 @@ class _OperationSerializer(
967
976
  retval.append(
968
977
  f"deserialized = self._deserialize('{response.serialization_type}', pipeline_response)"
969
978
  )
970
- elif self.code_model.options["models_mode"] == "dpg" and isinstance(
971
- response.type, ModelType
972
- ):
979
+ elif self.code_model.options["models_mode"] == "dpg":
973
980
  retval.append(
974
- f"deserialized = _deserialize({response.serialization_type}, response.json())"
981
+ f"deserialized = _deserialize({response.type.type_annotation(is_operation_file=True)}"
982
+ ", response.json())"
975
983
  )
976
984
  else:
977
985
  deserialized_value = (
@@ -1046,20 +1054,41 @@ class _OperationSerializer(
1046
1054
  self.response_headers_and_deserialization(builder.responses[0])
1047
1055
  )
1048
1056
  retval.append("")
1057
+ type_ignore = (
1058
+ builder.has_response_body
1059
+ and not builder.has_optional_return_type
1060
+ and not (
1061
+ self.code_model.options["models_mode"] == "msrest"
1062
+ and any(not resp.is_stream_response for resp in builder.responses)
1063
+ )
1064
+ )
1049
1065
  if builder.has_optional_return_type or self.code_model.options["models_mode"]:
1050
1066
  deserialized = "deserialized"
1051
1067
  else:
1052
1068
  deserialized = f"cast({builder.response_type_annotation(async_mode=self.async_mode)}, deserialized)"
1069
+ type_ignore = False
1070
+ if (
1071
+ not builder.has_optional_return_type
1072
+ and len(builder.responses) > 1
1073
+ and any(resp.is_stream_response or resp.type for resp in builder.responses)
1074
+ ):
1075
+ type_ignore = True
1053
1076
  retval.append("if cls:")
1054
1077
  retval.append(
1055
- " return cls(pipeline_response, {}, {})".format(
1078
+ " return cls(pipeline_response, {}, {}){}".format(
1056
1079
  deserialized if builder.has_response_body else "None",
1057
1080
  "response_headers" if builder.any_response_has_headers else "{}",
1081
+ " # type: ignore" if type_ignore else "",
1058
1082
  )
1059
1083
  )
1060
- if builder.has_response_body:
1084
+ if builder.has_response_body and any(
1085
+ response.is_stream_response or response.type
1086
+ for response in builder.responses
1087
+ ):
1061
1088
  retval.append("")
1062
- retval.append(f"return {deserialized}")
1089
+ retval.append(
1090
+ f"return {deserialized}{' # type: ignore' if type_ignore else ''}"
1091
+ )
1063
1092
  if (
1064
1093
  builder.request_builder.method == "HEAD"
1065
1094
  and self.code_model.options["head_as_boolean"]
@@ -1133,7 +1162,7 @@ class _OperationSerializer(
1133
1162
  @staticmethod
1134
1163
  def get_metadata_url(builder: OperationType) -> str:
1135
1164
  url = _escape_str(builder.request_builder.url)
1136
- return f"{builder.name}.metadata = {{'url': { url }}} # type: ignore"
1165
+ return f"{builder.name}.metadata = {{'url': { url }}}"
1137
1166
 
1138
1167
  @property
1139
1168
  def _call_method(self) -> str:
@@ -1282,7 +1311,7 @@ class _PagingOperationSerializer(
1282
1311
  )
1283
1312
  retval.append(f" list_of_elem = deserialized{list_of_elem}")
1284
1313
  retval.append(" if cls:")
1285
- retval.append(" list_of_elem = cls(list_of_elem)")
1314
+ retval.append(" list_of_elem = cls(list_of_elem) # type: ignore")
1286
1315
 
1287
1316
  continuation_token_name = builder.continuation_token_name
1288
1317
  if not continuation_token_name:
@@ -1290,7 +1319,9 @@ class _PagingOperationSerializer(
1290
1319
  elif self.code_model.options["models_mode"]:
1291
1320
  cont_token_property = f"deserialized.{continuation_token_name} or None"
1292
1321
  else:
1293
- cont_token_property = f'deserialized.get("{continuation_token_name}", None)'
1322
+ cont_token_property = (
1323
+ f'deserialized.get("{continuation_token_name}") or None'
1324
+ )
1294
1325
  list_type = "AsyncList" if self.async_mode else "iter"
1295
1326
  retval.append(f" return {cont_token_property}, {list_type}(list_of_elem)")
1296
1327
  return retval
@@ -1364,19 +1395,19 @@ class _LROOperationSerializer(_OperationSerializer[LROOperationType]):
1364
1395
 
1365
1396
  def initial_call(self, builder: LROOperationType) -> List[str]:
1366
1397
  retval = [
1367
- "polling = kwargs.pop('polling', True) # type: Union[bool, "
1368
- f"{builder.get_base_polling_method(self.async_mode)}]"
1398
+ f"polling: Union[bool, {builder.get_base_polling_method(self.async_mode)}] = kwargs.pop('polling', True)",
1369
1399
  ]
1370
1400
  retval.append("lro_delay = kwargs.pop(")
1371
1401
  retval.append(" 'polling_interval',")
1372
1402
  retval.append(" self._config.polling_interval")
1373
1403
  retval.append(")")
1374
1404
  retval.append(
1375
- "cont_token = kwargs.pop('continuation_token', None) # type: Optional[str]"
1405
+ "cont_token: Optional[str] = kwargs.pop('continuation_token', None)"
1376
1406
  )
1377
1407
  retval.append("if cont_token is None:")
1378
1408
  retval.append(
1379
- f" raw_result = {self._call_method}self.{builder.initial_operation.name}( # type: ignore"
1409
+ f" raw_result = {self._call_method}self.{builder.initial_operation.name}("
1410
+ f"{'' if builder.lro_response and builder.lro_response.type else ' # type: ignore'}"
1380
1411
  )
1381
1412
  retval.extend(
1382
1413
  [
@@ -1409,13 +1440,14 @@ class _LROOperationSerializer(_OperationSerializer[LROOperationType]):
1409
1440
  retval.extend(
1410
1441
  [
1411
1442
  "if polling is True:",
1412
- f" polling_method = cast({builder.get_base_polling_method(self.async_mode)}, "
1443
+ f" polling_method: {builder.get_base_polling_method(self.async_mode)} "
1444
+ + f"= cast({builder.get_base_polling_method(self.async_mode)}, "
1413
1445
  f"{builder.get_polling_method(self.async_mode)}(",
1414
1446
  " lro_delay,",
1415
1447
  f" {lro_options_str}",
1416
1448
  f" {path_format_arguments_str}",
1417
1449
  " **kwargs",
1418
- f")) # type: {builder.get_base_polling_method(self.async_mode)}",
1450
+ f"))",
1419
1451
  ]
1420
1452
  )
1421
1453
  retval.append(
@@ -1434,7 +1466,7 @@ class _LROOperationSerializer(_OperationSerializer[LROOperationType]):
1434
1466
  retval.append(" )")
1435
1467
  retval.append(
1436
1468
  f"return {builder.get_poller(self.async_mode)}"
1437
- "(self._client, raw_result, get_long_running_output, polling_method)"
1469
+ "(self._client, raw_result, get_long_running_output, polling_method) # type: ignore"
1438
1470
  )
1439
1471
  return retval
1440
1472
 
@@ -1462,13 +1494,18 @@ class _LROOperationSerializer(_OperationSerializer[LROOperationType]):
1462
1494
  )
1463
1495
  retval.append(" if cls:")
1464
1496
  retval.append(
1465
- " return cls(pipeline_response, {}, {})".format(
1497
+ " return cls(pipeline_response, {}, {}){}".format(
1466
1498
  "deserialized"
1467
1499
  if builder.lro_response and builder.lro_response.type
1468
1500
  else "None",
1469
1501
  "response_headers"
1470
1502
  if builder.lro_response and builder.lro_response.headers
1471
1503
  else "{}",
1504
+ " # type: ignore"
1505
+ if builder.lro_response
1506
+ and builder.lro_response.type
1507
+ and not self.code_model.options["models_mode"]
1508
+ else "",
1472
1509
  )
1473
1510
  )
1474
1511
  if builder.lro_response and builder.lro_response.type:
@@ -1507,7 +1544,7 @@ class LROPagingOperationSerializer(
1507
1544
  retval.append(" )")
1508
1545
  return retval
1509
1546
 
1510
- def decorators(self, builder: LROPagingOperation) -> List[str]: # type: ignore
1547
+ def decorators(self, builder: LROPagingOperation) -> List[str]:
1511
1548
  """Decorators for the method"""
1512
1549
  return _LROOperationSerializer.decorators(self, builder) # type: ignore
1513
1550
 
@@ -116,15 +116,29 @@ class ClientSerializer:
116
116
 
117
117
  def serializers_and_operation_groups_properties(self) -> List[str]:
118
118
  retval = []
119
- if self.client.code_model.model_types:
120
- client_models_value = (
121
- "{k: v for k, v in models.__dict__.items() if isinstance(v, type)}"
122
- )
123
- else:
124
- client_models_value = "{} # type: Dict[str, Any]"
119
+
120
+ def _get_client_models_value(models_dict_name: str) -> str:
121
+ if self.client.code_model.model_types:
122
+ return f"{{k: v for k, v in {models_dict_name}.__dict__.items() if isinstance(v, type)}}"
123
+ return "{}"
124
+
125
125
  is_msrest_model = self.client.code_model.options["models_mode"] == "msrest"
126
126
  if is_msrest_model:
127
- retval.append(f"client_models = {client_models_value}")
127
+ add_private_models = len(self.client.code_model.model_types) != len(
128
+ self.client.code_model.public_model_types
129
+ )
130
+ model_dict_name = (
131
+ f"_models.{self.client.code_model.models_filename}"
132
+ if add_private_models
133
+ else "_models"
134
+ )
135
+ retval.append(
136
+ f"client_models{': Dict[str, Any]' if not self.client.code_model.model_types else ''}"
137
+ f" = {_get_client_models_value(model_dict_name)}"
138
+ )
139
+ if add_private_models and self.client.code_model.model_types:
140
+ update_dict = f"{{k: v for k, v in _models.__dict__.items() if isinstance(v, type)}}"
141
+ retval.append(f"client_models.update({update_dict})")
128
142
  client_models_str = "client_models" if is_msrest_model else ""
129
143
  retval.append(f"self._serialize = Serializer({client_models_str})")
130
144
  retval.append(f"self._deserialize = Deserializer({client_models_str})")
@@ -136,7 +150,7 @@ class ClientSerializer:
136
150
  for og in operation_groups:
137
151
  retval.extend(
138
152
  [
139
- f"self.{og.property_name} = {og.class_name}({og.mypy_ignore}{og.pylint_disable}",
153
+ f"self.{og.property_name} = {og.class_name}({og.pylint_disable}",
140
154
  " self._client, self._config, self._serialize, self._deserialize",
141
155
  ")",
142
156
  ]
@@ -58,7 +58,13 @@ def _json_serialize_imports(
58
58
  name_import_ordered_list = []
59
59
  if name_imports:
60
60
  name_import_ordered_list = list(name_imports)
61
- name_import_ordered_list.sort()
61
+ name_import_ordered_list.sort(
62
+ key=lambda e: "".join(e) # type: ignore
63
+ if isinstance(e, (list, tuple))
64
+ else e
65
+ if isinstance(e, str)
66
+ else ""
67
+ )
62
68
  json_package_name_dictionary[package_name] = name_import_ordered_list
63
69
  json_import_type_dictionary[import_type_key] = json_package_name_dictionary
64
70
  json_serialize_imports[typing_section_key] = json_import_type_dictionary
@@ -3,7 +3,7 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import cast, List
6
+ from typing import List, cast
7
7
  from abc import ABC, abstractmethod
8
8
 
9
9
  from jinja2 import Environment
@@ -77,7 +77,7 @@ class _ModelSerializer(ABC):
77
77
  typing = "Optional[str]"
78
78
  else:
79
79
  typing = "str"
80
- return f"self.{prop.client_name} = {discriminator_value} # type: {typing}"
80
+ return f"self.{prop.client_name}: {typing} = {discriminator_value}"
81
81
 
82
82
  @staticmethod
83
83
  def initialize_standard_property(prop: Property):
@@ -152,7 +152,7 @@ class MsrestModelSerializer(_ModelSerializer):
152
152
  else "_serialization.Model"
153
153
  )
154
154
  if model.parents:
155
- basename = ", ".join([cast(ModelType, m).name for m in model.parents])
155
+ basename = ", ".join([m.name for m in model.parents])
156
156
  return f"class {model.name}({basename}):{model.pylint_disable}"
157
157
 
158
158
  @staticmethod
@@ -163,9 +163,7 @@ class MsrestModelSerializer(_ModelSerializer):
163
163
  p.client_name: p
164
164
  for bm in model.parents
165
165
  for p in model.properties
166
- if p not in cast(ModelType, bm).properties
167
- or p.is_discriminator
168
- or p.constant
166
+ if p not in bm.properties or p.is_discriminator or p.constant
169
167
  }.values()
170
168
  )
171
169
  else:
@@ -195,7 +193,10 @@ class MsrestModelSerializer(_ModelSerializer):
195
193
  xml_metadata = f", 'xml': {{{prop.type.xml_serialization_ctxt}}}"
196
194
  else:
197
195
  xml_metadata = ""
198
- return f'"{prop.client_name}": {{"key": "{attribute_key}", "type": "{prop.serialization_type}"{xml_metadata}}},'
196
+ return (
197
+ f'"{prop.client_name}": {{"key": "{attribute_key}",'
198
+ f' "type": "{prop.msrest_deserialization_key}"{xml_metadata}}},'
199
+ )
199
200
 
200
201
 
201
202
  class DpgModelSerializer(_ModelSerializer):
@@ -220,7 +221,7 @@ class DpgModelSerializer(_ModelSerializer):
220
221
  def declare_model(self, model: ModelType) -> str:
221
222
  basename = "_model_base.Model"
222
223
  if model.parents:
223
- basename = ", ".join([cast(ModelType, m).name for m in model.parents])
224
+ basename = ", ".join([m.name for m in model.parents])
224
225
  if model.discriminator_value:
225
226
  basename += f", discriminator='{model.discriminator_value}'"
226
227
  return f"class {model.name}({basename}):{model.pylint_disable}"
@@ -228,9 +229,7 @@ class DpgModelSerializer(_ModelSerializer):
228
229
  @staticmethod
229
230
  def get_properties_to_declare(model: ModelType) -> List[Property]:
230
231
  if model.parents:
231
- parent_properties = [
232
- p for bm in model.parents for p in cast(ModelType, bm).properties
233
- ]
232
+ parent_properties = [p for bm in model.parents for p in bm.properties]
234
233
  properties_to_declare = [
235
234
  p
236
235
  for p in model.properties
@@ -271,7 +270,7 @@ class DpgModelSerializer(_ModelSerializer):
271
270
  for prop in self.get_properties_to_declare(model):
272
271
  if prop.constant or prop.is_discriminator:
273
272
  init_args.append(
274
- f"self.{prop.client_name} = {cast(ConstantType, prop.type).get_declaration()} "
275
- f"# type: {prop.type_annotation()}"
273
+ f"self.{prop.client_name}: {prop.type_annotation()} = "
274
+ f"{cast(ConstantType, prop.type).get_declaration()}"
276
275
  )
277
276
  return init_args
@@ -64,6 +64,7 @@ class OperationGroupsSerializer:
64
64
  template = self.env.get_or_select_template(
65
65
  "operation_groups_container.py.jinja2"
66
66
  )
67
+
67
68
  return template.render(
68
69
  code_model=self.code_model,
69
70
  operation_groups=operation_groups,
@@ -145,13 +145,13 @@ class ParameterSerializer:
145
145
  f"_{kwarg_dict}.pop('{kwarg.rest_api_name}', {default_value})"
146
146
  )
147
147
  retval.append(
148
- f"{kwarg.client_name} = kwargs.pop('{kwarg.client_name}', "
149
- + f"{default_value}) # type: {kwarg.type_annotation()}"
148
+ f"{kwarg.client_name}: {kwarg.type_annotation()} = kwargs.pop('{kwarg.client_name}', "
149
+ + f"{default_value})"
150
150
  )
151
151
  else:
152
152
  type_annot = kwarg.type_annotation()
153
153
  retval.append(
154
- f"{kwarg.client_name} = kwargs.pop('{kwarg.client_name}') # type: {type_annot}"
154
+ f"{kwarg.client_name}: {type_annot} = kwargs.pop('{kwarg.client_name}')"
155
155
  )
156
156
  return retval
157
157
 
@@ -5,7 +5,7 @@
5
5
  # --------------------------------------------------------------------------
6
6
  from jinja2 import Environment
7
7
  from .import_serializer import FileImportSerializer
8
- from ..models import CodeModel, FileImport, ImportType, TypingSection
8
+ from ..models import CodeModel, FileImport, ImportType
9
9
 
10
10
 
11
11
  class PatchSerializer:
@@ -16,9 +16,7 @@ class PatchSerializer:
16
16
  def serialize(self) -> str:
17
17
  template = self.env.get_template("patch.py.jinja2")
18
18
  imports = FileImport()
19
- imports.add_submodule_import(
20
- "typing", "List", ImportType.STDLIB, TypingSection.CONDITIONAL
21
- )
19
+ imports.add_submodule_import("typing", "List", ImportType.STDLIB)
22
20
  return template.render(
23
21
  code_model=self.code_model,
24
22
  imports=FileImportSerializer(imports),
@@ -5,7 +5,7 @@
5
5
  # license information.
6
6
  # --------------------------------------------------------------------------
7
7
  import logging
8
- from typing import Dict, Any
8
+ from typing import Dict, Any, Union
9
9
  from jinja2 import Environment
10
10
 
11
11
  from autorest.codegen.models.credential_types import AzureKeyCredentialType
@@ -13,6 +13,7 @@ from autorest.codegen.models.credential_types import TokenCredentialType
13
13
  from autorest.codegen.models.imports import FileImport, ImportType
14
14
  from autorest.codegen.models.operation import OperationBase
15
15
  from autorest.codegen.models.operation_group import OperationGroup
16
+ from autorest.codegen.models.parameter import Parameter, BodyParameter
16
17
  from autorest.codegen.serializers.import_serializer import FileImportSerializer
17
18
  from ..models import CodeModel
18
19
 
@@ -28,7 +29,6 @@ class SampleSerializer:
28
29
  operation: OperationBase[Any],
29
30
  sample: Dict[str, Any],
30
31
  file_name: str,
31
- sample_origin_name: str,
32
32
  ) -> None:
33
33
  self.code_model = code_model
34
34
  self.env = env
@@ -36,7 +36,6 @@ class SampleSerializer:
36
36
  self.operation = operation
37
37
  self.sample = sample
38
38
  self.file_name = file_name
39
- self.sample_origin_name = sample_origin_name
40
39
 
41
40
  def _imports(self) -> FileImportSerializer:
42
41
  imports = FileImport()
@@ -55,6 +54,13 @@ class SampleSerializer:
55
54
  imports.add_submodule_import(
56
55
  "azure.core.credentials", "AzureKeyCredential", ImportType.THIRDPARTY
57
56
  )
57
+ for param in self.operation.parameters.positional:
58
+ if (
59
+ not param.client_default_value
60
+ and not param.optional
61
+ and param.rest_api_name in self.sample["parameters"]
62
+ ):
63
+ imports.merge(param.type.imports_for_sample())
58
64
  return FileImportSerializer(imports, True)
59
65
 
60
66
  def _client_params(self) -> Dict[str, Any]:
@@ -88,13 +94,12 @@ class SampleSerializer:
88
94
  return client_params
89
95
 
90
96
  @staticmethod
91
- def handle_param(param: Any) -> str:
92
- if isinstance(param, str):
93
- if any(i in param for i in '\r\n"'):
94
- return f'"""{param}"""'
95
- return f'"{param}"'
97
+ def handle_param(param: Union[Parameter, BodyParameter], param_value: Any) -> str:
98
+ if isinstance(param_value, str):
99
+ if any(i in param_value for i in '\r\n"'):
100
+ return f'"""{param_value}"""'
96
101
 
97
- return str(param)
102
+ return param.type.serialize_sample_value(param_value)
98
103
 
99
104
  # prepare operation parameters
100
105
  def _operation_params(self) -> Dict[str, Any]:
@@ -103,15 +108,17 @@ class SampleSerializer:
103
108
  for p in self.operation.parameters.positional
104
109
  if not p.client_default_value
105
110
  ]
106
- failure_info = "fail to find required param named {} in example file {}"
111
+ failure_info = "fail to find required param named {}"
107
112
  operation_params = {}
108
113
  for param in params_positional:
109
114
  name = param.rest_api_name
110
115
  param_value = self.sample["parameters"].get(name)
111
116
  if not param.optional:
112
117
  if not param_value:
113
- raise Exception(failure_info.format(name, self.sample_origin_name))
114
- operation_params[param.client_name] = self.handle_param(param_value)
118
+ raise Exception(failure_info.format(name))
119
+ operation_params[param.client_name] = self.handle_param(
120
+ param, param_value
121
+ )
115
122
  return operation_params
116
123
 
117
124
  def _operation_group_name(self) -> str:
@@ -135,8 +142,10 @@ class SampleSerializer:
135
142
  return f".{self.operation.name}"
136
143
 
137
144
  def _origin_file(self) -> str:
138
- name = self.sample.get("x-ms-original-file", "").split("specification")[-1]
139
- return "specification" + name if name else name
145
+ name = self.sample.get("x-ms-original-file", "")
146
+ if "specification" in name:
147
+ return "specification" + name.split("specification")[-1]
148
+ return ""
140
149
 
141
150
  def serialize(self) -> str:
142
151
  return self.env.get_template("sample.py.jinja2").render(
@@ -3,6 +3,7 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
+ from pathlib import Path
6
7
 
7
8
 
8
9
  def method_signature_and_response_type_annotation_template(
@@ -11,3 +12,8 @@ def method_signature_and_response_type_annotation_template(
11
12
  response_type_annotation: str,
12
13
  ) -> str:
13
14
  return f"{method_signature} -> {response_type_annotation}:"
15
+
16
+
17
+ def extract_sample_name(file_path: str) -> str:
18
+ file = file_path.split("specification")[-1]
19
+ return Path(file).parts[-1].replace(".json", "")
@@ -26,21 +26,12 @@
26
26
  {% endif %}
27
27
  return self._client.send_request(request_copy, **kwargs)
28
28
 
29
- {{ keywords.def }} close(self){{ " -> None" if async_mode else "" }}:
30
- {% if not async_mode %}
31
- # type: () -> None
32
- {% endif %}
29
+ {{ keywords.def }} close(self) -> None:
33
30
  {{ keywords.await }}self._client.close()
34
31
 
35
- {{ keywords.def }} __{{ keywords.async_prefix }}enter__(self){{ (" -> \"" + client.name + "\"") if async_mode else "" }}:
36
- {% if not async_mode %}
37
- # type: () -> {{ client.name }}
38
- {% endif %}
32
+ {{ keywords.def }} __{{ keywords.async_prefix }}enter__(self){{ " -> \"" + client.name + "\"" }}:
39
33
  {{ keywords.await }}self._client.__{{ keywords.async_prefix }}enter__()
40
34
  return self
41
35
 
42
- {{ keywords.def }} __{{ keywords.async_prefix }}exit__(self, *exc_details){{ " -> None" if async_mode else "" }}:
43
- {% if not async_mode %}
44
- # type: (Any) -> None
45
- {% endif %}
36
+ {{ keywords.def }} __{{ keywords.async_prefix }}exit__(self, *exc_details) -> None:
46
37
  {{ keywords.await }}self._client.__{{ keywords.async_prefix }}exit__(*exc_details)
@@ -30,11 +30,8 @@ class {{ client.name }}Configuration(Configuration): # pylint: disable=too-many
30
30
 
31
31
  def _configure(
32
32
  self,
33
- **kwargs{{": Any" if async_mode else " # type: Any"}}
34
- ){{ " -> None" if async_mode else "" }}:
35
- {% if not async_mode %}
36
- # type: (...) -> None
37
- {% endif %}
33
+ **kwargs: Any
34
+ ) -> None:
38
35
  self.user_agent_policy = kwargs.get('user_agent_policy') or policies.UserAgentPolicy(**kwargs)
39
36
  self.headers_policy = kwargs.get('headers_policy') or policies.HeadersPolicy(**kwargs)
40
37
  self.proxy_policy = kwargs.get('proxy_policy') or policies.ProxyPolicy(**kwargs)
@@ -3,7 +3,7 @@
3
3
  {% set await = "await " if async_mode else "" %}
4
4
  {% set async_class = "Async" if async_mode else "" %}
5
5
  {% macro escape_str(s) %}'{{ s|replace("'", "\\'") }}'{% endmacro %}
6
- {% set kwargs_declaration = "**kwargs: Any" if async_mode else "**kwargs # type: Any" %}
6
+ {% set kwargs_declaration = "**kwargs: Any" %}
7
7
  {% set extend_all = "__all__.extend([p for p in _patch_all if p not in __all__])" %}
8
8
  {% macro patch_imports(try_except=False) %}
9
9
  {% set indentation = " " if try_except else "" %}
@@ -11,7 +11,7 @@
11
11
  try:
12
12
  {% endif %}
13
13
  {{ indentation }}from ._patch import __all__ as _patch_all
14
- {{ indentation }}from ._patch import * # type: ignore # pylint: disable=unused-wildcard-import
14
+ {{ indentation }}from ._patch import * # pylint: disable=unused-wildcard-import
15
15
  {% if try_except %}
16
16
  except ImportError:
17
17
  _patch_all = []
@@ -45,7 +45,7 @@
45
45
  "service_client_specific": {
46
46
  "sync": {
47
47
  "api_version": {
48
- "signature": "api_version=None, # type: Optional[str]",
48
+ "signature": "api_version: Optional[str]=None,",
49
49
  "description": "API version to use if no profile is provided, or if missing in profile.",
50
50
  "docstring_type": "str",
51
51
  "required": false
@@ -59,7 +59,7 @@
59
59
  },
60
60
  {% endif %}
61
61
  "profile": {
62
- "signature": "profile=KnownProfiles.default, # type: KnownProfiles",
62
+ "signature": "profile: KnownProfiles=KnownProfiles.default,",
63
63
  "description": "A profile definition, from KnownProfiles to dict.",
64
64
  "docstring_type": "azure.profiles.KnownProfiles",
65
65
  "required": false