@autorest/python 6.9.2 → 6.9.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/autorest/_utils.py +48 -20
  2. package/autorest/codegen/__init__.py +8 -1
  3. package/autorest/codegen/models/__init__.py +17 -2
  4. package/autorest/codegen/models/base.py +3 -4
  5. package/autorest/codegen/models/client.py +5 -6
  6. package/autorest/codegen/models/combined_type.py +1 -1
  7. package/autorest/codegen/models/constant_type.py +3 -17
  8. package/autorest/codegen/models/credential_types.py +2 -2
  9. package/autorest/codegen/models/dictionary_type.py +1 -1
  10. package/autorest/codegen/models/enum_type.py +78 -23
  11. package/autorest/codegen/models/imports.py +23 -5
  12. package/autorest/codegen/models/list_type.py +1 -1
  13. package/autorest/codegen/models/lro_operation.py +3 -0
  14. package/autorest/codegen/models/model_type.py +1 -1
  15. package/autorest/codegen/models/operation.py +4 -4
  16. package/autorest/codegen/models/operation_group.py +2 -2
  17. package/autorest/codegen/models/paging_operation.py +1 -1
  18. package/autorest/codegen/models/parameter.py +6 -3
  19. package/autorest/codegen/models/primitive_types.py +20 -25
  20. package/autorest/codegen/models/property.py +25 -4
  21. package/autorest/codegen/models/request_builder.py +4 -5
  22. package/autorest/codegen/models/response.py +1 -1
  23. package/autorest/codegen/serializers/builder_serializer.py +23 -41
  24. package/autorest/codegen/serializers/client_serializer.py +1 -1
  25. package/autorest/codegen/serializers/general_serializer.py +6 -7
  26. package/autorest/codegen/serializers/model_serializer.py +10 -9
  27. package/autorest/codegen/serializers/operation_groups_serializer.py +1 -1
  28. package/autorest/codegen/serializers/patch_serializer.py +1 -1
  29. package/autorest/codegen/serializers/request_builders_serializer.py +1 -1
  30. package/autorest/codegen/serializers/sample_serializer.py +1 -1
  31. package/autorest/codegen/serializers/types_serializer.py +1 -1
  32. package/autorest/codegen/templates/client.py.jinja2 +1 -1
  33. package/autorest/codegen/templates/enum.py.jinja2 +3 -3
  34. package/autorest/codegen/templates/model_base.py.jinja2 +1 -1
  35. package/autorest/codegen/templates/serialization.py.jinja2 +5 -5
  36. package/autorest/jsonrpc/server.py +5 -1
  37. package/autorest/m4reformatter/__init__.py +11 -9
  38. package/autorest/multiapi/models/imports.py +17 -15
  39. package/autorest/preprocess/__init__.py +7 -6
  40. package/package.json +1 -1
@@ -8,6 +8,19 @@ import re
8
8
  import argparse
9
9
 
10
10
 
11
+ def update_enum_value(
12
+ name: str, value: Any, description: str, enum_type: Dict[str, Any]
13
+ ) -> Dict[str, Any]:
14
+ return {
15
+ "name": name,
16
+ "type": "enumvalue",
17
+ "value": value,
18
+ "description": description,
19
+ "enumType": enum_type,
20
+ "valueType": enum_type["valueType"],
21
+ }
22
+
23
+
11
24
  def to_snake_case(name: str) -> str:
12
25
  def replace_upper_characters(m) -> str:
13
26
  match_str = m.group().lower()
@@ -107,24 +120,39 @@ KNOWN_TYPES: Dict[str, Dict[str, Any]] = {
107
120
  JSON_REGEXP = re.compile(r"^(application|text)/(.+\+)?json$")
108
121
 
109
122
 
110
- def build_policies(is_arm: bool, async_mode: bool) -> List[str]:
111
- async_prefix = "Async" if async_mode else ""
112
- policies = [
113
- "policies.RequestIdPolicy(**kwargs)",
114
- "self._config.headers_policy",
115
- "self._config.user_agent_policy",
116
- "self._config.proxy_policy",
117
- "policies.ContentDecodePolicy(**kwargs)",
118
- f"{async_prefix}ARMAutoResourceProviderRegistrationPolicy()"
119
- if is_arm
120
- else None,
121
- "self._config.redirect_policy",
122
- "self._config.retry_policy",
123
- "self._config.authentication_policy",
124
- "self._config.custom_hook_policy",
125
- "self._config.logging_policy",
126
- "policies.DistributedTracingPolicy(**kwargs)",
127
- "policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None",
128
- "self._config.http_logging_policy",
129
- ]
123
+ def build_policies(
124
+ is_arm: bool, async_mode: bool, unbranded: bool = False
125
+ ) -> List[str]:
126
+ if not unbranded:
127
+ # for Azure
128
+ async_prefix = "Async" if async_mode else ""
129
+ policies = [
130
+ "policies.RequestIdPolicy(**kwargs)",
131
+ "self._config.headers_policy",
132
+ "self._config.user_agent_policy",
133
+ "self._config.proxy_policy",
134
+ "policies.ContentDecodePolicy(**kwargs)",
135
+ f"{async_prefix}ARMAutoResourceProviderRegistrationPolicy()"
136
+ if is_arm
137
+ else None,
138
+ "self._config.redirect_policy",
139
+ "self._config.retry_policy",
140
+ "self._config.authentication_policy",
141
+ "self._config.custom_hook_policy",
142
+ "self._config.logging_policy",
143
+ "policies.DistributedTracingPolicy(**kwargs)",
144
+ "policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None",
145
+ "self._config.http_logging_policy",
146
+ ]
147
+ else:
148
+ # for non-Azure
149
+ policies = [
150
+ "self._config.headers_policy",
151
+ "self._config.user_agent_policy",
152
+ "self._config.proxy_policy",
153
+ "policies.ContentDecodePolicy(**kwargs)",
154
+ "self._config.retry_policy",
155
+ "self._config.authentication_policy",
156
+ "self._config.logging_policy",
157
+ ]
130
158
  return [p for p in policies if p]
@@ -85,6 +85,11 @@ def _validate_code_model_options(options: Dict[str, Any]) -> None:
85
85
  "Please read https://aka.ms/azsdk/dpcodegen for more details."
86
86
  )
87
87
 
88
+ if options["unbranded"] and options["tracing"]:
89
+ raise ValueError(
90
+ "Can not set --unbranded=true and --tracing=true at the same time."
91
+ )
92
+
88
93
 
89
94
  _LOGGER = logging.getLogger(__name__)
90
95
 
@@ -137,6 +142,7 @@ class CodeGenerator(Plugin):
137
142
  models_mode_default = "dpg"
138
143
 
139
144
  package_name = self.options.get("package-name")
145
+ unbranded = self.options.get("unbranded", False)
140
146
  options: Dict[str, Any] = {
141
147
  "azure_arm": azure_arm,
142
148
  "head_as_boolean": self.options.get("head-as-boolean", True),
@@ -148,7 +154,7 @@ class CodeGenerator(Plugin):
148
154
  "package_name": package_name,
149
155
  "package_version": self.options.get("package-version"),
150
156
  "client_side_validation": self.options.get("client-side-validation", False),
151
- "tracing": self.options.get("tracing", show_operations),
157
+ "tracing": self.options.get("tracing", show_operations and not unbranded),
152
158
  "multiapi": self.options.get("multiapi", False),
153
159
  "polymorphic_examples": self.options.get("polymorphic-examples", 5),
154
160
  "models_mode": self.options.get("models-mode", models_mode_default).lower(),
@@ -177,6 +183,7 @@ class CodeGenerator(Plugin):
177
183
  "generate_sample": self.options.get("generate-sample", False),
178
184
  "default_api_version": self.options.get("default-api-version"),
179
185
  "from_typespec": self.options.get("from-typespec", False),
186
+ "unbranded": unbranded,
180
187
  }
181
188
 
182
189
  if options["builders_visibility"] is None:
@@ -4,7 +4,7 @@
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
6
  import logging
7
- from typing import Any, Dict, Union
7
+ from typing import Any, Dict, Union, Optional
8
8
  from .base import BaseModel
9
9
  from .base_builder import BaseBuilder, ParameterListType
10
10
  from .code_model import CodeModel
@@ -30,7 +30,7 @@ from .primitive_types import (
30
30
  UnixTimeType,
31
31
  AzureCoreType,
32
32
  )
33
- from .enum_type import EnumType
33
+ from .enum_type import EnumType, EnumValue
34
34
  from .base import BaseType
35
35
  from .constant_type import ConstantType
36
36
  from .imports import FileImport, ImportType, TypingSection
@@ -86,6 +86,7 @@ __all__ = [
86
86
  "DictionaryType",
87
87
  "ListType",
88
88
  "EnumType",
89
+ "EnumValue",
89
90
  "FileImport",
90
91
  "ImportType",
91
92
  "TypingSection",
@@ -128,6 +129,7 @@ TYPE_TO_OBJECT = {
128
129
  "dict": DictionaryType,
129
130
  "constant": ConstantType,
130
131
  "enum": EnumType,
132
+ "enumvalue": EnumValue,
131
133
  "binary": BinaryType,
132
134
  "any": AnyType,
133
135
  "datetime": DatetimeType,
@@ -157,6 +159,7 @@ def build_type(yaml_data: Dict[str, Any], code_model: CodeModel) -> BaseType:
157
159
  except KeyError:
158
160
  # Not created yet, let's create it and add it to the index
159
161
  pass
162
+ response: Optional[BaseType] = None
160
163
  if yaml_data["type"] == "model":
161
164
  # need to special case model to avoid recursion
162
165
  if yaml_data["base"] == "json" or not code_model.options["models_mode"]:
@@ -168,6 +171,16 @@ def build_type(yaml_data: Dict[str, Any], code_model: CodeModel) -> BaseType:
168
171
  response = model_type(yaml_data, code_model)
169
172
  code_model.types_map[yaml_id] = response
170
173
  response.fill_instance_from_yaml(yaml_data, code_model)
174
+ elif yaml_data["type"] == "enum":
175
+ # avoid recursion because we add the parent enum type to the enum value
176
+ response = EnumType(
177
+ yaml_data,
178
+ code_model,
179
+ values=[],
180
+ value_type=build_type(yaml_data["valueType"], code_model),
181
+ )
182
+ code_model.types_map[yaml_id] = response
183
+ response.fill_instance_from_yaml(yaml_data, code_model)
171
184
  else:
172
185
  object_type = yaml_data.get("type")
173
186
  if object_type not in TYPE_TO_OBJECT:
@@ -177,6 +190,8 @@ def build_type(yaml_data: Dict[str, Any], code_model: CodeModel) -> BaseType:
177
190
  )
178
191
  object_type = "string"
179
192
  response = TYPE_TO_OBJECT[object_type].from_yaml(yaml_data, code_model) # type: ignore
193
+ if response is None:
194
+ raise ValueError("response can not be None")
180
195
  code_model.types_map[yaml_id] = response
181
196
  return response
182
197
 
@@ -53,14 +53,13 @@ class BaseType(BaseModel, ABC): # pylint: disable=too-many-public-methods
53
53
  return cls(yaml_data=yaml_data, code_model=code_model)
54
54
 
55
55
  def imports(self, **kwargs) -> FileImport: # pylint: disable=unused-argument
56
- return FileImport()
56
+ return FileImport(self.code_model)
57
57
 
58
58
  def imports_for_multiapi(self, **kwargs: Any) -> FileImport:
59
59
  return self.imports(**kwargs)
60
60
 
61
- @staticmethod
62
- def imports_for_sample() -> FileImport:
63
- return FileImport()
61
+ def imports_for_sample(self) -> FileImport:
62
+ return FileImport(self.code_model)
64
63
 
65
64
  @staticmethod
66
65
  def serialize_sample_value(value: Any) -> str:
@@ -195,7 +195,7 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
195
195
  raise KeyError(f"No operation with id {operation_id} found.") from exc
196
196
 
197
197
  def _imports_shared(self, async_mode: bool) -> FileImport:
198
- file_import = FileImport()
198
+ file_import = FileImport(self.code_model)
199
199
 
200
200
  file_import.add_submodule_import(
201
201
  "typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL
@@ -225,10 +225,9 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
225
225
  ImportType.LOCAL,
226
226
  )
227
227
  file_import.add_msrest_import(
228
- self.code_model,
229
- ".." if async_mode else ".",
230
- MsrestImportType.SerializerDeserializer,
231
- TypingSection.REGULAR,
228
+ relative_path=".." if async_mode else ".",
229
+ msrest_import_type=MsrestImportType.SerializerDeserializer,
230
+ typing_section=TypingSection.REGULAR,
232
231
  )
233
232
  file_import.add_submodule_import(
234
233
  "azure.core.pipeline", "policies", ImportType.AZURECORE
@@ -406,7 +405,7 @@ class Config(_ClientConfigBase[ConfigGlobalParameterList]):
406
405
  return f"{super().name}Configuration"
407
406
 
408
407
  def _imports_shared(self, async_mode: bool) -> FileImport:
409
- file_import = FileImport()
408
+ file_import = FileImport(self.code_model)
410
409
  file_import.add_submodule_import(
411
410
  "azure.core.pipeline", "policies", ImportType.AZURECORE
412
411
  )
@@ -103,7 +103,7 @@ class CombinedType(BaseType):
103
103
  raise ValueError("You shouldn't do instance checks on a multiple type")
104
104
 
105
105
  def imports(self, **kwargs: Any) -> FileImport:
106
- file_import = FileImport()
106
+ file_import = FileImport(self.code_model)
107
107
  if self.name and not kwargs.get("is_types_file"):
108
108
  file_import.add_submodule_import(
109
109
  kwargs.pop("relative_path"),
@@ -6,7 +6,7 @@
6
6
  import logging
7
7
  from typing import Dict, Any, Optional, TYPE_CHECKING
8
8
  from .base import BaseType
9
- from .imports import FileImport, ImportType, TypingSection
9
+ from .imports import FileImport
10
10
  from .primitive_types import IntegerType, BinaryType, StringType, BooleanType
11
11
  from .utils import add_to_description
12
12
 
@@ -124,7 +124,7 @@ class ConstantType(BaseType):
124
124
  )
125
125
 
126
126
  def _imports_shared(self, **kwargs: Any):
127
- file_import = FileImport()
127
+ file_import = FileImport(self.code_model)
128
128
  file_import.merge(self.value_type.imports(**kwargs))
129
129
  return file_import
130
130
 
@@ -134,21 +134,7 @@ class ConstantType(BaseType):
134
134
  def imports(self, **kwargs: Any) -> FileImport:
135
135
  file_import = self._imports_shared(**kwargs)
136
136
  if self._is_literal:
137
- file_import.add_import("sys", ImportType.STDLIB)
138
- file_import.add_submodule_import(
139
- "typing_extensions",
140
- "Literal",
141
- ImportType.BYVERSION,
142
- TypingSection.REGULAR,
143
- None,
144
- (
145
- (
146
- (3, 8),
147
- "typing",
148
- "pylint: disable=no-name-in-module, ungrouped-imports",
149
- ),
150
- ),
151
- )
137
+ file_import.add_literal_import()
152
138
  return file_import
153
139
 
154
140
  @property
@@ -182,7 +182,7 @@ class TokenCredentialType(
182
182
  return "~azure.core.credentials.TokenCredential"
183
183
 
184
184
  def imports(self, **kwargs: Any) -> FileImport:
185
- file_import = FileImport()
185
+ file_import = FileImport(self.code_model)
186
186
  if kwargs.get("async_mode"):
187
187
  file_import.add_submodule_import(
188
188
  "azure.core.credentials_async",
@@ -221,7 +221,7 @@ class AzureKeyCredentialType(
221
221
  return "isinstance({}, AzureKeyCredential)"
222
222
 
223
223
  def imports(self, **kwargs: Any) -> FileImport: # pylint: disable=unused-argument
224
- file_import = FileImport()
224
+ file_import = FileImport(self.code_model)
225
225
  file_import.add_submodule_import(
226
226
  "azure.core.credentials",
227
227
  "AzureKeyCredential",
@@ -123,7 +123,7 @@ class DictionaryType(BaseType):
123
123
  )
124
124
 
125
125
  def imports(self, **kwargs: Any) -> FileImport:
126
- file_import = FileImport()
126
+ file_import = FileImport(self.code_model)
127
127
  file_import.add_submodule_import(
128
128
  "typing", "Dict", ImportType.STDLIB, TypingSection.CONDITIONAL
129
129
  )
@@ -3,17 +3,16 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import Any, Dict, List, TYPE_CHECKING, Optional
6
+ from typing import Any, Dict, List, TYPE_CHECKING, Optional, cast
7
7
 
8
8
  from .base import BaseType
9
9
  from .imports import FileImport, ImportType, TypingSection
10
- from .base import BaseModel
11
10
 
12
11
  if TYPE_CHECKING:
13
12
  from .code_model import CodeModel
14
13
 
15
14
 
16
- class EnumValue(BaseModel):
15
+ class EnumValue(BaseType):
17
16
  """Model containing necessary information for a single value of an enum.
18
17
 
19
18
  :param str name: The name of this enum value
@@ -21,11 +20,70 @@ class EnumValue(BaseModel):
21
20
  :param str description: Optional. The description for this enum value
22
21
  """
23
22
 
24
- def __init__(self, yaml_data: Dict[str, Any], code_model: "CodeModel") -> None:
23
+ def __init__(
24
+ self,
25
+ yaml_data: Dict[str, Any],
26
+ code_model: "CodeModel",
27
+ enum_type: "EnumType",
28
+ value_type: BaseType,
29
+ ) -> None:
25
30
  super().__init__(yaml_data=yaml_data, code_model=code_model)
26
31
  self.name: str = self.yaml_data["name"]
27
32
  self.value: str = self.yaml_data["value"]
28
- self.description: Optional[str] = self.yaml_data.get("description")
33
+ self.enum_type = enum_type
34
+ self.value_type = value_type
35
+
36
+ def description(self, *, is_operation_file: bool) -> str:
37
+ return self.yaml_data.get("description", "")
38
+
39
+ def type_annotation(self, **kwargs: Any) -> str:
40
+ """The python type used for type annotation"""
41
+ return f"Literal[{self.enum_type.name}.{self.name}]"
42
+
43
+ def get_declaration(self, value=None):
44
+ return self.enum_type.name + "." + self.name
45
+
46
+ def docstring_text(self, **kwargs: Any) -> str:
47
+ return self.enum_type.name + "." + self.name
48
+
49
+ def docstring_type(self, **kwargs: Any) -> str:
50
+ """The python type used for RST syntax input and type annotation."""
51
+
52
+ type_annotation = self.value_type.type_annotation(**kwargs)
53
+ enum_type_annotation = f"{self.code_model.namespace}.models.{self.name}"
54
+ return f"{type_annotation} or ~{enum_type_annotation}"
55
+
56
+ def get_json_template_representation(
57
+ self,
58
+ *,
59
+ optional: bool = True,
60
+ client_default_value_declaration: Optional[str] = None,
61
+ description: Optional[str] = None,
62
+ ) -> Any:
63
+ # for better display effect, use the only value instead of var type
64
+ return self.value_type.get_json_template_representation(
65
+ optional=optional,
66
+ client_default_value_declaration=client_default_value_declaration,
67
+ description=description,
68
+ )
69
+
70
+ @property
71
+ def serialization_type(self) -> str:
72
+ return self.value_type.serialization_type
73
+
74
+ @property
75
+ def instance_check_template(self) -> str:
76
+ return self.value_type.instance_check_template
77
+
78
+ def imports(self, **kwargs: Any) -> FileImport:
79
+ file_import = FileImport(self.code_model)
80
+ file_import.merge(self.value_type.imports(**kwargs))
81
+ file_import.add_literal_import()
82
+ file_import.add_submodule_import(
83
+ "._enums", self.enum_type.name, ImportType.LOCAL, TypingSection.REGULAR
84
+ )
85
+
86
+ return file_import
29
87
 
30
88
  @classmethod
31
89
  def from_yaml(
@@ -39,9 +97,14 @@ class EnumValue(BaseModel):
39
97
  :return: A created EnumValue
40
98
  :rtype: ~autorest.models.EnumValue
41
99
  """
100
+ from . import build_type
101
+
102
+ enum_type = cast(EnumType, code_model.lookup_type(id(yaml_data["enumType"])))
42
103
  return cls(
43
104
  yaml_data=yaml_data,
44
105
  code_model=code_model,
106
+ enum_type=enum_type,
107
+ value_type=build_type(yaml_data["valueType"], code_model),
45
108
  )
46
109
 
47
110
 
@@ -150,32 +213,24 @@ class EnumType(BaseType):
150
213
  def instance_check_template(self) -> str:
151
214
  return self.value_type.instance_check_template
152
215
 
216
+ def fill_instance_from_yaml(
217
+ self, yaml_data: Dict[str, Any], code_model: "CodeModel"
218
+ ) -> None:
219
+ for value in yaml_data["values"]:
220
+ self.values.append(EnumValue.from_yaml(value, code_model))
221
+
153
222
  @classmethod
154
223
  def from_yaml(
155
224
  cls, yaml_data: Dict[str, Any], code_model: "CodeModel"
156
225
  ) -> "EnumType":
157
- """Constructs an EnumType from yaml data.
158
-
159
- :param yaml_data: the yaml data from which we will construct this schema
160
- :type yaml_data: dict[str, Any]
161
-
162
- :return: A created EnumType
163
- :rtype: ~autorest.models.EnumType
164
- """
165
- from . import build_type
166
-
167
- return cls(
168
- yaml_data=yaml_data,
169
- code_model=code_model,
170
- value_type=build_type(yaml_data["valueType"], code_model),
171
- values=[
172
- EnumValue.from_yaml(value, code_model) for value in yaml_data["values"]
173
- ],
226
+ raise ValueError(
227
+ "You shouldn't call from_yaml for EnumType to avoid recursion. "
228
+ "Please initial a blank EnumType, then call .fill_instance_from_yaml on the created type."
174
229
  )
175
230
 
176
231
  def imports(self, **kwargs: Any) -> FileImport:
177
232
  operation = kwargs.pop("operation", False)
178
- file_import = FileImport()
233
+ file_import = FileImport(self.code_model)
179
234
  if self.code_model.options["models_mode"]:
180
235
  file_import.add_submodule_import(
181
236
  "typing", "Union", ImportType.STDLIB, TypingSection.CONDITIONAL
@@ -88,8 +88,9 @@ class TypeDefinition:
88
88
 
89
89
 
90
90
  class FileImport:
91
- def __init__(self, imports: Optional[List[ImportModel]] = None) -> None:
92
- self.imports = imports or []
91
+ def __init__(self, code_model: "CodeModel") -> None:
92
+ self.imports: List[ImportModel] = []
93
+ self.code_model = code_model
93
94
  # has sync and async type definitions
94
95
  self.type_definitions: Dict[str, TypeDefinition] = {}
95
96
 
@@ -183,6 +184,23 @@ class FileImport:
183
184
  )
184
185
  self.add_submodule_import("typing", "Any", ImportType.STDLIB)
185
186
 
187
+ def add_literal_import(self) -> None:
188
+ self.add_import("sys", ImportType.STDLIB)
189
+ self.add_submodule_import(
190
+ "typing_extensions",
191
+ "Literal",
192
+ ImportType.BYVERSION,
193
+ TypingSection.REGULAR,
194
+ None,
195
+ (
196
+ (
197
+ (3, 8),
198
+ "typing",
199
+ "pylint: disable=no-name-in-module, ungrouped-imports",
200
+ ),
201
+ ),
202
+ )
203
+
186
204
  def to_dict(
187
205
  self,
188
206
  ) -> Dict[
@@ -255,12 +273,12 @@ class FileImport:
255
273
 
256
274
  def add_msrest_import(
257
275
  self,
258
- code_model: "CodeModel",
276
+ *,
259
277
  relative_path: str,
260
278
  msrest_import_type: MsrestImportType,
261
279
  typing_section: TypingSection,
262
280
  ):
263
- if code_model.options["client_side_validation"]:
281
+ if self.code_model.options["client_side_validation"]:
264
282
  if msrest_import_type == MsrestImportType.Module:
265
283
  self.add_import(
266
284
  "msrest.serialization", ImportType.AZURECORE, typing_section
@@ -274,7 +292,7 @@ class FileImport:
274
292
  "msrest", "Deserializer", ImportType.THIRDPARTY, typing_section
275
293
  )
276
294
  else:
277
- if code_model.options["multiapi"]:
295
+ if self.code_model.options["multiapi"]:
278
296
  relative_path += "."
279
297
  if msrest_import_type == MsrestImportType.Module:
280
298
  self.add_submodule_import(
@@ -146,7 +146,7 @@ class ListType(BaseType):
146
146
  )
147
147
 
148
148
  def imports(self, **kwargs: Any) -> FileImport:
149
- file_import = FileImport()
149
+ file_import = FileImport(self.code_model)
150
150
  if not (
151
151
  self.code_model.options["version_tolerant"]
152
152
  and self.element_type.is_xml
@@ -104,6 +104,9 @@ class LROOperationBase(OperationBase[LROResponseType]):
104
104
  """We don't want the poller to show up in ClsType, so we call super() on resposne type annotation"""
105
105
  return f"ClsType[{Response.type_annotation(self.responses[0], async_mode=async_mode)}]"
106
106
 
107
+ def get_poller_with_response_type(self, async_mode: bool) -> str:
108
+ return self.response_type_annotation(async_mode=async_mode)
109
+
107
110
  def get_poller(self, async_mode: bool) -> str:
108
111
  return self.responses[0].get_poller(async_mode)
109
112
 
@@ -268,7 +268,7 @@ class JSONModelType(ModelType):
268
268
  return "isinstance({}, MutableMapping)"
269
269
 
270
270
  def imports(self, **kwargs: Any) -> FileImport:
271
- file_import = FileImport()
271
+ file_import = FileImport(self.code_model)
272
272
  file_import.add_submodule_import(
273
273
  "typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL
274
274
  )
@@ -230,7 +230,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
230
230
  def _imports_shared(
231
231
  self, async_mode: bool, **kwargs: Any # pylint: disable=unused-argument
232
232
  ) -> FileImport:
233
- file_import = FileImport()
233
+ file_import = FileImport(self.code_model)
234
234
  file_import.add_submodule_import(
235
235
  "typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL
236
236
  )
@@ -254,7 +254,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
254
254
 
255
255
  def imports_for_multiapi(self, async_mode: bool, **kwargs: Any) -> FileImport:
256
256
  if self.abstract:
257
- return FileImport()
257
+ return FileImport(self.code_model)
258
258
  file_import = self._imports_shared(async_mode, **kwargs)
259
259
  for param in self.parameters.method:
260
260
  file_import.merge(
@@ -308,7 +308,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
308
308
  async_mode: bool,
309
309
  ) -> FileImport:
310
310
  """Helper method to get a request builder import."""
311
- file_import = FileImport()
311
+ file_import = FileImport(self.code_model)
312
312
  if self.code_model.options["builders_visibility"] != "embedded":
313
313
  group_name = request_builder.group_name
314
314
  rest_import_path = "..." if async_mode else ".."
@@ -338,7 +338,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
338
338
  self, async_mode: bool, **kwargs: Any
339
339
  ) -> FileImport:
340
340
  if self.abstract:
341
- return FileImport()
341
+ return FileImport(self.code_model)
342
342
  file_import = self._imports_shared(async_mode, **kwargs)
343
343
 
344
344
  for param in self.parameters.method:
@@ -48,7 +48,7 @@ class OperationGroup(BaseModel):
48
48
  return ", ".join(base_classes)
49
49
 
50
50
  def imports_for_multiapi(self, async_mode: bool) -> FileImport:
51
- file_import = FileImport()
51
+ file_import = FileImport(self.code_model)
52
52
  relative_path = ".." if async_mode else "."
53
53
  for operation in self.operations:
54
54
  file_import.merge(
@@ -79,7 +79,7 @@ class OperationGroup(BaseModel):
79
79
  return any(o for o in self.operations if o.need_validation)
80
80
 
81
81
  def imports(self, async_mode: bool) -> FileImport:
82
- file_import = FileImport()
82
+ file_import = FileImport(self.code_model)
83
83
 
84
84
  relative_path = ("..." if async_mode else "..") + (
85
85
  "." if self.client.is_subclient else ""
@@ -144,7 +144,7 @@ class PagingOperationBase(OperationBase[PagingResponseType]):
144
144
 
145
145
  def imports(self, async_mode: bool, **kwargs: Any) -> FileImport:
146
146
  if self.abstract:
147
- return FileImport()
147
+ return FileImport(self.code_model)
148
148
  file_import = self._imports_shared(async_mode, **kwargs)
149
149
  file_import.merge(super().imports(async_mode, **kwargs))
150
150
  if self.code_model.options["tracing"] and self.want_tracing:
@@ -92,6 +92,9 @@ class _ParameterBase(
92
92
  )
93
93
  self.hide_in_method: bool = self.yaml_data.get("hideInMethod", False)
94
94
 
95
+ def get_declaration(self, value: Any = None) -> Any:
96
+ return self.type.get_declaration(value)
97
+
95
98
  @property
96
99
  def hide_in_operation_signature(self) -> bool:
97
100
  return False
@@ -115,7 +118,7 @@ class _ParameterBase(
115
118
  if self.optional and isinstance(self.type, ConstantType):
116
119
  base_description = add_to_description(
117
120
  base_description,
118
- f"Known values are {self.type.get_declaration()} and None.",
121
+ f"Known values are {self.get_declaration()} and None.",
119
122
  )
120
123
  if not (self.optional or self.client_default_value):
121
124
  base_description = add_to_description(base_description, "Required.")
@@ -141,7 +144,7 @@ class _ParameterBase(
141
144
  """Declaration of parameter's client default value"""
142
145
  if self.client_default_value is None:
143
146
  return None
144
- return self.type.get_declaration(self.client_default_value)
147
+ return self.get_declaration(self.client_default_value)
145
148
 
146
149
  def type_annotation(self, **kwargs: Any) -> str:
147
150
  kwargs["is_operation_file"] = True
@@ -165,7 +168,7 @@ class _ParameterBase(
165
168
  return self.type.serialization_type
166
169
 
167
170
  def _imports_shared(self, async_mode: bool, **_: Any) -> FileImport:
168
- file_import = FileImport()
171
+ file_import = FileImport(self.code_model)
169
172
  if self.optional and self.client_default_value is None:
170
173
  file_import.add_submodule_import("typing", "Optional", ImportType.STDLIB)
171
174
  if self.added_on: