@autorest/python 6.9.3 → 6.9.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. package/autorest/_utils.py +35 -20
  2. package/autorest/codegen/__init__.py +8 -1
  3. package/autorest/codegen/models/base.py +3 -4
  4. package/autorest/codegen/models/client.py +5 -6
  5. package/autorest/codegen/models/combined_type.py +1 -1
  6. package/autorest/codegen/models/constant_type.py +1 -1
  7. package/autorest/codegen/models/credential_types.py +2 -2
  8. package/autorest/codegen/models/dictionary_type.py +1 -1
  9. package/autorest/codegen/models/enum_type.py +2 -2
  10. package/autorest/codegen/models/imports.py +6 -5
  11. package/autorest/codegen/models/list_type.py +1 -1
  12. package/autorest/codegen/models/lro_operation.py +3 -0
  13. package/autorest/codegen/models/model_type.py +1 -1
  14. package/autorest/codegen/models/operation.py +4 -4
  15. package/autorest/codegen/models/operation_group.py +2 -2
  16. package/autorest/codegen/models/paging_operation.py +1 -1
  17. package/autorest/codegen/models/parameter.py +1 -1
  18. package/autorest/codegen/models/primitive_types.py +20 -25
  19. package/autorest/codegen/models/property.py +3 -1
  20. package/autorest/codegen/models/request_builder.py +4 -5
  21. package/autorest/codegen/models/response.py +1 -1
  22. package/autorest/codegen/serializers/builder_serializer.py +6 -3
  23. package/autorest/codegen/serializers/client_serializer.py +1 -1
  24. package/autorest/codegen/serializers/general_serializer.py +6 -7
  25. package/autorest/codegen/serializers/model_serializer.py +5 -3
  26. package/autorest/codegen/serializers/operation_groups_serializer.py +1 -1
  27. package/autorest/codegen/serializers/patch_serializer.py +1 -1
  28. package/autorest/codegen/serializers/request_builders_serializer.py +1 -1
  29. package/autorest/codegen/serializers/sample_serializer.py +1 -1
  30. package/autorest/codegen/serializers/types_serializer.py +1 -1
  31. package/autorest/codegen/templates/model_base.py.jinja2 +1 -1
  32. package/autorest/jsonrpc/server.py +5 -1
  33. package/autorest/multiapi/models/imports.py +17 -15
  34. package/package.json +1 -1
@@ -120,24 +120,39 @@ KNOWN_TYPES: Dict[str, Dict[str, Any]] = {
120
120
  JSON_REGEXP = re.compile(r"^(application|text)/(.+\+)?json$")
121
121
 
122
122
 
123
- def build_policies(is_arm: bool, async_mode: bool) -> List[str]:
124
- async_prefix = "Async" if async_mode else ""
125
- policies = [
126
- "policies.RequestIdPolicy(**kwargs)",
127
- "self._config.headers_policy",
128
- "self._config.user_agent_policy",
129
- "self._config.proxy_policy",
130
- "policies.ContentDecodePolicy(**kwargs)",
131
- f"{async_prefix}ARMAutoResourceProviderRegistrationPolicy()"
132
- if is_arm
133
- else None,
134
- "self._config.redirect_policy",
135
- "self._config.retry_policy",
136
- "self._config.authentication_policy",
137
- "self._config.custom_hook_policy",
138
- "self._config.logging_policy",
139
- "policies.DistributedTracingPolicy(**kwargs)",
140
- "policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None",
141
- "self._config.http_logging_policy",
142
- ]
123
+ def build_policies(
124
+ is_arm: bool, async_mode: bool, unbranded: bool = False
125
+ ) -> List[str]:
126
+ if not unbranded:
127
+ # for Azure
128
+ async_prefix = "Async" if async_mode else ""
129
+ policies = [
130
+ "policies.RequestIdPolicy(**kwargs)",
131
+ "self._config.headers_policy",
132
+ "self._config.user_agent_policy",
133
+ "self._config.proxy_policy",
134
+ "policies.ContentDecodePolicy(**kwargs)",
135
+ f"{async_prefix}ARMAutoResourceProviderRegistrationPolicy()"
136
+ if is_arm
137
+ else None,
138
+ "self._config.redirect_policy",
139
+ "self._config.retry_policy",
140
+ "self._config.authentication_policy",
141
+ "self._config.custom_hook_policy",
142
+ "self._config.logging_policy",
143
+ "policies.DistributedTracingPolicy(**kwargs)",
144
+ "policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None",
145
+ "self._config.http_logging_policy",
146
+ ]
147
+ else:
148
+ # for non-Azure
149
+ policies = [
150
+ "self._config.headers_policy",
151
+ "self._config.user_agent_policy",
152
+ "self._config.proxy_policy",
153
+ "policies.ContentDecodePolicy(**kwargs)",
154
+ "self._config.retry_policy",
155
+ "self._config.authentication_policy",
156
+ "self._config.logging_policy",
157
+ ]
143
158
  return [p for p in policies if p]
@@ -85,6 +85,11 @@ def _validate_code_model_options(options: Dict[str, Any]) -> None:
85
85
  "Please read https://aka.ms/azsdk/dpcodegen for more details."
86
86
  )
87
87
 
88
+ if options["unbranded"] and options["tracing"]:
89
+ raise ValueError(
90
+ "Can not set --unbranded=true and --tracing=true at the same time."
91
+ )
92
+
88
93
 
89
94
  _LOGGER = logging.getLogger(__name__)
90
95
 
@@ -137,6 +142,7 @@ class CodeGenerator(Plugin):
137
142
  models_mode_default = "dpg"
138
143
 
139
144
  package_name = self.options.get("package-name")
145
+ unbranded = self.options.get("unbranded", False)
140
146
  options: Dict[str, Any] = {
141
147
  "azure_arm": azure_arm,
142
148
  "head_as_boolean": self.options.get("head-as-boolean", True),
@@ -148,7 +154,7 @@ class CodeGenerator(Plugin):
148
154
  "package_name": package_name,
149
155
  "package_version": self.options.get("package-version"),
150
156
  "client_side_validation": self.options.get("client-side-validation", False),
151
- "tracing": self.options.get("tracing", show_operations),
157
+ "tracing": self.options.get("tracing", show_operations and not unbranded),
152
158
  "multiapi": self.options.get("multiapi", False),
153
159
  "polymorphic_examples": self.options.get("polymorphic-examples", 5),
154
160
  "models_mode": self.options.get("models-mode", models_mode_default).lower(),
@@ -177,6 +183,7 @@ class CodeGenerator(Plugin):
177
183
  "generate_sample": self.options.get("generate-sample", False),
178
184
  "default_api_version": self.options.get("default-api-version"),
179
185
  "from_typespec": self.options.get("from-typespec", False),
186
+ "unbranded": unbranded,
180
187
  }
181
188
 
182
189
  if options["builders_visibility"] is None:
@@ -53,14 +53,13 @@ class BaseType(BaseModel, ABC): # pylint: disable=too-many-public-methods
53
53
  return cls(yaml_data=yaml_data, code_model=code_model)
54
54
 
55
55
  def imports(self, **kwargs) -> FileImport: # pylint: disable=unused-argument
56
- return FileImport()
56
+ return FileImport(self.code_model)
57
57
 
58
58
  def imports_for_multiapi(self, **kwargs: Any) -> FileImport:
59
59
  return self.imports(**kwargs)
60
60
 
61
- @staticmethod
62
- def imports_for_sample() -> FileImport:
63
- return FileImport()
61
+ def imports_for_sample(self) -> FileImport:
62
+ return FileImport(self.code_model)
64
63
 
65
64
  @staticmethod
66
65
  def serialize_sample_value(value: Any) -> str:
@@ -195,7 +195,7 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
195
195
  raise KeyError(f"No operation with id {operation_id} found.") from exc
196
196
 
197
197
  def _imports_shared(self, async_mode: bool) -> FileImport:
198
- file_import = FileImport()
198
+ file_import = FileImport(self.code_model)
199
199
 
200
200
  file_import.add_submodule_import(
201
201
  "typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL
@@ -225,10 +225,9 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
225
225
  ImportType.LOCAL,
226
226
  )
227
227
  file_import.add_msrest_import(
228
- self.code_model,
229
- ".." if async_mode else ".",
230
- MsrestImportType.SerializerDeserializer,
231
- TypingSection.REGULAR,
228
+ relative_path=".." if async_mode else ".",
229
+ msrest_import_type=MsrestImportType.SerializerDeserializer,
230
+ typing_section=TypingSection.REGULAR,
232
231
  )
233
232
  file_import.add_submodule_import(
234
233
  "azure.core.pipeline", "policies", ImportType.AZURECORE
@@ -406,7 +405,7 @@ class Config(_ClientConfigBase[ConfigGlobalParameterList]):
406
405
  return f"{super().name}Configuration"
407
406
 
408
407
  def _imports_shared(self, async_mode: bool) -> FileImport:
409
- file_import = FileImport()
408
+ file_import = FileImport(self.code_model)
410
409
  file_import.add_submodule_import(
411
410
  "azure.core.pipeline", "policies", ImportType.AZURECORE
412
411
  )
@@ -103,7 +103,7 @@ class CombinedType(BaseType):
103
103
  raise ValueError("You shouldn't do instance checks on a multiple type")
104
104
 
105
105
  def imports(self, **kwargs: Any) -> FileImport:
106
- file_import = FileImport()
106
+ file_import = FileImport(self.code_model)
107
107
  if self.name and not kwargs.get("is_types_file"):
108
108
  file_import.add_submodule_import(
109
109
  kwargs.pop("relative_path"),
@@ -124,7 +124,7 @@ class ConstantType(BaseType):
124
124
  )
125
125
 
126
126
  def _imports_shared(self, **kwargs: Any):
127
- file_import = FileImport()
127
+ file_import = FileImport(self.code_model)
128
128
  file_import.merge(self.value_type.imports(**kwargs))
129
129
  return file_import
130
130
 
@@ -182,7 +182,7 @@ class TokenCredentialType(
182
182
  return "~azure.core.credentials.TokenCredential"
183
183
 
184
184
  def imports(self, **kwargs: Any) -> FileImport:
185
- file_import = FileImport()
185
+ file_import = FileImport(self.code_model)
186
186
  if kwargs.get("async_mode"):
187
187
  file_import.add_submodule_import(
188
188
  "azure.core.credentials_async",
@@ -221,7 +221,7 @@ class AzureKeyCredentialType(
221
221
  return "isinstance({}, AzureKeyCredential)"
222
222
 
223
223
  def imports(self, **kwargs: Any) -> FileImport: # pylint: disable=unused-argument
224
- file_import = FileImport()
224
+ file_import = FileImport(self.code_model)
225
225
  file_import.add_submodule_import(
226
226
  "azure.core.credentials",
227
227
  "AzureKeyCredential",
@@ -123,7 +123,7 @@ class DictionaryType(BaseType):
123
123
  )
124
124
 
125
125
  def imports(self, **kwargs: Any) -> FileImport:
126
- file_import = FileImport()
126
+ file_import = FileImport(self.code_model)
127
127
  file_import.add_submodule_import(
128
128
  "typing", "Dict", ImportType.STDLIB, TypingSection.CONDITIONAL
129
129
  )
@@ -76,7 +76,7 @@ class EnumValue(BaseType):
76
76
  return self.value_type.instance_check_template
77
77
 
78
78
  def imports(self, **kwargs: Any) -> FileImport:
79
- file_import = FileImport()
79
+ file_import = FileImport(self.code_model)
80
80
  file_import.merge(self.value_type.imports(**kwargs))
81
81
  file_import.add_literal_import()
82
82
  file_import.add_submodule_import(
@@ -230,7 +230,7 @@ class EnumType(BaseType):
230
230
 
231
231
  def imports(self, **kwargs: Any) -> FileImport:
232
232
  operation = kwargs.pop("operation", False)
233
- file_import = FileImport()
233
+ file_import = FileImport(self.code_model)
234
234
  if self.code_model.options["models_mode"]:
235
235
  file_import.add_submodule_import(
236
236
  "typing", "Union", ImportType.STDLIB, TypingSection.CONDITIONAL
@@ -88,8 +88,9 @@ class TypeDefinition:
88
88
 
89
89
 
90
90
  class FileImport:
91
- def __init__(self, imports: Optional[List[ImportModel]] = None) -> None:
92
- self.imports = imports or []
91
+ def __init__(self, code_model: "CodeModel") -> None:
92
+ self.imports: List[ImportModel] = []
93
+ self.code_model = code_model
93
94
  # has sync and async type definitions
94
95
  self.type_definitions: Dict[str, TypeDefinition] = {}
95
96
 
@@ -272,12 +273,12 @@ class FileImport:
272
273
 
273
274
  def add_msrest_import(
274
275
  self,
275
- code_model: "CodeModel",
276
+ *,
276
277
  relative_path: str,
277
278
  msrest_import_type: MsrestImportType,
278
279
  typing_section: TypingSection,
279
280
  ):
280
- if code_model.options["client_side_validation"]:
281
+ if self.code_model.options["client_side_validation"]:
281
282
  if msrest_import_type == MsrestImportType.Module:
282
283
  self.add_import(
283
284
  "msrest.serialization", ImportType.AZURECORE, typing_section
@@ -291,7 +292,7 @@ class FileImport:
291
292
  "msrest", "Deserializer", ImportType.THIRDPARTY, typing_section
292
293
  )
293
294
  else:
294
- if code_model.options["multiapi"]:
295
+ if self.code_model.options["multiapi"]:
295
296
  relative_path += "."
296
297
  if msrest_import_type == MsrestImportType.Module:
297
298
  self.add_submodule_import(
@@ -146,7 +146,7 @@ class ListType(BaseType):
146
146
  )
147
147
 
148
148
  def imports(self, **kwargs: Any) -> FileImport:
149
- file_import = FileImport()
149
+ file_import = FileImport(self.code_model)
150
150
  if not (
151
151
  self.code_model.options["version_tolerant"]
152
152
  and self.element_type.is_xml
@@ -104,6 +104,9 @@ class LROOperationBase(OperationBase[LROResponseType]):
104
104
  """We don't want the poller to show up in ClsType, so we call super() on resposne type annotation"""
105
105
  return f"ClsType[{Response.type_annotation(self.responses[0], async_mode=async_mode)}]"
106
106
 
107
+ def get_poller_with_response_type(self, async_mode: bool) -> str:
108
+ return self.response_type_annotation(async_mode=async_mode)
109
+
107
110
  def get_poller(self, async_mode: bool) -> str:
108
111
  return self.responses[0].get_poller(async_mode)
109
112
 
@@ -268,7 +268,7 @@ class JSONModelType(ModelType):
268
268
  return "isinstance({}, MutableMapping)"
269
269
 
270
270
  def imports(self, **kwargs: Any) -> FileImport:
271
- file_import = FileImport()
271
+ file_import = FileImport(self.code_model)
272
272
  file_import.add_submodule_import(
273
273
  "typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL
274
274
  )
@@ -230,7 +230,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
230
230
  def _imports_shared(
231
231
  self, async_mode: bool, **kwargs: Any # pylint: disable=unused-argument
232
232
  ) -> FileImport:
233
- file_import = FileImport()
233
+ file_import = FileImport(self.code_model)
234
234
  file_import.add_submodule_import(
235
235
  "typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL
236
236
  )
@@ -254,7 +254,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
254
254
 
255
255
  def imports_for_multiapi(self, async_mode: bool, **kwargs: Any) -> FileImport:
256
256
  if self.abstract:
257
- return FileImport()
257
+ return FileImport(self.code_model)
258
258
  file_import = self._imports_shared(async_mode, **kwargs)
259
259
  for param in self.parameters.method:
260
260
  file_import.merge(
@@ -308,7 +308,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
308
308
  async_mode: bool,
309
309
  ) -> FileImport:
310
310
  """Helper method to get a request builder import."""
311
- file_import = FileImport()
311
+ file_import = FileImport(self.code_model)
312
312
  if self.code_model.options["builders_visibility"] != "embedded":
313
313
  group_name = request_builder.group_name
314
314
  rest_import_path = "..." if async_mode else ".."
@@ -338,7 +338,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
338
338
  self, async_mode: bool, **kwargs: Any
339
339
  ) -> FileImport:
340
340
  if self.abstract:
341
- return FileImport()
341
+ return FileImport(self.code_model)
342
342
  file_import = self._imports_shared(async_mode, **kwargs)
343
343
 
344
344
  for param in self.parameters.method:
@@ -48,7 +48,7 @@ class OperationGroup(BaseModel):
48
48
  return ", ".join(base_classes)
49
49
 
50
50
  def imports_for_multiapi(self, async_mode: bool) -> FileImport:
51
- file_import = FileImport()
51
+ file_import = FileImport(self.code_model)
52
52
  relative_path = ".." if async_mode else "."
53
53
  for operation in self.operations:
54
54
  file_import.merge(
@@ -79,7 +79,7 @@ class OperationGroup(BaseModel):
79
79
  return any(o for o in self.operations if o.need_validation)
80
80
 
81
81
  def imports(self, async_mode: bool) -> FileImport:
82
- file_import = FileImport()
82
+ file_import = FileImport(self.code_model)
83
83
 
84
84
  relative_path = ("..." if async_mode else "..") + (
85
85
  "." if self.client.is_subclient else ""
@@ -144,7 +144,7 @@ class PagingOperationBase(OperationBase[PagingResponseType]):
144
144
 
145
145
  def imports(self, async_mode: bool, **kwargs: Any) -> FileImport:
146
146
  if self.abstract:
147
- return FileImport()
147
+ return FileImport(self.code_model)
148
148
  file_import = self._imports_shared(async_mode, **kwargs)
149
149
  file_import.merge(super().imports(async_mode, **kwargs))
150
150
  if self.code_model.options["tracing"] and self.want_tracing:
@@ -168,7 +168,7 @@ class _ParameterBase(
168
168
  return self.type.serialization_type
169
169
 
170
170
  def _imports_shared(self, async_mode: bool, **_: Any) -> FileImport:
171
- file_import = FileImport()
171
+ file_import = FileImport(self.code_model)
172
172
  if self.optional and self.client_default_value is None:
173
173
  file_import.add_submodule_import("typing", "Optional", ImportType.STDLIB)
174
174
  if self.added_on:
@@ -107,7 +107,7 @@ class BinaryType(PrimitiveType):
107
107
  from .combined_type import CombinedType
108
108
  from .operation import OperationBase
109
109
 
110
- file_import = FileImport()
110
+ file_import = FileImport(self.code_model)
111
111
  file_import.add_submodule_import("typing", "IO", ImportType.STDLIB)
112
112
  operation = kwargs.get("operation")
113
113
  if (
@@ -145,7 +145,7 @@ class BinaryIteratorType(PrimitiveType):
145
145
  return self.get_declaration("Iterator[bytes]")
146
146
 
147
147
  def imports(self, **kwargs: Any) -> FileImport:
148
- file_import = FileImport()
148
+ file_import = FileImport(self.code_model)
149
149
  iterator = "AsyncIterator" if kwargs.get("async_mode") else "Iterator"
150
150
  file_import.add_submodule_import("typing", iterator, ImportType.STDLIB)
151
151
  return file_import
@@ -171,7 +171,7 @@ class AnyType(PrimitiveType):
171
171
  return self.get_declaration({})
172
172
 
173
173
  def imports(self, **kwargs: Any) -> FileImport:
174
- file_import = FileImport()
174
+ file_import = FileImport(self.code_model)
175
175
  file_import.add_submodule_import(
176
176
  "typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL
177
177
  )
@@ -204,7 +204,7 @@ class AnyObjectType(PrimitiveType):
204
204
  return "isinstance({}, MutableMapping)"
205
205
 
206
206
  def imports(self, **kwargs: Any) -> FileImport:
207
- file_import = FileImport()
207
+ file_import = FileImport(self.code_model)
208
208
  file_import.define_mutable_mapping_type()
209
209
  return file_import
210
210
 
@@ -386,7 +386,7 @@ class DatetimeType(PrimitiveType):
386
386
  return f'"{value}"'
387
387
 
388
388
  def imports(self, **kwargs: Any) -> FileImport:
389
- file_import = FileImport()
389
+ file_import = FileImport(self.code_model)
390
390
  file_import.add_import("datetime", ImportType.STDLIB)
391
391
  return file_import
392
392
 
@@ -398,9 +398,8 @@ class DatetimeType(PrimitiveType):
398
398
  def instance_check_template(self) -> str:
399
399
  return "isinstance({}, datetime.datetime)"
400
400
 
401
- @staticmethod
402
- def imports_for_sample() -> FileImport:
403
- file_import = super(DatetimeType, DatetimeType).imports_for_sample()
401
+ def imports_for_sample(self) -> FileImport:
402
+ file_import = super().imports_for_sample()
404
403
  file_import.add_import("isodate", ImportType.STDLIB)
405
404
  return file_import
406
405
 
@@ -430,7 +429,7 @@ class TimeType(PrimitiveType):
430
429
  return f'"{value}"'
431
430
 
432
431
  def imports(self, **kwargs: Any) -> FileImport:
433
- file_import = FileImport()
432
+ file_import = FileImport(self.code_model)
434
433
  file_import.add_import("datetime", ImportType.STDLIB)
435
434
  return file_import
436
435
 
@@ -442,9 +441,8 @@ class TimeType(PrimitiveType):
442
441
  def instance_check_template(self) -> str:
443
442
  return "isinstance({}, datetime.time)"
444
443
 
445
- @staticmethod
446
- def imports_for_sample() -> FileImport:
447
- file_import = super(TimeType, TimeType).imports_for_sample()
444
+ def imports_for_sample(self) -> FileImport:
445
+ file_import = super().imports_for_sample()
448
446
  file_import.add_import("isodate", ImportType.STDLIB)
449
447
  return file_import
450
448
 
@@ -478,7 +476,7 @@ class UnixTimeType(PrimitiveType):
478
476
  return f'"{value}"'
479
477
 
480
478
  def imports(self, **kwargs: Any) -> FileImport:
481
- file_import = FileImport()
479
+ file_import = FileImport(self.code_model)
482
480
  file_import.add_import("datetime", ImportType.STDLIB)
483
481
  return file_import
484
482
 
@@ -490,9 +488,8 @@ class UnixTimeType(PrimitiveType):
490
488
  def instance_check_template(self) -> str:
491
489
  return "isinstance({}, datetime.time)"
492
490
 
493
- @staticmethod
494
- def imports_for_sample() -> FileImport:
495
- file_import = super(UnixTimeType, UnixTimeType).imports_for_sample()
491
+ def imports_for_sample(self) -> FileImport:
492
+ file_import = super().imports_for_sample()
496
493
  file_import.add_import("datetime", ImportType.STDLIB)
497
494
  return file_import
498
495
 
@@ -522,7 +519,7 @@ class DateType(PrimitiveType):
522
519
  return f'"{value}"'
523
520
 
524
521
  def imports(self, **kwargs: Any) -> FileImport:
525
- file_import = FileImport()
522
+ file_import = FileImport(self.code_model)
526
523
  file_import.add_import("datetime", ImportType.STDLIB)
527
524
  return file_import
528
525
 
@@ -534,9 +531,8 @@ class DateType(PrimitiveType):
534
531
  def instance_check_template(self) -> str:
535
532
  return "isinstance({}, datetime.date)"
536
533
 
537
- @staticmethod
538
- def imports_for_sample() -> FileImport:
539
- file_import = super(DateType, DateType).imports_for_sample()
534
+ def imports_for_sample(self) -> FileImport:
535
+ file_import = super().imports_for_sample()
540
536
  file_import.add_import("isodate", ImportType.STDLIB)
541
537
  return file_import
542
538
 
@@ -566,7 +562,7 @@ class DurationType(PrimitiveType):
566
562
  return f'"{value}"'
567
563
 
568
564
  def imports(self, **kwargs: Any) -> FileImport:
569
- file_import = FileImport()
565
+ file_import = FileImport(self.code_model)
570
566
  file_import.add_import("datetime", ImportType.STDLIB)
571
567
  return file_import
572
568
 
@@ -578,9 +574,8 @@ class DurationType(PrimitiveType):
578
574
  def instance_check_template(self) -> str:
579
575
  return "isinstance({}, datetime.timedelta)"
580
576
 
581
- @staticmethod
582
- def imports_for_sample() -> FileImport:
583
- file_import = super(DurationType, DurationType).imports_for_sample()
577
+ def imports_for_sample(self) -> FileImport:
578
+ file_import = super().imports_for_sample()
584
579
  file_import.add_import("isodate", ImportType.STDLIB)
585
580
  return file_import
586
581
 
@@ -623,7 +618,7 @@ class AzureCoreType(PrimitiveType):
623
618
  return self.name
624
619
 
625
620
  def imports(self, **kwargs: Any) -> FileImport:
626
- file_import = FileImport()
621
+ file_import = FileImport(self.code_model)
627
622
  file_import.add_submodule_import("azure.core", self.name, ImportType.AZURECORE)
628
623
  return file_import
629
624
 
@@ -120,6 +120,8 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
120
120
  )
121
121
  if self.description(is_operation_file=True):
122
122
  description = self.description(is_operation_file=True)
123
+ # make sure there is no \n otherwise the json template will be invalid
124
+ description = (description or "").replace("\n", " ")
123
125
  return self.type.get_json_template_representation(
124
126
  optional=self.optional,
125
127
  client_default_value_declaration=client_default_value_declaration,
@@ -145,7 +147,7 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
145
147
  return retval or None
146
148
 
147
149
  def imports(self, **kwargs) -> FileImport:
148
- file_import = FileImport()
150
+ file_import = FileImport(self.code_model)
149
151
  if self.is_discriminator and isinstance(self.type, EnumType):
150
152
  return file_import
151
153
  file_import.merge(
@@ -81,7 +81,7 @@ class RequestBuilderBase(BaseBuilder[ParameterListType]):
81
81
  return "~azure.core.rest.HttpRequest"
82
82
 
83
83
  def imports(self) -> FileImport:
84
- file_import = FileImport()
84
+ file_import = FileImport(self.code_model)
85
85
  relative_path = ".."
86
86
  if (
87
87
  not self.code_model.options["builders_visibility"] == "embedded"
@@ -111,15 +111,14 @@ class RequestBuilderBase(BaseBuilder[ParameterListType]):
111
111
  "typing", "Any", ImportType.STDLIB, typing_section=TypingSection.CONDITIONAL
112
112
  )
113
113
  file_import.add_msrest_import(
114
- self.code_model,
115
- "..."
114
+ relative_path="..."
116
115
  if (
117
116
  not self.code_model.options["builders_visibility"] == "embedded"
118
117
  and self.group_name
119
118
  )
120
119
  else "..",
121
- MsrestImportType.Serializer,
122
- TypingSection.REGULAR,
120
+ msrest_import_type=MsrestImportType.Serializer,
121
+ typing_section=TypingSection.REGULAR,
123
122
  )
124
123
  if (
125
124
  self.overloads
@@ -108,7 +108,7 @@ class Response(BaseModel):
108
108
  return self.type.docstring_type(**kwargs) if self.type else "None"
109
109
 
110
110
  def _imports_shared(self, **kwargs: Any) -> FileImport:
111
- file_import = FileImport()
111
+ file_import = FileImport(self.code_model)
112
112
  if self.type:
113
113
  file_import.merge(self.type.imports(**kwargs))
114
114
  if self.nullable:
@@ -1565,7 +1565,7 @@ class _LROOperationSerializer(_OperationSerializer[LROOperationType]):
1565
1565
  retval.append("else: polling_method = polling")
1566
1566
  retval.append("if cont_token:")
1567
1567
  retval.append(
1568
- f" return {builder.get_poller(self.async_mode)}.from_continuation_token("
1568
+ f" return {builder.get_poller_with_response_type(self.async_mode)}.from_continuation_token("
1569
1569
  )
1570
1570
  retval.append(" polling_method=polling_method,")
1571
1571
  retval.append(" continuation_token=cont_token,")
@@ -1573,9 +1573,12 @@ class _LROOperationSerializer(_OperationSerializer[LROOperationType]):
1573
1573
  retval.append(" deserialization_callback=get_long_running_output")
1574
1574
  retval.append(" )")
1575
1575
  retval.append(
1576
- f"return {builder.get_poller(self.async_mode)}"
1577
- "(self._client, raw_result, get_long_running_output, polling_method) # type: ignore"
1576
+ f"return {builder.get_poller_with_response_type(self.async_mode)}("
1578
1577
  )
1578
+ retval.append(
1579
+ " self._client, raw_result, get_long_running_output, polling_method # type: ignore"
1580
+ )
1581
+ retval.append(" )")
1579
1582
  return retval
1580
1583
 
1581
1584
  def get_long_running_output(self, builder: LROOperationType) -> List[str]:
@@ -123,7 +123,7 @@ class ClientSerializer:
123
123
  [
124
124
  "_policies = kwargs.pop('policies', None)",
125
125
  "if _policies is None:",
126
- f' _policies = [{",".join(build_policies(self.client.code_model.options["azure_arm"], async_mode))}]', # pylint: disable=line-too-long
126
+ f' _policies = [{",".join(build_policies(self.client.code_model.options["azure_arm"], async_mode, self.client.code_model.options["unbranded"]))}]', # pylint: disable=line-too-long
127
127
  f"self._client: {pipeline_client_name} = {pipeline_client_name}("
128
128
  f"{', '.join(f'{k}={v}' for k, v in params.items())}, **kwargs)",
129
129
  ]
@@ -85,7 +85,7 @@ class GeneralSerializer:
85
85
  def serialize_service_client_file(self, clients: List[Client]) -> str:
86
86
  template = self.env.get_template("client_container.py.jinja2")
87
87
 
88
- imports = FileImport()
88
+ imports = FileImport(self.code_model)
89
89
  for client in clients:
90
90
  imports.merge(client.imports(self.async_mode))
91
91
 
@@ -101,7 +101,7 @@ class GeneralSerializer:
101
101
  template = self.env.get_template("vendor.py.jinja2")
102
102
 
103
103
  # configure imports
104
- file_import = FileImport()
104
+ file_import = FileImport(self.code_model)
105
105
  if self.code_model.need_request_converter:
106
106
  file_import.add_submodule_import(
107
107
  "azure.core.pipeline.transport",
@@ -122,10 +122,9 @@ class GeneralSerializer:
122
122
  TypingSection.TYPING,
123
123
  )
124
124
  file_import.add_msrest_import(
125
- self.code_model,
126
- ".." if self.async_mode else ".",
127
- MsrestImportType.SerializerDeserializer,
128
- TypingSection.TYPING,
125
+ relative_path=".." if self.async_mode else ".",
126
+ msrest_import_type=MsrestImportType.SerializerDeserializer,
127
+ typing_section=TypingSection.TYPING,
129
128
  )
130
129
  for client in clients:
131
130
  file_import.add_submodule_import(
@@ -152,7 +151,7 @@ class GeneralSerializer:
152
151
 
153
152
  def serialize_config_file(self, clients: List[Client]) -> str:
154
153
  template = self.env.get_template("config_container.py.jinja2")
155
- imports = FileImport()
154
+ imports = FileImport(self.code_model)
156
155
  for client in self.code_model.clients:
157
156
  imports.merge(client.config.imports(self.async_mode))
158
157
  return template.render(
@@ -133,9 +133,11 @@ class _ModelSerializer(ABC):
133
133
 
134
134
  class MsrestModelSerializer(_ModelSerializer):
135
135
  def imports(self) -> FileImport:
136
- file_import = FileImport()
136
+ file_import = FileImport(self.code_model)
137
137
  file_import.add_msrest_import(
138
- self.code_model, "..", MsrestImportType.Module, TypingSection.REGULAR
138
+ relative_path="..",
139
+ msrest_import_type=MsrestImportType.Module,
140
+ typing_section=TypingSection.REGULAR,
139
141
  )
140
142
  for model in self.code_model.model_types:
141
143
  file_import.merge(model.imports(is_operation_file=False))
@@ -200,7 +202,7 @@ class MsrestModelSerializer(_ModelSerializer):
200
202
 
201
203
  class DpgModelSerializer(_ModelSerializer):
202
204
  def imports(self) -> FileImport:
203
- file_import = FileImport()
205
+ file_import = FileImport(self.code_model)
204
206
  file_import.add_submodule_import(
205
207
  "..",
206
208
  "_model_base",
@@ -57,7 +57,7 @@ class OperationGroupsSerializer:
57
57
  if self.operation_group
58
58
  else [og for client in self.clients for og in client.operation_groups]
59
59
  )
60
- imports = FileImport()
60
+ imports = FileImport(code_model=self.code_model)
61
61
  for operation_group in operation_groups:
62
62
  imports.merge(
63
63
  operation_group.imports(
@@ -15,7 +15,7 @@ class PatchSerializer:
15
15
 
16
16
  def serialize(self) -> str:
17
17
  template = self.env.get_template("patch.py.jinja2")
18
- imports = FileImport()
18
+ imports = FileImport(code_model=self.code_model)
19
19
  imports.add_submodule_import("typing", "List", ImportType.STDLIB)
20
20
  return template.render(
21
21
  code_model=self.code_model,
@@ -26,7 +26,7 @@ class RequestBuildersSerializer:
26
26
 
27
27
  @property
28
28
  def imports(self) -> FileImport:
29
- file_import = FileImport()
29
+ file_import = FileImport(code_model=self.code_model)
30
30
  for request_builder in self.request_builders:
31
31
  if request_builder.group_name == self.group_name:
32
32
  file_import.merge(request_builder.imports())
@@ -43,7 +43,7 @@ class SampleSerializer:
43
43
  }
44
44
 
45
45
  def _imports(self) -> FileImportSerializer:
46
- imports = FileImport()
46
+ imports = FileImport(self.code_model)
47
47
  namespace_from_package_name = get_namespace_from_package_name(
48
48
  self.code_model.options["package_name"]
49
49
  )
@@ -15,7 +15,7 @@ class TypesSerializer:
15
15
  self.env = env
16
16
 
17
17
  def imports(self) -> FileImport:
18
- file_import = FileImport()
18
+ file_import = FileImport(code_model=self.code_model)
19
19
  if self.code_model.named_unions:
20
20
  file_import.add_submodule_import(
21
21
  "typing",
@@ -553,7 +553,7 @@ class Model(_MyMutableMapping):
553
553
  if exclude_readonly:
554
554
  readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)]
555
555
  for k, v in self.items():
556
- if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false
556
+ if exclude_readonly and k in readonly_props: # pyright: ignore[reportUnboundVariable]
557
557
  continue
558
558
  result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly)
559
559
  return result
@@ -3,12 +3,14 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
+ import typing
6
7
  import contextlib
7
8
  import os
8
9
  import sys
9
10
  import logging
10
11
 
11
12
  from jsonrpc import dispatcher, JSONRPCResponseManager
13
+ from jsonrpc.jsonrpc2 import JSONRPC20Response
12
14
 
13
15
  from .stdstream import read_message, write_message
14
16
 
@@ -99,7 +101,9 @@ def main() -> None:
99
101
  _LOGGER.debug("Trying to read")
100
102
  message = read_message()
101
103
 
102
- response = JSONRPCResponseManager.handle(message, dispatcher).json
104
+ response = typing.cast(
105
+ JSONRPC20Response, JSONRPCResponseManager.handle(message, dispatcher)
106
+ ).json
103
107
  _LOGGER.debug("Produced: %s", response)
104
108
  write_message(response)
105
109
  _LOGGER.debug("Message processed")
@@ -4,7 +4,7 @@
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
6
  from enum import Enum
7
- from typing import Dict, Optional, Set, Union, Tuple
7
+ from typing import Dict, Optional, Set, Union, Tuple, cast
8
8
  from ..utils import convert_list_to_tuple
9
9
 
10
10
 
@@ -108,21 +108,23 @@ class FileImport:
108
108
  ] = None,
109
109
  typing_section: TypingSection = TypingSection.REGULAR,
110
110
  ) -> None:
111
- name_input: Optional[
112
- Union[
113
- str,
114
- Tuple[
111
+ name_input = cast(
112
+ Optional[
113
+ Union[
115
114
  str,
116
- str,
117
- ],
118
- Tuple[
119
- str,
120
- Optional[str],
121
- Tuple[Tuple[Tuple[int, int], str, Optional[str]]],
122
- ],
123
- ]
124
- ] = None
125
- name_input = convert_list_to_tuple(name_import)
115
+ Tuple[
116
+ str,
117
+ str,
118
+ ],
119
+ Tuple[
120
+ str,
121
+ Optional[str],
122
+ Tuple[Tuple[Tuple[int, int], str, Optional[str]]],
123
+ ],
124
+ ]
125
+ ],
126
+ convert_list_to_tuple(name_import),
127
+ )
126
128
  self._imports.setdefault(typing_section, {}).setdefault(
127
129
  import_type, {}
128
130
  ).setdefault(from_section, set()).add(name_input)
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@autorest/python",
3
- "version": "6.9.3",
3
+ "version": "6.9.4",
4
4
  "description": "The Python extension for generators in AutoRest.",
5
5
  "main": "index.js",
6
6
  "repository": {