@autorest/python 6.9.6 → 6.9.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. package/autorest/codegen/__init__.py +27 -8
  2. package/autorest/codegen/models/base.py +2 -5
  3. package/autorest/codegen/models/client.py +12 -8
  4. package/autorest/codegen/models/code_model.py +5 -1
  5. package/autorest/codegen/models/combined_type.py +1 -1
  6. package/autorest/codegen/models/constant_type.py +1 -1
  7. package/autorest/codegen/models/credential_types.py +21 -11
  8. package/autorest/codegen/models/dictionary_type.py +1 -1
  9. package/autorest/codegen/models/enum_type.py +2 -2
  10. package/autorest/codegen/models/imports.py +18 -69
  11. package/autorest/codegen/models/list_type.py +1 -1
  12. package/autorest/codegen/models/model_type.py +1 -1
  13. package/autorest/codegen/models/operation.py +11 -15
  14. package/autorest/codegen/models/operation_group.py +2 -2
  15. package/autorest/codegen/models/paging_operation.py +4 -4
  16. package/autorest/codegen/models/parameter.py +1 -1
  17. package/autorest/codegen/models/primitive_types.py +12 -16
  18. package/autorest/codegen/models/property.py +1 -1
  19. package/autorest/codegen/models/request_builder.py +4 -4
  20. package/autorest/codegen/models/response.py +32 -10
  21. package/autorest/codegen/serializers/__init__.py +1 -1
  22. package/autorest/codegen/serializers/builder_serializer.py +15 -11
  23. package/autorest/codegen/serializers/client_serializer.py +4 -7
  24. package/autorest/codegen/serializers/enum_serializer.py +2 -1
  25. package/autorest/codegen/serializers/general_serializer.py +9 -10
  26. package/autorest/codegen/serializers/model_serializer.py +2 -2
  27. package/autorest/codegen/serializers/operation_groups_serializer.py +2 -1
  28. package/autorest/codegen/serializers/patch_serializer.py +2 -2
  29. package/autorest/codegen/serializers/request_builders_serializer.py +1 -1
  30. package/autorest/codegen/serializers/sample_serializer.py +17 -13
  31. package/autorest/codegen/serializers/types_serializer.py +1 -1
  32. package/autorest/codegen/templates/enum_container.py.jinja2 +1 -1
  33. package/autorest/codegen/templates/model_base.py.jinja2 +12 -9
  34. package/autorest/codegen/templates/model_dpg.py.jinja2 +1 -1
  35. package/autorest/codegen/templates/model_msrest.py.jinja2 +1 -1
  36. package/autorest/codegen/templates/packaging_templates/LICENSE.jinja2 +1 -1
  37. package/autorest/codegen/templates/packaging_templates/README.md.jinja2 +1 -1
  38. package/autorest/codegen/templates/packaging_templates/setup.py.jinja2 +4 -4
  39. package/autorest/codegen/templates/patch.py.jinja2 +1 -1
  40. package/autorest/codegen/templates/serialization.py.jinja2 +4 -7
  41. package/autorest/multiapi/models/code_model.py +2 -0
  42. package/autorest/multiapi/serializers/__init__.py +1 -2
  43. package/package.json +1 -1
@@ -107,7 +107,7 @@ class BinaryType(PrimitiveType):
107
107
  from .combined_type import CombinedType
108
108
  from .operation import OperationBase
109
109
 
110
- file_import = self.init_file_import()
110
+ file_import = FileImport(self.code_model)
111
111
  file_import.add_submodule_import("typing", "IO", ImportType.STDLIB)
112
112
  operation = kwargs.get("operation")
113
113
  if (
@@ -145,7 +145,7 @@ class BinaryIteratorType(PrimitiveType):
145
145
  return self.get_declaration("Iterator[bytes]")
146
146
 
147
147
  def imports(self, **kwargs: Any) -> FileImport:
148
- file_import = self.init_file_import()
148
+ file_import = FileImport(self.code_model)
149
149
  iterator = "AsyncIterator" if kwargs.get("async_mode") else "Iterator"
150
150
  file_import.add_submodule_import("typing", iterator, ImportType.STDLIB)
151
151
  return file_import
@@ -171,7 +171,7 @@ class AnyType(PrimitiveType):
171
171
  return self.get_declaration({})
172
172
 
173
173
  def imports(self, **kwargs: Any) -> FileImport:
174
- file_import = self.init_file_import()
174
+ file_import = FileImport(self.code_model)
175
175
  file_import.add_submodule_import(
176
176
  "typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL
177
177
  )
@@ -204,7 +204,7 @@ class AnyObjectType(PrimitiveType):
204
204
  return "isinstance({}, MutableMapping)"
205
205
 
206
206
  def imports(self, **kwargs: Any) -> FileImport:
207
- file_import = self.init_file_import()
207
+ file_import = FileImport(self.code_model)
208
208
  file_import.define_mutable_mapping_type()
209
209
  return file_import
210
210
 
@@ -386,7 +386,7 @@ class DatetimeType(PrimitiveType):
386
386
  return f'"{value}"'
387
387
 
388
388
  def imports(self, **kwargs: Any) -> FileImport:
389
- file_import = self.init_file_import()
389
+ file_import = FileImport(self.code_model)
390
390
  file_import.add_import("datetime", ImportType.STDLIB)
391
391
  return file_import
392
392
 
@@ -429,7 +429,7 @@ class TimeType(PrimitiveType):
429
429
  return f'"{value}"'
430
430
 
431
431
  def imports(self, **kwargs: Any) -> FileImport:
432
- file_import = self.init_file_import()
432
+ file_import = FileImport(self.code_model)
433
433
  file_import.add_import("datetime", ImportType.STDLIB)
434
434
  return file_import
435
435
 
@@ -476,7 +476,7 @@ class UnixTimeType(PrimitiveType):
476
476
  return f'"{value}"'
477
477
 
478
478
  def imports(self, **kwargs: Any) -> FileImport:
479
- file_import = self.init_file_import()
479
+ file_import = FileImport(self.code_model)
480
480
  file_import.add_import("datetime", ImportType.STDLIB)
481
481
  return file_import
482
482
 
@@ -519,7 +519,7 @@ class DateType(PrimitiveType):
519
519
  return f'"{value}"'
520
520
 
521
521
  def imports(self, **kwargs: Any) -> FileImport:
522
- file_import = self.init_file_import()
522
+ file_import = FileImport(self.code_model)
523
523
  file_import.add_import("datetime", ImportType.STDLIB)
524
524
  return file_import
525
525
 
@@ -562,7 +562,7 @@ class DurationType(PrimitiveType):
562
562
  return f'"{value}"'
563
563
 
564
564
  def imports(self, **kwargs: Any) -> FileImport:
565
- file_import = self.init_file_import()
565
+ file_import = FileImport(self.code_model)
566
566
  file_import.add_import("datetime", ImportType.STDLIB)
567
567
  return file_import
568
568
 
@@ -612,18 +612,14 @@ class SdkCoreType(PrimitiveType):
612
612
  self.name = yaml_data.get("name", "")
613
613
 
614
614
  def docstring_type(self, **kwargs: Any) -> str:
615
- return f"~{self.init_file_import().import_core}" + self.type_annotation(
616
- **kwargs
617
- )
615
+ return f"~{self.code_model.core_library}.{self.type_annotation(**kwargs)}"
618
616
 
619
617
  def type_annotation(self, **kwargs: Any) -> str:
620
618
  return self.name
621
619
 
622
620
  def imports(self, **kwargs: Any) -> FileImport:
623
- file_import = self.init_file_import()
624
- file_import.add_submodule_import(
625
- file_import.import_core, self.name, ImportType.SDKCORE
626
- )
621
+ file_import = super().imports(**kwargs)
622
+ file_import.add_submodule_import("", self.name, ImportType.SDKCORE)
627
623
  return file_import
628
624
 
629
625
  @property
@@ -147,7 +147,7 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
147
147
  return retval or None
148
148
 
149
149
  def imports(self, **kwargs) -> FileImport:
150
- file_import = self.init_file_import()
150
+ file_import = FileImport(self.code_model)
151
151
  if self.is_discriminator and isinstance(self.type, EnumType):
152
152
  return file_import
153
153
  file_import.merge(
@@ -78,10 +78,10 @@ class RequestBuilderBase(BaseBuilder[ParameterListType]):
78
78
  )
79
79
 
80
80
  def response_docstring_type(self, **kwargs) -> str:
81
- return f"~{self.init_file_import().import_core_rest}.HttpRequest"
81
+ return f"~{self.code_model.core_library}.rest.HttpRequest"
82
82
 
83
83
  def imports(self) -> FileImport:
84
- file_import = self.init_file_import()
84
+ file_import = FileImport(self.code_model)
85
85
  relative_path = ".."
86
86
  if (
87
87
  not self.code_model.options["builders_visibility"] == "embedded"
@@ -98,14 +98,14 @@ class RequestBuilderBase(BaseBuilder[ParameterListType]):
98
98
  )
99
99
 
100
100
  file_import.add_submodule_import(
101
- file_import.import_core_rest,
101
+ "rest",
102
102
  "HttpRequest",
103
103
  ImportType.SDKCORE,
104
104
  )
105
105
 
106
106
  if self.parameters.headers or self.parameters.query:
107
107
  file_import.add_submodule_import(
108
- file_import.import_core_utils,
108
+ "utils",
109
109
  "case_insensitive_dict",
110
110
  ImportType.SDKCORE,
111
111
  )
@@ -114,7 +114,7 @@ class Response(BaseModel):
114
114
  return self.type.docstring_type(**kwargs) if self.type else "None"
115
115
 
116
116
  def _imports_shared(self, **kwargs: Any) -> FileImport:
117
- file_import = self.init_file_import()
117
+ file_import = FileImport(self.code_model)
118
118
  if self.type:
119
119
  file_import.merge(self.type.imports(**kwargs))
120
120
  if self.nullable:
@@ -134,6 +134,16 @@ class Response(BaseModel):
134
134
  def imports_for_multiapi(self, **kwargs: Any) -> FileImport:
135
135
  return self._imports_shared(**kwargs)
136
136
 
137
+ def _get_import_type(self, input_path: str) -> ImportType:
138
+ # helper function to return imports for responses based off
139
+ # of whether we're importing from the core library, or users
140
+ # are customizing responses
141
+ return (
142
+ ImportType.SDKCORE
143
+ if self.code_model.core_library.split(".")[0] in input_path
144
+ else ImportType.THIRDPARTY
145
+ )
146
+
137
147
  @classmethod
138
148
  def from_yaml(
139
149
  cls, yaml_data: Dict[str, Any], code_model: "CodeModel"
@@ -166,11 +176,14 @@ class PagingResponse(Response):
166
176
  self.item_type = self.code_model.lookup_type(id(self.yaml_data["itemType"]))
167
177
  self.pager_sync: str = (
168
178
  self.yaml_data.get("pagerSync")
169
- or f"{self.init_file_import().import_core_paging}.ItemPaged"
179
+ or f"{self.code_model.core_library}.paging.ItemPaged"
180
+ )
181
+ default_paging_submodule = (
182
+ f"{'' if self.code_model.options['unbranded'] else 'async_'}paging"
170
183
  )
171
184
  self.pager_async: str = (
172
185
  self.yaml_data.get("pagerAsync")
173
- or f"{self.init_file_import().import_core_paging_async}.AsyncItemPaged"
186
+ or f"{self.code_model.core_library}.{default_paging_submodule}.AsyncItemPaged"
174
187
  )
175
188
 
176
189
  def get_polymorphic_subtypes(self, polymorphic_subtypes: List["ModelType"]) -> None:
@@ -179,6 +192,9 @@ class PagingResponse(Response):
179
192
  def get_json_template_representation(self) -> Any:
180
193
  return self.item_type.get_json_template_representation()
181
194
 
195
+ def get_pager_import_path(self, async_mode: bool) -> str:
196
+ return ".".join(self.get_pager_path(async_mode).split(".")[:-1])
197
+
182
198
  def get_pager_path(self, async_mode: bool) -> str:
183
199
  return self.pager_async if async_mode else self.pager_sync
184
200
 
@@ -201,10 +217,12 @@ class PagingResponse(Response):
201
217
  def _imports_shared(self, **kwargs: Any) -> FileImport:
202
218
  file_import = super()._imports_shared(**kwargs)
203
219
  async_mode = kwargs.get("async_mode", False)
204
- pager_import_path = ".".join(self.get_pager_path(async_mode).split(".")[:-1])
205
220
  pager = self.get_pager(async_mode)
221
+ pager_path = self.get_pager_import_path(async_mode)
206
222
 
207
- file_import.add_submodule_import(pager_import_path, pager, ImportType.SDKCORE)
223
+ file_import.add_submodule_import(
224
+ pager_path, pager, self._get_import_type(pager_path)
225
+ )
208
226
  return file_import
209
227
 
210
228
  def imports(self, **kwargs: Any) -> FileImport:
@@ -212,7 +230,7 @@ class PagingResponse(Response):
212
230
  async_mode = kwargs.get("async_mode")
213
231
  if async_mode:
214
232
  file_import.add_submodule_import(
215
- file_import.import_core_paging_async,
233
+ f"{'' if self.code_model.options['unbranded'] else 'async_'}paging",
216
234
  "AsyncList",
217
235
  ImportType.SDKCORE,
218
236
  )
@@ -283,7 +301,9 @@ class LROResponse(Response):
283
301
  async_mode = kwargs["async_mode"]
284
302
  poller_import_path = ".".join(self.get_poller_path(async_mode).split(".")[:-1])
285
303
  poller = self.get_poller(async_mode)
286
- file_import.add_submodule_import(poller_import_path, poller, ImportType.SDKCORE)
304
+ file_import.add_submodule_import(
305
+ poller_import_path, poller, self._get_import_type(poller_import_path)
306
+ )
287
307
  return file_import
288
308
 
289
309
  def imports(self, **kwargs: Any) -> FileImport:
@@ -297,7 +317,7 @@ class LROResponse(Response):
297
317
  file_import.add_submodule_import(
298
318
  default_polling_method_import_path,
299
319
  default_polling_method,
300
- ImportType.SDKCORE,
320
+ self._get_import_type(default_polling_method_import_path),
301
321
  )
302
322
  default_no_polling_method_import_path = ".".join(
303
323
  self.get_no_polling_method_path(async_mode).split(".")[:-1]
@@ -306,7 +326,7 @@ class LROResponse(Response):
306
326
  file_import.add_submodule_import(
307
327
  default_no_polling_method_import_path,
308
328
  default_no_polling_method,
309
- ImportType.SDKCORE,
329
+ self._get_import_type(default_no_polling_method_import_path),
310
330
  )
311
331
 
312
332
  base_polling_method_import_path = ".".join(
@@ -314,7 +334,9 @@ class LROResponse(Response):
314
334
  )
315
335
  base_polling_method = self.get_base_polling_method(async_mode)
316
336
  file_import.add_submodule_import(
317
- base_polling_method_import_path, base_polling_method, ImportType.SDKCORE
337
+ base_polling_method_import_path,
338
+ base_polling_method,
339
+ self._get_import_type(base_polling_method_import_path),
318
340
  )
319
341
  return file_import
320
342
 
@@ -238,7 +238,7 @@ class JinjaSerializer(ReaderAndWriter): # pylint: disable=abstract-method
238
238
  else:
239
239
  return
240
240
  serializer = GeneralSerializer(self.code_model, env, async_mode=False)
241
- params = self.code_model.options["package_configuration"] or {}
241
+ params = self.code_model.options["packaging_files_config"] or {}
242
242
  for template_name in package_files:
243
243
  if (
244
244
  self.code_model.options["unbranded"]
@@ -428,9 +428,7 @@ class _BuilderBaseSerializer(Generic[BuilderType]): # pylint: disable=abstract-
428
428
 
429
429
  @property
430
430
  def pipeline_name(self) -> str:
431
- if not self.code_model.options["unbranded"]:
432
- return "_pipeline"
433
- return "pipeline"
431
+ return f"{'' if self.code_model.options['unbranded'] else '_'}pipeline"
434
432
 
435
433
 
436
434
  ############################## REQUEST BUILDERS ##############################
@@ -489,13 +487,13 @@ class RequestBuilderSerializer(
489
487
  return False
490
488
 
491
489
  def response_docstring(self, builder: RequestBuilderType) -> List[str]:
492
- import_core_rest = builder.init_file_import().import_core_rest
490
+ request_full_path = f"{self.code_model.core_library}.rest.HttpRequest"
493
491
  response_str = (
494
- f":return: Returns an :class:`~{import_core_rest}.HttpRequest` that you will pass to the client's "
492
+ f":return: Returns an :class:`~{request_full_path}` that you will pass to the client's "
495
493
  + "`send_request` method. See https://aka.ms/azsdk/dpcodegen/python/send_request for how to "
496
494
  + "incorporate this response into your code flow."
497
495
  )
498
- rtype_str = f":rtype: ~{import_core_rest}.HttpRequest"
496
+ rtype_str = f":rtype: ~{request_full_path}"
499
497
  return [response_str, rtype_str]
500
498
 
501
499
  def pop_kwargs_from_signature(self, builder: RequestBuilderType) -> List[str]:
@@ -645,7 +643,7 @@ class _OperationSerializer(
645
643
  type_ignore = self.async_mode and builder.group_name == "" # is in a mixin
646
644
  stream_value = (
647
645
  f'kwargs.pop("stream", {builder.has_stream_response})'
648
- if builder.expose_stream_keyword
646
+ if builder.expose_stream_keyword and builder.has_response_body
649
647
  else builder.has_stream_response
650
648
  )
651
649
  return [
@@ -730,7 +728,7 @@ class _OperationSerializer(
730
728
  return [
731
729
  response_str,
732
730
  rtype_str,
733
- f":raises ~{builder.init_file_import().import_core_exceptions}.HttpResponseError:",
731
+ f":raises ~{self.code_model.core_library}.exceptions.HttpResponseError:",
734
732
  ]
735
733
 
736
734
  def _serialize_body_parameter(self, builder: OperationType) -> List[str]:
@@ -763,13 +761,13 @@ class _OperationSerializer(
763
761
  if hasattr(body_param.type, "encode") and body_param.type.encode: # type: ignore
764
762
  create_body_call = (
765
763
  f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
766
- "cls=AzureJSONEncoder, exclude_readonly=True, "
764
+ "cls=SdkJSONEncoder, exclude_readonly=True, "
767
765
  f"format='{body_param.type.encode}') # type: ignore" # type: ignore
768
766
  )
769
767
  else:
770
768
  create_body_call = (
771
769
  f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
772
- "cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore"
770
+ "cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore"
773
771
  )
774
772
  else:
775
773
  create_body_call = f"_{body_kwarg_name} = {body_param.client_name}"
@@ -1082,12 +1080,18 @@ class _OperationSerializer(
1082
1080
  if builder.has_stream_response:
1083
1081
  deserialize_code.append("deserialized = response.content")
1084
1082
  else:
1083
+ format_filed = (
1084
+ f', format="{response.type.encode}"'
1085
+ if isinstance(response.type, ByteArraySchema)
1086
+ and response.default_content_type == "application/json"
1087
+ else ""
1088
+ )
1085
1089
  deserialize_code.append("deserialized = _deserialize(")
1086
1090
  deserialize_code.append(
1087
1091
  f" {response.type.type_annotation(is_operation_file=True)},{pylint_disable}"
1088
1092
  )
1089
1093
  deserialize_code.append(
1090
- f" response.json(){response.result_property}"
1094
+ f" response.json(){response.result_property}{format_filed}"
1091
1095
  )
1092
1096
  deserialize_code.append(")")
1093
1097
 
@@ -253,7 +253,7 @@ class ClientSerializer:
253
253
 
254
254
  def _rest_request_example(self, async_mode: bool) -> List[str]:
255
255
  retval = [
256
- f">>> from {self.client.init_file_import().import_core_rest} import HttpRequest"
256
+ f">>> from {self.client.code_model.core_library}.rest import HttpRequest"
257
257
  ]
258
258
  retval.append('>>> request = HttpRequest("GET", "https://www.example.org/")')
259
259
  retval.append("<HttpRequest [GET], url: 'https://www.example.org/'>")
@@ -261,6 +261,7 @@ class ClientSerializer:
261
261
  return retval
262
262
 
263
263
  def send_request_description(self, async_mode: bool) -> List[str]:
264
+ rest_library = f"{self.client.code_model.core_library}.rest"
264
265
  retval = ['"""Runs the network request through the client\'s chained policies.']
265
266
  retval.append("")
266
267
  if self.client.code_model.options["builders_visibility"] != "embedded":
@@ -273,9 +274,7 @@ class ClientSerializer:
273
274
  )
274
275
  retval.append("")
275
276
  retval.append(":param request: The network request you want to make. Required.")
276
- retval.append(
277
- f":type request: ~{self.client.init_file_import().import_core_rest}.HttpRequest"
278
- )
277
+ retval.append(f":type request: ~{rest_library}.HttpRequest")
279
278
  retval.append(
280
279
  ":keyword bool stream: Whether the response payload will be streamed. Defaults to False."
281
280
  )
@@ -283,9 +282,7 @@ class ClientSerializer:
283
282
  ":return: The response of your network call. Does not do error handling on your response."
284
283
  )
285
284
  http_response = "AsyncHttpResponse" if async_mode else "HttpResponse"
286
- retval.append(
287
- f":rtype: ~{self.client.init_file_import().import_core_rest}.{http_response}"
288
- )
285
+ retval.append(f":rtype: ~{rest_library}.{http_response}")
289
286
  retval.append('"""')
290
287
  return retval
291
288
 
@@ -5,6 +5,7 @@
5
5
  # --------------------------------------------------------------------------
6
6
 
7
7
  from .base_serializer import BaseSerializer
8
+ from ..models import FileImport
8
9
 
9
10
 
10
11
  class EnumSerializer(BaseSerializer):
@@ -12,5 +13,5 @@ class EnumSerializer(BaseSerializer):
12
13
  # Generate the enum file
13
14
  template = self.env.get_template("enum_container.py.jinja2")
14
15
  return template.render(
15
- code_model=self.code_model, file_import=self.init_file_import()
16
+ code_model=self.code_model, file_import=FileImport(self.code_model)
16
17
  )
@@ -6,7 +6,7 @@
6
6
  from typing import Any, List
7
7
  from jinja2 import Environment
8
8
  from .import_serializer import FileImportSerializer, TypingSection
9
- from ..models.imports import MsrestImportType
9
+ from ..models.imports import MsrestImportType, FileImport
10
10
  from ..models import (
11
11
  ImportType,
12
12
  CodeModel,
@@ -59,7 +59,7 @@ class GeneralSerializer(BaseSerializer):
59
59
  }
60
60
  params.update(self.code_model.options)
61
61
  params.update(kwargs)
62
- return template.render(file_import=self.init_file_import(), **params)
62
+ return template.render(file_import=FileImport(self.code_model), **params)
63
63
 
64
64
  def serialize_pkgutil_init_file(self) -> str:
65
65
  template = self.env.get_template("pkgutil_init.py.jinja2")
@@ -76,7 +76,7 @@ class GeneralSerializer(BaseSerializer):
76
76
  def serialize_service_client_file(self, clients: List[Client]) -> str:
77
77
  template = self.env.get_template("client_container.py.jinja2")
78
78
 
79
- imports = self.init_file_import()
79
+ imports = FileImport(self.code_model)
80
80
  for client in clients:
81
81
  imports.merge(client.imports(self.async_mode))
82
82
 
@@ -92,7 +92,7 @@ class GeneralSerializer(BaseSerializer):
92
92
  template = self.env.get_template("vendor.py.jinja2")
93
93
 
94
94
  # configure imports
95
- file_import = self.init_file_import()
95
+ file_import = FileImport(self.code_model)
96
96
  if self.code_model.need_request_converter:
97
97
  file_import.add_submodule_import(
98
98
  "azure.core.pipeline.transport",
@@ -107,7 +107,7 @@ class GeneralSerializer(BaseSerializer):
107
107
  ImportType.STDLIB,
108
108
  )
109
109
  file_import.add_submodule_import(
110
- file_import.import_core_pipeline_client,
110
+ "runtime" if self.code_model.options["unbranded"] else "",
111
111
  f"{'Async' if self.async_mode else ''}PipelineClient",
112
112
  ImportType.SDKCORE,
113
113
  TypingSection.TYPING,
@@ -126,7 +126,7 @@ class GeneralSerializer(BaseSerializer):
126
126
  if self.code_model.has_etag:
127
127
  file_import.add_submodule_import("typing", "Optional", ImportType.STDLIB)
128
128
  file_import.add_submodule_import(
129
- file_import.import_core,
129
+ "",
130
130
  "MatchConditions",
131
131
  ImportType.SDKCORE,
132
132
  )
@@ -142,7 +142,7 @@ class GeneralSerializer(BaseSerializer):
142
142
 
143
143
  def serialize_config_file(self, clients: List[Client]) -> str:
144
144
  template = self.env.get_template("config_container.py.jinja2")
145
- imports = self.init_file_import()
145
+ imports = FileImport(self.code_model)
146
146
  for client in self.code_model.clients:
147
147
  imports.merge(client.config.imports(self.async_mode))
148
148
  return template.render(
@@ -160,14 +160,13 @@ class GeneralSerializer(BaseSerializer):
160
160
  def serialize_serialization_file(self) -> str:
161
161
  template = self.env.get_template("serialization.py.jinja2")
162
162
  return template.render(
163
- import_core_exceptions=self.init_file_import().import_core_exceptions,
164
- import_core_serialization=self.init_file_import().import_core_serialization,
163
+ code_model=self.code_model,
165
164
  )
166
165
 
167
166
  def serialize_model_base_file(self) -> str:
168
167
  template = self.env.get_template("model_base.py.jinja2")
169
168
  return template.render(
170
- code_model=self.code_model, file_import=self.init_file_import()
169
+ code_model=self.code_model, file_import=FileImport(self.code_model)
171
170
  )
172
171
 
173
172
  def serialize_validation_file(self) -> str:
@@ -129,7 +129,7 @@ class _ModelSerializer(BaseSerializer, ABC):
129
129
 
130
130
  class MsrestModelSerializer(_ModelSerializer):
131
131
  def imports(self) -> FileImport:
132
- file_import = self.init_file_import()
132
+ file_import = FileImport(self.code_model)
133
133
  file_import.add_msrest_import(
134
134
  relative_path="..",
135
135
  msrest_import_type=MsrestImportType.Module,
@@ -198,7 +198,7 @@ class MsrestModelSerializer(_ModelSerializer):
198
198
 
199
199
  class DpgModelSerializer(_ModelSerializer):
200
200
  def imports(self) -> FileImport:
201
- file_import = self.init_file_import()
201
+ file_import = FileImport(self.code_model)
202
202
  file_import.add_submodule_import(
203
203
  "..",
204
204
  "_model_base",
@@ -13,6 +13,7 @@ from ..models import (
13
13
  RequestBuilder,
14
14
  OverloadedRequestBuilder,
15
15
  Client,
16
+ FileImport,
16
17
  )
17
18
  from .import_serializer import FileImportSerializer
18
19
  from .builder_serializer import (
@@ -56,7 +57,7 @@ class OperationGroupsSerializer(BaseSerializer):
56
57
  if self.operation_group
57
58
  else [og for client in self.clients for og in client.operation_groups]
58
59
  )
59
- imports = self.init_file_import()
60
+ imports = FileImport(self.code_model)
60
61
  for operation_group in operation_groups:
61
62
  imports.merge(
62
63
  operation_group.imports(
@@ -4,14 +4,14 @@
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
6
  from .import_serializer import FileImportSerializer
7
- from ..models import ImportType
7
+ from ..models import ImportType, FileImport
8
8
  from .base_serializer import BaseSerializer
9
9
 
10
10
 
11
11
  class PatchSerializer(BaseSerializer):
12
12
  def serialize(self) -> str:
13
13
  template = self.env.get_template("patch.py.jinja2")
14
- imports = self.init_file_import()
14
+ imports = FileImport(self.code_model)
15
15
  imports.add_submodule_import("typing", "List", ImportType.STDLIB)
16
16
  return template.render(
17
17
  code_model=self.code_model,
@@ -26,7 +26,7 @@ class RequestBuildersSerializer(BaseSerializer):
26
26
 
27
27
  @property
28
28
  def imports(self) -> FileImport:
29
- file_import = self.init_file_import()
29
+ file_import = FileImport(self.code_model)
30
30
  for request_builder in self.request_builders:
31
31
  if request_builder.group_name == self.group_name:
32
32
  file_import.merge(request_builder.imports())
@@ -8,15 +8,19 @@ import logging
8
8
  from typing import Dict, Any, Union, Tuple
9
9
  from jinja2 import Environment
10
10
 
11
- from autorest.codegen.models.credential_types import KeyCredentialType
12
- from autorest.codegen.models.credential_types import TokenCredentialType
13
- from autorest.codegen.models.imports import ImportType
14
11
  from autorest.codegen.models.operation import OperationBase
15
- from autorest.codegen.models.operation_group import OperationGroup
16
- from autorest.codegen.models.parameter import Parameter, BodyParameter
17
- from autorest.codegen.serializers.import_serializer import FileImportSerializer
18
- from autorest.codegen.serializers.base_serializer import BaseSerializer
19
- from ..models import CodeModel
12
+ from .import_serializer import FileImportSerializer
13
+ from .base_serializer import BaseSerializer
14
+ from ..models import (
15
+ CodeModel,
16
+ KeyCredentialType,
17
+ TokenCredentialType,
18
+ ImportType,
19
+ OperationGroup,
20
+ Parameter,
21
+ BodyParameter,
22
+ FileImport,
23
+ )
20
24
  from .utils import get_namespace_config, get_namespace_from_package_name
21
25
  from ..._utils import to_snake_case
22
26
 
@@ -43,7 +47,7 @@ class SampleSerializer(BaseSerializer):
43
47
  }
44
48
 
45
49
  def _imports(self) -> FileImportSerializer:
46
- imports = self.init_file_import()
50
+ imports = FileImport(self.code_model)
47
51
  namespace_from_package_name = get_namespace_from_package_name(
48
52
  self.code_model.options["package_name"]
49
53
  )
@@ -57,18 +61,18 @@ class SampleSerializer(BaseSerializer):
57
61
  ) > namespace_from_package_name.count("."):
58
62
  namespace = namespace_config
59
63
  client = self.code_model.clients[0]
60
- imports.add_submodule_import(namespace, client.name, ImportType.THIRDPARTY)
64
+ imports.add_submodule_import(namespace, client.name, ImportType.LOCAL)
61
65
  credential_type = getattr(client.credential, "type", None)
62
66
  if isinstance(credential_type, TokenCredentialType):
63
67
  imports.add_submodule_import(
64
- "azure.identity", "DefaultAzureCredential", ImportType.THIRDPARTY
68
+ "azure.identity", "DefaultAzureCredential", ImportType.SDKCORE
65
69
  )
66
70
  elif isinstance(credential_type, KeyCredentialType):
67
71
  imports.add_import("os", ImportType.STDLIB)
68
72
  imports.add_submodule_import(
69
- f"{imports.import_core}.credentials",
73
+ "credentials",
70
74
  "AzureKeyCredential",
71
- ImportType.THIRDPARTY,
75
+ ImportType.SDKCORE,
72
76
  )
73
77
  for param in self.operation.parameters.positional:
74
78
  if (
@@ -10,7 +10,7 @@ from .base_serializer import BaseSerializer
10
10
 
11
11
  class TypesSerializer(BaseSerializer):
12
12
  def imports(self) -> FileImport:
13
- file_import = self.init_file_import()
13
+ file_import = FileImport(self.code_model)
14
14
  if self.code_model.named_unions:
15
15
  file_import.add_submodule_import(
16
16
  "typing",
@@ -2,7 +2,7 @@
2
2
  {{ code_model.options['license_header'] }}
3
3
 
4
4
  from enum import Enum
5
- from {{ file_import.import_core_case_insensitive_enum }} import CaseInsensitiveEnumMeta
5
+ from {{ code_model.core_library }}{{ ".utils" if code_model.options["unbranded"] else "" }} import CaseInsensitiveEnumMeta
6
6
 
7
7
  {% for enum in code_model.enums | sort %}
8
8
  {% include "enum.py.jinja2" %}