@autorest/python 6.9.2 → 6.9.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/autorest/_utils.py +48 -20
  2. package/autorest/codegen/__init__.py +8 -1
  3. package/autorest/codegen/models/__init__.py +17 -2
  4. package/autorest/codegen/models/base.py +3 -4
  5. package/autorest/codegen/models/client.py +5 -6
  6. package/autorest/codegen/models/combined_type.py +1 -1
  7. package/autorest/codegen/models/constant_type.py +3 -17
  8. package/autorest/codegen/models/credential_types.py +2 -2
  9. package/autorest/codegen/models/dictionary_type.py +1 -1
  10. package/autorest/codegen/models/enum_type.py +78 -23
  11. package/autorest/codegen/models/imports.py +23 -5
  12. package/autorest/codegen/models/list_type.py +1 -1
  13. package/autorest/codegen/models/lro_operation.py +3 -0
  14. package/autorest/codegen/models/model_type.py +1 -1
  15. package/autorest/codegen/models/operation.py +4 -4
  16. package/autorest/codegen/models/operation_group.py +2 -2
  17. package/autorest/codegen/models/paging_operation.py +1 -1
  18. package/autorest/codegen/models/parameter.py +6 -3
  19. package/autorest/codegen/models/primitive_types.py +20 -25
  20. package/autorest/codegen/models/property.py +25 -4
  21. package/autorest/codegen/models/request_builder.py +4 -5
  22. package/autorest/codegen/models/response.py +1 -1
  23. package/autorest/codegen/serializers/builder_serializer.py +23 -41
  24. package/autorest/codegen/serializers/client_serializer.py +1 -1
  25. package/autorest/codegen/serializers/general_serializer.py +6 -7
  26. package/autorest/codegen/serializers/model_serializer.py +10 -9
  27. package/autorest/codegen/serializers/operation_groups_serializer.py +1 -1
  28. package/autorest/codegen/serializers/patch_serializer.py +1 -1
  29. package/autorest/codegen/serializers/request_builders_serializer.py +1 -1
  30. package/autorest/codegen/serializers/sample_serializer.py +1 -1
  31. package/autorest/codegen/serializers/types_serializer.py +1 -1
  32. package/autorest/codegen/templates/client.py.jinja2 +1 -1
  33. package/autorest/codegen/templates/enum.py.jinja2 +3 -3
  34. package/autorest/codegen/templates/model_base.py.jinja2 +1 -1
  35. package/autorest/codegen/templates/serialization.py.jinja2 +5 -5
  36. package/autorest/jsonrpc/server.py +5 -1
  37. package/autorest/m4reformatter/__init__.py +11 -9
  38. package/autorest/multiapi/models/imports.py +17 -15
  39. package/autorest/preprocess/__init__.py +7 -6
  40. package/package.json +1 -1
@@ -107,7 +107,7 @@ class BinaryType(PrimitiveType):
107
107
  from .combined_type import CombinedType
108
108
  from .operation import OperationBase
109
109
 
110
- file_import = FileImport()
110
+ file_import = FileImport(self.code_model)
111
111
  file_import.add_submodule_import("typing", "IO", ImportType.STDLIB)
112
112
  operation = kwargs.get("operation")
113
113
  if (
@@ -145,7 +145,7 @@ class BinaryIteratorType(PrimitiveType):
145
145
  return self.get_declaration("Iterator[bytes]")
146
146
 
147
147
  def imports(self, **kwargs: Any) -> FileImport:
148
- file_import = FileImport()
148
+ file_import = FileImport(self.code_model)
149
149
  iterator = "AsyncIterator" if kwargs.get("async_mode") else "Iterator"
150
150
  file_import.add_submodule_import("typing", iterator, ImportType.STDLIB)
151
151
  return file_import
@@ -171,7 +171,7 @@ class AnyType(PrimitiveType):
171
171
  return self.get_declaration({})
172
172
 
173
173
  def imports(self, **kwargs: Any) -> FileImport:
174
- file_import = FileImport()
174
+ file_import = FileImport(self.code_model)
175
175
  file_import.add_submodule_import(
176
176
  "typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL
177
177
  )
@@ -204,7 +204,7 @@ class AnyObjectType(PrimitiveType):
204
204
  return "isinstance({}, MutableMapping)"
205
205
 
206
206
  def imports(self, **kwargs: Any) -> FileImport:
207
- file_import = FileImport()
207
+ file_import = FileImport(self.code_model)
208
208
  file_import.define_mutable_mapping_type()
209
209
  return file_import
210
210
 
@@ -386,7 +386,7 @@ class DatetimeType(PrimitiveType):
386
386
  return f'"{value}"'
387
387
 
388
388
  def imports(self, **kwargs: Any) -> FileImport:
389
- file_import = FileImport()
389
+ file_import = FileImport(self.code_model)
390
390
  file_import.add_import("datetime", ImportType.STDLIB)
391
391
  return file_import
392
392
 
@@ -398,9 +398,8 @@ class DatetimeType(PrimitiveType):
398
398
  def instance_check_template(self) -> str:
399
399
  return "isinstance({}, datetime.datetime)"
400
400
 
401
- @staticmethod
402
- def imports_for_sample() -> FileImport:
403
- file_import = super(DatetimeType, DatetimeType).imports_for_sample()
401
+ def imports_for_sample(self) -> FileImport:
402
+ file_import = super().imports_for_sample()
404
403
  file_import.add_import("isodate", ImportType.STDLIB)
405
404
  return file_import
406
405
 
@@ -430,7 +429,7 @@ class TimeType(PrimitiveType):
430
429
  return f'"{value}"'
431
430
 
432
431
  def imports(self, **kwargs: Any) -> FileImport:
433
- file_import = FileImport()
432
+ file_import = FileImport(self.code_model)
434
433
  file_import.add_import("datetime", ImportType.STDLIB)
435
434
  return file_import
436
435
 
@@ -442,9 +441,8 @@ class TimeType(PrimitiveType):
442
441
  def instance_check_template(self) -> str:
443
442
  return "isinstance({}, datetime.time)"
444
443
 
445
- @staticmethod
446
- def imports_for_sample() -> FileImport:
447
- file_import = super(TimeType, TimeType).imports_for_sample()
444
+ def imports_for_sample(self) -> FileImport:
445
+ file_import = super().imports_for_sample()
448
446
  file_import.add_import("isodate", ImportType.STDLIB)
449
447
  return file_import
450
448
 
@@ -478,7 +476,7 @@ class UnixTimeType(PrimitiveType):
478
476
  return f'"{value}"'
479
477
 
480
478
  def imports(self, **kwargs: Any) -> FileImport:
481
- file_import = FileImport()
479
+ file_import = FileImport(self.code_model)
482
480
  file_import.add_import("datetime", ImportType.STDLIB)
483
481
  return file_import
484
482
 
@@ -490,9 +488,8 @@ class UnixTimeType(PrimitiveType):
490
488
  def instance_check_template(self) -> str:
491
489
  return "isinstance({}, datetime.time)"
492
490
 
493
- @staticmethod
494
- def imports_for_sample() -> FileImport:
495
- file_import = super(UnixTimeType, UnixTimeType).imports_for_sample()
491
+ def imports_for_sample(self) -> FileImport:
492
+ file_import = super().imports_for_sample()
496
493
  file_import.add_import("datetime", ImportType.STDLIB)
497
494
  return file_import
498
495
 
@@ -522,7 +519,7 @@ class DateType(PrimitiveType):
522
519
  return f'"{value}"'
523
520
 
524
521
  def imports(self, **kwargs: Any) -> FileImport:
525
- file_import = FileImport()
522
+ file_import = FileImport(self.code_model)
526
523
  file_import.add_import("datetime", ImportType.STDLIB)
527
524
  return file_import
528
525
 
@@ -534,9 +531,8 @@ class DateType(PrimitiveType):
534
531
  def instance_check_template(self) -> str:
535
532
  return "isinstance({}, datetime.date)"
536
533
 
537
- @staticmethod
538
- def imports_for_sample() -> FileImport:
539
- file_import = super(DateType, DateType).imports_for_sample()
534
+ def imports_for_sample(self) -> FileImport:
535
+ file_import = super().imports_for_sample()
540
536
  file_import.add_import("isodate", ImportType.STDLIB)
541
537
  return file_import
542
538
 
@@ -566,7 +562,7 @@ class DurationType(PrimitiveType):
566
562
  return f'"{value}"'
567
563
 
568
564
  def imports(self, **kwargs: Any) -> FileImport:
569
- file_import = FileImport()
565
+ file_import = FileImport(self.code_model)
570
566
  file_import.add_import("datetime", ImportType.STDLIB)
571
567
  return file_import
572
568
 
@@ -578,9 +574,8 @@ class DurationType(PrimitiveType):
578
574
  def instance_check_template(self) -> str:
579
575
  return "isinstance({}, datetime.timedelta)"
580
576
 
581
- @staticmethod
582
- def imports_for_sample() -> FileImport:
583
- file_import = super(DurationType, DurationType).imports_for_sample()
577
+ def imports_for_sample(self) -> FileImport:
578
+ file_import = super().imports_for_sample()
584
579
  file_import.add_import("isodate", ImportType.STDLIB)
585
580
  return file_import
586
581
 
@@ -623,7 +618,7 @@ class AzureCoreType(PrimitiveType):
623
618
  return self.name
624
619
 
625
620
  def imports(self, **kwargs: Any) -> FileImport:
626
- file_import = FileImport()
621
+ file_import = FileImport(self.code_model)
627
622
  file_import.add_submodule_import("azure.core", self.name, ImportType.AZURECORE)
628
623
  return file_import
629
624
 
@@ -7,6 +7,7 @@ from typing import Any, Dict, Optional, TYPE_CHECKING, List
7
7
 
8
8
  from .base import BaseModel
9
9
  from .constant_type import ConstantType
10
+ from .enum_type import EnumType
10
11
  from .base import BaseType
11
12
  from .imports import FileImport, ImportType
12
13
  from .utils import add_to_description, add_to_pylint_disable
@@ -61,9 +62,9 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
61
62
  @property
62
63
  def client_default_value_declaration(self) -> str:
63
64
  if self.client_default_value is not None:
64
- return self.type.get_declaration(self.client_default_value)
65
+ return self.get_declaration(self.client_default_value)
65
66
  if self.type.client_default_value is not None:
66
- return self.type.get_declaration(self.type.client_default_value)
67
+ return self.get_declaration(self.type.client_default_value)
67
68
  return "None"
68
69
 
69
70
  @property
@@ -88,11 +89,24 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
88
89
  def msrest_deserialization_key(self) -> str:
89
90
  return self.type.msrest_deserialization_key
90
91
 
92
+ @property
93
+ def is_enum_discriminator(self) -> bool:
94
+ return self.is_discriminator and self.type.type == "enum"
95
+
91
96
  def type_annotation(self, *, is_operation_file: bool = False) -> str:
97
+ if self.is_enum_discriminator:
98
+ # here we are the enum discriminator property on the base model
99
+ return "Literal[None]"
92
100
  if self.optional and self.client_default_value is None:
93
101
  return f"Optional[{self.type.type_annotation(is_operation_file=is_operation_file)}]"
94
102
  return self.type.type_annotation(is_operation_file=is_operation_file)
95
103
 
104
+ def get_declaration(self, value: Any = None) -> Any:
105
+ if self.is_enum_discriminator:
106
+ # here we are the enum discriminator property on the base model
107
+ return None
108
+ return self.type.get_declaration(value)
109
+
96
110
  def get_json_template_representation(
97
111
  self,
98
112
  *,
@@ -101,11 +115,13 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
101
115
  description: Optional[str] = None,
102
116
  ) -> Any:
103
117
  if self.client_default_value:
104
- client_default_value_declaration = self.type.get_declaration(
118
+ client_default_value_declaration = self.get_declaration(
105
119
  self.client_default_value
106
120
  )
107
121
  if self.description(is_operation_file=True):
108
122
  description = self.description(is_operation_file=True)
123
+ # make sure there is no \n otherwise the json template will be invalid
124
+ description = (description or "").replace("\n", " ")
109
125
  return self.type.get_json_template_representation(
110
126
  optional=self.optional,
111
127
  client_default_value_declaration=client_default_value_declaration,
@@ -131,7 +147,12 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
131
147
  return retval or None
132
148
 
133
149
  def imports(self, **kwargs) -> FileImport:
134
- file_import = self.type.imports(**kwargs, relative_path="..", model_typing=True)
150
+ file_import = FileImport(self.code_model)
151
+ if self.is_discriminator and isinstance(self.type, EnumType):
152
+ return file_import
153
+ file_import.merge(
154
+ self.type.imports(**kwargs, relative_path="..", model_typing=True)
155
+ )
135
156
  if self.optional and self.client_default_value is None:
136
157
  file_import.add_submodule_import("typing", "Optional", ImportType.STDLIB)
137
158
  if self.code_model.options["models_mode"] == "dpg":
@@ -81,7 +81,7 @@ class RequestBuilderBase(BaseBuilder[ParameterListType]):
81
81
  return "~azure.core.rest.HttpRequest"
82
82
 
83
83
  def imports(self) -> FileImport:
84
- file_import = FileImport()
84
+ file_import = FileImport(self.code_model)
85
85
  relative_path = ".."
86
86
  if (
87
87
  not self.code_model.options["builders_visibility"] == "embedded"
@@ -111,15 +111,14 @@ class RequestBuilderBase(BaseBuilder[ParameterListType]):
111
111
  "typing", "Any", ImportType.STDLIB, typing_section=TypingSection.CONDITIONAL
112
112
  )
113
113
  file_import.add_msrest_import(
114
- self.code_model,
115
- "..."
114
+ relative_path="..."
116
115
  if (
117
116
  not self.code_model.options["builders_visibility"] == "embedded"
118
117
  and self.group_name
119
118
  )
120
119
  else "..",
121
- MsrestImportType.Serializer,
122
- TypingSection.REGULAR,
120
+ msrest_import_type=MsrestImportType.Serializer,
121
+ typing_section=TypingSection.REGULAR,
123
122
  )
124
123
  if (
125
124
  self.overloads
@@ -108,7 +108,7 @@ class Response(BaseModel):
108
108
  return self.type.docstring_type(**kwargs) if self.type else "None"
109
109
 
110
110
  def _imports_shared(self, **kwargs: Any) -> FileImport:
111
- file_import = FileImport()
111
+ file_import = FileImport(self.code_model)
112
112
  if self.type:
113
113
  file_import.merge(self.type.imports(**kwargs))
114
114
  if self.nullable:
@@ -26,7 +26,6 @@ from ..models import (
26
26
  ParameterMethodLocation,
27
27
  RequestBuilderBodyParameter,
28
28
  OverloadedRequestBuilder,
29
- ConstantType,
30
29
  MultipartBodyParameter,
31
30
  Property,
32
31
  RequestBuilderType,
@@ -458,15 +457,16 @@ class RequestBuilderSerializer(
458
457
  @staticmethod
459
458
  def declare_non_inputtable_constants(builder: RequestBuilderType) -> List[str]:
460
459
  def _get_value(param):
461
- param_type = cast(ConstantType, param.type)
462
460
  if param.location in [ParameterLocation.HEADER, ParameterLocation.QUERY]:
463
461
  kwarg_dict = (
464
462
  "headers"
465
463
  if param.location == ParameterLocation.HEADER
466
464
  else "params"
467
465
  )
468
- return f"_{kwarg_dict}.pop('{param.wire_name}', {param_type.get_declaration()})"
469
- return f"{param_type.get_declaration()}"
466
+ return (
467
+ f"_{kwarg_dict}.pop('{param.wire_name}', {param.get_declaration()})"
468
+ )
469
+ return f"{param.get_declaration()}"
470
470
 
471
471
  return [
472
472
  f"{p.client_name} = {_get_value(p)}"
@@ -645,7 +645,7 @@ class _OperationSerializer(
645
645
  f"_stream = {stream_value}",
646
646
  f"pipeline_response: PipelineResponse = {self._call_method}self._client._pipeline.run( "
647
647
  + f"{'# type: ignore' if type_ignore else ''} # pylint: disable=protected-access",
648
- " request,",
648
+ " _request,",
649
649
  " stream=_stream,",
650
650
  " **kwargs",
651
651
  ")",
@@ -915,7 +915,7 @@ class _OperationSerializer(
915
915
  ("_" + group_name) if group_name else "",
916
916
  request_builder.name,
917
917
  )
918
- retval.append(f"request = {request_path_name}(")
918
+ retval.append(f"_request = {request_path_name}(")
919
919
  for parameter in request_builder.parameters.method:
920
920
  if parameter.location == ParameterLocation.BODY:
921
921
  # going to pass in body later based off of overloads
@@ -979,14 +979,14 @@ class _OperationSerializer(
979
979
  and builder.parameters.body_parameter.client_name == "files"
980
980
  ):
981
981
  pass_files = ", _files"
982
- retval.append(f"request = _convert_request(request{pass_files})")
982
+ retval.append(f"_request = _convert_request(_request{pass_files})")
983
983
  if builder.parameters.path:
984
984
  retval.extend(self.serialize_path(builder))
985
- url_to_format = "request.url"
985
+ url_to_format = "_request.url"
986
986
  if self.code_model.options["version_tolerant"] and template_url:
987
987
  url_to_format = template_url
988
988
  retval.append(
989
- "request.url = self._client.format_url({}{})".format(
989
+ "_request.url = self._client.format_url({}{})".format(
990
990
  url_to_format,
991
991
  ", **path_format_arguments" if builder.parameters.path else "",
992
992
  )
@@ -1168,31 +1168,16 @@ class _OperationSerializer(
1168
1168
  )
1169
1169
  )
1170
1170
  retval.append("")
1171
- type_ignore = (
1172
- builder.has_response_body
1173
- and not builder.has_optional_return_type
1174
- and not (
1175
- self.code_model.options["models_mode"] == "msrest"
1176
- and any(not resp.is_stream_response for resp in builder.responses)
1177
- )
1178
- )
1179
1171
  if builder.has_optional_return_type or self.code_model.options["models_mode"]:
1180
1172
  deserialized = "deserialized"
1181
1173
  else:
1182
1174
  deserialized = f"cast({builder.response_type_annotation(async_mode=self.async_mode)}, deserialized)"
1183
- type_ignore = False
1184
- if (
1185
- not builder.has_optional_return_type
1186
- and len(builder.responses) > 1
1187
- and any(resp.is_stream_response or resp.type for resp in builder.responses)
1188
- ):
1189
- type_ignore = True
1190
1175
  retval.append("if cls:")
1191
1176
  retval.append(
1192
1177
  " return cls(pipeline_response, {}, {}){}".format(
1193
1178
  deserialized if builder.has_response_body else "None",
1194
1179
  "response_headers" if builder.any_response_has_headers else "{}",
1195
- " # type: ignore" if type_ignore else "",
1180
+ " # type: ignore",
1196
1181
  )
1197
1182
  )
1198
1183
  if builder.has_response_body and any(
@@ -1200,9 +1185,7 @@ class _OperationSerializer(
1200
1185
  for response in builder.responses
1201
1186
  ):
1202
1187
  retval.append("")
1203
- retval.append(
1204
- f"return {deserialized}{' # type: ignore' if type_ignore else ''}"
1205
- )
1188
+ retval.append(f"return {deserialized} # type: ignore")
1206
1189
  if (
1207
1190
  builder.request_builder.method == "HEAD"
1208
1191
  and self.code_model.options["head_as_boolean"]
@@ -1370,8 +1353,8 @@ class _PagingOperationSerializer(
1370
1353
  except StopIteration:
1371
1354
  pass
1372
1355
 
1373
- retval.append(f'request = HttpRequest("GET", {next_link_str}{query_str})')
1374
- retval.extend(self._postprocess_http_request(builder, "request.url"))
1356
+ retval.append(f'_request = HttpRequest("GET", {next_link_str}{query_str})')
1357
+ retval.extend(self._postprocess_http_request(builder, "_request.url"))
1375
1358
 
1376
1359
  return retval
1377
1360
 
@@ -1391,10 +1374,10 @@ class _PagingOperationSerializer(
1391
1374
  [f" {line}" for line in self.call_next_link_request_builder(builder)]
1392
1375
  )
1393
1376
  if not builder.next_request_builder and self.code_model.is_legacy:
1394
- retval.append(' request.method = "GET"')
1377
+ retval.append(' _request.method = "GET"')
1395
1378
  else:
1396
1379
  retval.append("")
1397
- retval.append(" return request")
1380
+ retval.append(" return _request")
1398
1381
  return retval
1399
1382
 
1400
1383
  @property
@@ -1453,7 +1436,7 @@ class _PagingOperationSerializer(
1453
1436
 
1454
1437
  def _get_next_callback(self, builder: PagingOperationType) -> List[str]:
1455
1438
  retval = [f"{'async ' if self.async_mode else ''}def get_next(next_link=None):"]
1456
- retval.append(" request = prepare_request(next_link)")
1439
+ retval.append(" _request = prepare_request(next_link)")
1457
1440
  retval.append("")
1458
1441
  retval.extend([f" {l}" for l in self.make_pipeline_call(builder)])
1459
1442
  retval.append(" response = pipeline_response.http_response")
@@ -1582,7 +1565,7 @@ class _LROOperationSerializer(_OperationSerializer[LROOperationType]):
1582
1565
  retval.append("else: polling_method = polling")
1583
1566
  retval.append("if cont_token:")
1584
1567
  retval.append(
1585
- f" return {builder.get_poller(self.async_mode)}.from_continuation_token("
1568
+ f" return {builder.get_poller_with_response_type(self.async_mode)}.from_continuation_token("
1586
1569
  )
1587
1570
  retval.append(" polling_method=polling_method,")
1588
1571
  retval.append(" continuation_token=cont_token,")
@@ -1590,9 +1573,12 @@ class _LROOperationSerializer(_OperationSerializer[LROOperationType]):
1590
1573
  retval.append(" deserialization_callback=get_long_running_output")
1591
1574
  retval.append(" )")
1592
1575
  retval.append(
1593
- f"return {builder.get_poller(self.async_mode)}"
1594
- "(self._client, raw_result, get_long_running_output, polling_method) # type: ignore"
1576
+ f"return {builder.get_poller_with_response_type(self.async_mode)}("
1577
+ )
1578
+ retval.append(
1579
+ " self._client, raw_result, get_long_running_output, polling_method # type: ignore"
1595
1580
  )
1581
+ retval.append(" )")
1596
1582
  return retval
1597
1583
 
1598
1584
  def get_long_running_output(self, builder: LROOperationType) -> List[str]:
@@ -1626,11 +1612,7 @@ class _LROOperationSerializer(_OperationSerializer[LROOperationType]):
1626
1612
  "response_headers"
1627
1613
  if builder.lro_response and builder.lro_response.headers
1628
1614
  else "{}",
1629
- " # type: ignore"
1630
- if builder.lro_response
1631
- and builder.lro_response.type
1632
- and self.code_model.options["models_mode"] != "msrest"
1633
- else "",
1615
+ " # type: ignore",
1634
1616
  )
1635
1617
  )
1636
1618
  if builder.lro_response and builder.lro_response.type:
@@ -123,7 +123,7 @@ class ClientSerializer:
123
123
  [
124
124
  "_policies = kwargs.pop('policies', None)",
125
125
  "if _policies is None:",
126
- f' _policies = [{",".join(build_policies(self.client.code_model.options["azure_arm"], async_mode))}]', # pylint: disable=line-too-long
126
+ f' _policies = [{",".join(build_policies(self.client.code_model.options["azure_arm"], async_mode, self.client.code_model.options["unbranded"]))}]', # pylint: disable=line-too-long
127
127
  f"self._client: {pipeline_client_name} = {pipeline_client_name}("
128
128
  f"{', '.join(f'{k}={v}' for k, v in params.items())}, **kwargs)",
129
129
  ]
@@ -85,7 +85,7 @@ class GeneralSerializer:
85
85
  def serialize_service_client_file(self, clients: List[Client]) -> str:
86
86
  template = self.env.get_template("client_container.py.jinja2")
87
87
 
88
- imports = FileImport()
88
+ imports = FileImport(self.code_model)
89
89
  for client in clients:
90
90
  imports.merge(client.imports(self.async_mode))
91
91
 
@@ -101,7 +101,7 @@ class GeneralSerializer:
101
101
  template = self.env.get_template("vendor.py.jinja2")
102
102
 
103
103
  # configure imports
104
- file_import = FileImport()
104
+ file_import = FileImport(self.code_model)
105
105
  if self.code_model.need_request_converter:
106
106
  file_import.add_submodule_import(
107
107
  "azure.core.pipeline.transport",
@@ -122,10 +122,9 @@ class GeneralSerializer:
122
122
  TypingSection.TYPING,
123
123
  )
124
124
  file_import.add_msrest_import(
125
- self.code_model,
126
- ".." if self.async_mode else ".",
127
- MsrestImportType.SerializerDeserializer,
128
- TypingSection.TYPING,
125
+ relative_path=".." if self.async_mode else ".",
126
+ msrest_import_type=MsrestImportType.SerializerDeserializer,
127
+ typing_section=TypingSection.TYPING,
129
128
  )
130
129
  for client in clients:
131
130
  file_import.add_submodule_import(
@@ -152,7 +151,7 @@ class GeneralSerializer:
152
151
 
153
152
  def serialize_config_file(self, clients: List[Client]) -> str:
154
153
  template = self.env.get_template("config_container.py.jinja2")
155
- imports = FileImport()
154
+ imports = FileImport(self.code_model)
156
155
  for client in self.code_model.clients:
157
156
  imports.merge(client.config.imports(self.async_mode))
158
157
  return template.render(
@@ -3,14 +3,13 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import List, cast
6
+ from typing import List
7
7
  from abc import ABC, abstractmethod
8
8
 
9
9
  from jinja2 import Environment
10
- from ..models import ModelType, CodeModel, Property
10
+ from ..models import ModelType, CodeModel, Property, ConstantType, EnumValue
11
11
  from ..models.imports import FileImport, TypingSection, MsrestImportType, ImportType
12
12
  from .import_serializer import FileImportSerializer
13
- from ..models.constant_type import ConstantType
14
13
 
15
14
 
16
15
  def _documentation_string(
@@ -134,9 +133,11 @@ class _ModelSerializer(ABC):
134
133
 
135
134
  class MsrestModelSerializer(_ModelSerializer):
136
135
  def imports(self) -> FileImport:
137
- file_import = FileImport()
136
+ file_import = FileImport(self.code_model)
138
137
  file_import.add_msrest_import(
139
- self.code_model, "..", MsrestImportType.Module, TypingSection.REGULAR
138
+ relative_path="..",
139
+ msrest_import_type=MsrestImportType.Module,
140
+ typing_section=TypingSection.REGULAR,
140
141
  )
141
142
  for model in self.code_model.model_types:
142
143
  file_import.merge(model.imports(is_operation_file=False))
@@ -201,7 +202,7 @@ class MsrestModelSerializer(_ModelSerializer):
201
202
 
202
203
  class DpgModelSerializer(_ModelSerializer):
203
204
  def imports(self) -> FileImport:
204
- file_import = FileImport()
205
+ file_import = FileImport(self.code_model)
205
206
  file_import.add_submodule_import(
206
207
  "..",
207
208
  "_model_base",
@@ -266,8 +267,8 @@ class DpgModelSerializer(_ModelSerializer):
266
267
  field = "rest_discriminator" if prop.is_discriminator else "rest_field"
267
268
  type_ignore = (
268
269
  prop.is_discriminator
269
- and prop.is_discriminator
270
- and cast(ConstantType, prop.type).value
270
+ and isinstance(prop.type, (ConstantType, EnumValue))
271
+ and prop.type.value
271
272
  )
272
273
  return (
273
274
  f"{prop.client_name}: {prop.type_annotation()} ="
@@ -280,7 +281,7 @@ class DpgModelSerializer(_ModelSerializer):
280
281
  if prop.constant or prop.is_discriminator:
281
282
  init_args.append(
282
283
  f"self.{prop.client_name}: {prop.type_annotation()} = "
283
- f"{cast(ConstantType, prop.type).get_declaration()}"
284
+ f"{prop.get_declaration()}"
284
285
  )
285
286
  return init_args
286
287
 
@@ -57,7 +57,7 @@ class OperationGroupsSerializer:
57
57
  if self.operation_group
58
58
  else [og for client in self.clients for og in client.operation_groups]
59
59
  )
60
- imports = FileImport()
60
+ imports = FileImport(code_model=self.code_model)
61
61
  for operation_group in operation_groups:
62
62
  imports.merge(
63
63
  operation_group.imports(
@@ -15,7 +15,7 @@ class PatchSerializer:
15
15
 
16
16
  def serialize(self) -> str:
17
17
  template = self.env.get_template("patch.py.jinja2")
18
- imports = FileImport()
18
+ imports = FileImport(code_model=self.code_model)
19
19
  imports.add_submodule_import("typing", "List", ImportType.STDLIB)
20
20
  return template.render(
21
21
  code_model=self.code_model,
@@ -26,7 +26,7 @@ class RequestBuildersSerializer:
26
26
 
27
27
  @property
28
28
  def imports(self) -> FileImport:
29
- file_import = FileImport()
29
+ file_import = FileImport(code_model=self.code_model)
30
30
  for request_builder in self.request_builders:
31
31
  if request_builder.group_name == self.group_name:
32
32
  file_import.merge(request_builder.imports())
@@ -43,7 +43,7 @@ class SampleSerializer:
43
43
  }
44
44
 
45
45
  def _imports(self) -> FileImportSerializer:
46
- imports = FileImport()
46
+ imports = FileImport(self.code_model)
47
47
  namespace_from_package_name = get_namespace_from_package_name(
48
48
  self.code_model.options["package_name"]
49
49
  )
@@ -15,7 +15,7 @@ class TypesSerializer:
15
15
  self.env = env
16
16
 
17
17
  def imports(self) -> FileImport:
18
- file_import = FileImport()
18
+ file_import = FileImport(code_model=self.code_model)
19
19
  if self.code_model.named_unions:
20
20
  file_import.add_submodule_import(
21
21
  "typing",
@@ -24,7 +24,7 @@
24
24
  {% else %}
25
25
  request_copy.url = self._client.format_url(request_copy.url)
26
26
  {% endif %}
27
- return self._client.send_request(request_copy, **kwargs)
27
+ return self._client.send_request(request_copy, **kwargs) # type: ignore
28
28
 
29
29
  {{ keywords.def }} close(self) -> None:
30
30
  {{ keywords.await }}self._client.close()
@@ -7,7 +7,7 @@ class {{ enum.name }}({{ enum.value_type.type_annotation(is_operation_file=False
7
7
 
8
8
  {% for value in enum.values %}
9
9
  {{ value.name }} = {{ enum.value_type.get_declaration(value.value) }}
10
- {% if value.description %}
11
- """{{ value.description | wordwrap(width=95, break_long_words=False, break_on_hyphens=False, wrapstring='\n #: ') }}"""
10
+ {% if value.description(is_operation_file=False) %}
11
+ """{{ value.description(is_operation_file=False) | wordwrap(width=95, break_long_words=False, break_on_hyphens=False, wrapstring='\n #: ') }}"""
12
12
  {% endif %}
13
- {% endfor %}
13
+ {% endfor %}
@@ -553,7 +553,7 @@ class Model(_MyMutableMapping):
553
553
  if exclude_readonly:
554
554
  readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)]
555
555
  for k, v in self.items():
556
- if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false
556
+ if exclude_readonly and k in readonly_props: # pyright: ignore[reportUnboundVariable]
557
557
  continue
558
558
  result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly)
559
559
  return result