@autorest/python 6.10.2 → 6.11.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -83,8 +83,9 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
83
83
  self,
84
84
  ) -> List[Union[RequestBuilder, OverloadedRequestBuilder]]:
85
85
  request_builders: List[Union[RequestBuilder, OverloadedRequestBuilder]] = []
86
- for og_group in self.yaml_data["operationGroups"]:
87
- for operation_yaml in og_group["operations"]:
86
+
87
+ def add_og_request_builder(og: Dict[str, Any]):
88
+ for operation_yaml in og["operations"]:
88
89
  request_builder = get_request_builder(
89
90
  operation_yaml,
90
91
  code_model=self.code_model,
@@ -111,6 +112,14 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
111
112
  client=self,
112
113
  )
113
114
  )
115
+
116
+ queue = self.yaml_data["operationGroups"].copy()
117
+ while queue:
118
+ now = queue.pop(0)
119
+ add_og_request_builder(now)
120
+ if now.get("operationGroups"):
121
+ queue.extend(now["operationGroups"])
122
+
114
123
  return request_builders
115
124
 
116
125
  def pipeline_class(self, async_mode: bool) -> str:
@@ -247,7 +256,7 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
247
256
  @property
248
257
  def has_mixin(self) -> bool:
249
258
  """Do we want a mixin ABC class for typing purposes?"""
250
- return any(o for o in self.operation_groups if o.is_mixin)
259
+ return any(og for og in self.operation_groups if og.is_mixin)
251
260
 
252
261
  @property
253
262
  def lro_operations(self) -> List["OperationType"]:
@@ -255,8 +264,7 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
255
264
  return [
256
265
  operation
257
266
  for operation_group in self.operation_groups
258
- for operation in operation_group.operations
259
- if operation.operation_type in ("lro", "lropaging")
267
+ for operation in operation_group.lro_operations
260
268
  ]
261
269
 
262
270
  @property
@@ -267,8 +275,7 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
267
275
  @property
268
276
  def has_operations(self) -> bool:
269
277
  return any(
270
- bool(operation_group.operations)
271
- for operation_group in self.operation_groups
278
+ operation_group.has_operations for operation_group in self.operation_groups
272
279
  )
273
280
 
274
281
  def link_lro_initial_operations(self) -> None:
@@ -77,6 +77,15 @@ class CodeModel: # pylint: disable=too-many-public-methods, disable=too-many-in
77
77
  t for t in self.types_map.values() if isinstance(t, CombinedType) and t.name
78
78
  ]
79
79
 
80
+ @property
81
+ def has_form_data(self) -> bool:
82
+ for client in self.clients:
83
+ for operation_group in client.operation_groups:
84
+ for operation in operation_group.operations:
85
+ if operation.has_form_data_body:
86
+ return True
87
+ return False
88
+
80
89
  @property
81
90
  def has_etag(self) -> bool:
82
91
  return any(client.has_etag for client in self.clients)
@@ -133,7 +142,12 @@ class CodeModel: # pylint: disable=too-many-public-methods, disable=too-many-in
133
142
  return True
134
143
  if async_mode:
135
144
  return self.need_mixin_abc
136
- return self.need_request_converter or self.need_mixin_abc or self.has_etag
145
+ return (
146
+ self.need_request_converter
147
+ or self.need_mixin_abc
148
+ or self.has_etag
149
+ or self.has_form_data
150
+ )
137
151
 
138
152
  @property
139
153
  def need_request_converter(self) -> bool:
@@ -65,8 +65,7 @@ class CombinedType(BaseType):
65
65
 
66
66
  def type_annotation(self, **kwargs: Any) -> str:
67
67
  if self.name:
68
- ret = f"_types.{self.name}"
69
- return ret if kwargs.get("is_operation_file") else f'"{ret}"'
68
+ return f'"_types.{self.name}"'
70
69
  return self.type_definition(**kwargs)
71
70
 
72
71
  def type_definition(self, **kwargs: Any) -> str:
@@ -184,8 +184,7 @@ class FileImport:
184
184
  self._append_import(i)
185
185
  self.type_definitions.update(file_import.type_definitions)
186
186
 
187
- def define_mutable_mapping_type(self) -> None:
188
- """Helper function for defining the mutable mapping type"""
187
+ def add_mutable_mapping_import(self) -> None:
189
188
  self.add_import("sys", ImportType.STDLIB)
190
189
  self.add_submodule_import(
191
190
  "typing",
@@ -195,6 +194,10 @@ class FileImport:
195
194
  None,
196
195
  (((3, 9), "collections.abc", None),),
197
196
  )
197
+
198
+ def define_mutable_mapping_type(self) -> None:
199
+ """Helper function for defining the mutable mapping type"""
200
+ self.add_mutable_mapping_import()
198
201
  self.define_mypy_type(
199
202
  "JSON",
200
203
  "MutableMapping[str, Any] # pylint: disable=unsubscriptable-object",
@@ -92,6 +92,10 @@ class OperationBase( # pylint: disable=too-many-public-methods
92
92
  self.name = "_" + self.name
93
93
  self.has_etag: bool = self.yaml_data.get("hasEtag", False)
94
94
 
95
+ @property
96
+ def has_form_data_body(self):
97
+ return self.parameters.has_form_data_body
98
+
95
99
  @property
96
100
  def expose_stream_keyword(self) -> bool:
97
101
  return self.yaml_data.get("exposeStreamKeyword", False)
@@ -568,13 +572,31 @@ class Operation(OperationBase[Response]):
568
572
  and not self.code_model.options["models_mode"]
569
573
  ):
570
574
  file_import.add_submodule_import("typing", "cast", ImportType.STDLIB)
575
+ relative_path = "..." if async_mode else ".."
571
576
  if self.code_model.options["models_mode"] == "dpg":
572
- relative_path = "..." if async_mode else ".."
573
577
  if self.parameters.has_body:
574
- file_import.add_submodule_import(
575
- f"{relative_path}_model_base", "SdkJSONEncoder", ImportType.LOCAL
576
- )
577
- file_import.add_import("json", ImportType.STDLIB)
578
+ if not self.parameters.body_parameter.is_form_data:
579
+ file_import.add_submodule_import(
580
+ f"{relative_path}_model_base",
581
+ "SdkJSONEncoder",
582
+ ImportType.LOCAL,
583
+ )
584
+ file_import.add_import("json", ImportType.STDLIB)
585
+ else:
586
+ file_import.add_submodule_import(
587
+ relative_path, "_model_base", ImportType.LOCAL
588
+ )
589
+ file_import.add_submodule_import("io", "IOBase", ImportType.STDLIB)
590
+ file_import.add_submodule_import(
591
+ f"{relative_path}_vendor",
592
+ "multipart_form_data_file",
593
+ ImportType.LOCAL,
594
+ )
595
+ file_import.add_submodule_import(
596
+ f"{relative_path}_vendor",
597
+ "handle_multipart_form_data_model",
598
+ ImportType.LOCAL,
599
+ )
578
600
  if self.default_error_deserialization or any(
579
601
  r.type for r in self.responses
580
602
  ):
@@ -11,6 +11,8 @@ from .base import BaseModel
11
11
  from .operation import get_operation
12
12
  from .imports import FileImport, ImportType, TypingSection
13
13
  from .utils import add_to_pylint_disable, NAME_LENGTH_LIMIT
14
+ from .lro_operation import LROOperation
15
+ from .lro_paging_operation import LROPagingOperation
14
16
 
15
17
  if TYPE_CHECKING:
16
18
  from .code_model import CodeModel
@@ -32,9 +34,17 @@ class OperationGroup(BaseModel):
32
34
  super().__init__(yaml_data, code_model)
33
35
  self.client = client
34
36
  self.class_name: str = yaml_data["className"]
37
+ self.identify_name: str = yaml_data["identifyName"]
35
38
  self.property_name: str = yaml_data["propertyName"]
36
39
  self.operations = operations
37
40
  self.api_versions = api_versions
41
+ self.operation_groups: List[OperationGroup] = []
42
+ if self.code_model.options["show_operations"]:
43
+ self.operation_groups = [
44
+ OperationGroup.from_yaml(op_group, code_model, client)
45
+ for op_group in self.yaml_data.get("operationGroups", [])
46
+ ]
47
+ self.link_lro_initial_operations()
38
48
 
39
49
  @property
40
50
  def has_abstract_operations(self) -> bool:
@@ -43,7 +53,7 @@ class OperationGroup(BaseModel):
43
53
  @property
44
54
  def base_class(self) -> str:
45
55
  base_classes: List[str] = []
46
- if self.is_mixin and self.code_model.need_mixin_abc:
56
+ if self.is_mixin:
47
57
  base_classes.append(f"{self.client.name}MixinABC")
48
58
  return ", ".join(base_classes)
49
59
 
@@ -71,6 +81,8 @@ class OperationGroup(BaseModel):
71
81
  retval = add_to_pylint_disable(retval, "too-many-public-methods")
72
82
  if len(self.class_name) > NAME_LENGTH_LIMIT:
73
83
  retval = add_to_pylint_disable(retval, "name-too-long")
84
+ if len(self.operation_groups) > 6:
85
+ retval = add_to_pylint_disable(retval, "too-many-instance-attributes")
74
86
  return retval
75
87
 
76
88
  @property
@@ -88,6 +100,13 @@ class OperationGroup(BaseModel):
88
100
  file_import.merge(
89
101
  operation.imports(async_mode, relative_path=relative_path)
90
102
  )
103
+ if not self.code_model.options["combine_operation_files"]:
104
+ for og in self.operation_groups:
105
+ file_import.add_submodule_import(
106
+ ".",
107
+ og.class_name,
108
+ ImportType.LOCAL,
109
+ )
91
110
  # for multiapi
92
111
  if (
93
112
  (self.code_model.public_model_types)
@@ -124,7 +143,45 @@ class OperationGroup(BaseModel):
124
143
  @property
125
144
  def is_mixin(self) -> bool:
126
145
  """The operation group with no name is the direct client methods."""
127
- return self.property_name == ""
146
+ return self.identify_name == ""
147
+
148
+ def link_lro_initial_operations(self) -> None:
149
+ """Link each LRO operation to its initial operation"""
150
+ for operation_group in self.operation_groups:
151
+ for operation in operation_group.operations:
152
+ if isinstance(operation, (LROOperation, LROPagingOperation)):
153
+ operation.initial_operation = self.lookup_operation(
154
+ id(operation.yaml_data["initialOperation"])
155
+ )
156
+
157
+ def lookup_operation(self, operation_id: int) -> "OperationType":
158
+ try:
159
+ return next(
160
+ o
161
+ for og in self.operation_groups
162
+ for o in og.operations
163
+ if id(o.yaml_data) == operation_id
164
+ )
165
+ except StopIteration as exc:
166
+ raise KeyError(f"No operation with id {operation_id} found.") from exc
167
+
168
+ @property
169
+ def lro_operations(self) -> List["OperationType"]:
170
+ return [
171
+ operation
172
+ for operation in self.operations
173
+ if operation.operation_type in ("lro", "lropaging")
174
+ ] + [
175
+ operation
176
+ for operation_group in self.operation_groups
177
+ for operation in operation_group.lro_operations
178
+ ]
179
+
180
+ @property
181
+ def has_operations(self) -> bool:
182
+ return any(
183
+ operation_group.has_operations for operation_group in self.operation_groups
184
+ ) or bool(self.operations)
128
185
 
129
186
  @classmethod
130
187
  def from_yaml(
@@ -179,7 +179,7 @@ class _ParameterBase(
179
179
  )
180
180
  if isinstance(self.type, CombinedType) and self.type.name:
181
181
  file_import.add_submodule_import(
182
- "..",
182
+ "..." if async_mode else "..",
183
183
  "_types",
184
184
  ImportType.LOCAL,
185
185
  TypingSection.TYPING,
@@ -243,6 +243,10 @@ class _ParameterBase(
243
243
  class BodyParameter(_ParameterBase):
244
244
  """Body parameter."""
245
245
 
246
+ @property
247
+ def is_form_data(self) -> bool:
248
+ return self.default_content_type == "multipart/form-data"
249
+
246
250
  @property
247
251
  def is_partial_body(self) -> bool:
248
252
  """Whether it's part of a bigger body parameter, i.e. a MultipartBodyParameter"""
@@ -133,6 +133,10 @@ class _ParameterListBase(
133
133
  params.append(self.body_parameter)
134
134
  return params
135
135
 
136
+ @property
137
+ def has_form_data_body(self):
138
+ return self.has_body and self.body_parameter.is_form_data
139
+
136
140
  @property
137
141
  def has_body(self) -> bool:
138
142
  """Whether there is a body parameter in the parameter list"""
@@ -57,6 +57,10 @@ class RequestBuilderBase(BaseBuilder[ParameterListType]):
57
57
  self.method: str = yaml_data["method"]
58
58
  self.want_tracing = False
59
59
 
60
+ @property
61
+ def has_form_data_body(self):
62
+ return self.parameters.has_form_data_body
63
+
60
64
  @property
61
65
  def is_lro(self) -> bool:
62
66
  return self.yaml_data.get("discriminator") in ("lro", "lropaging")
@@ -120,8 +120,9 @@ class Response(BaseModel):
120
120
  if self.nullable:
121
121
  file_import.add_submodule_import("typing", "Optional", ImportType.STDLIB)
122
122
  if isinstance(self.type, CombinedType) and self.type.name:
123
+ async_mode = kwargs.get("async_mode", False)
123
124
  file_import.add_submodule_import(
124
- "..",
125
+ "..." if async_mode else "..",
125
126
  "_types",
126
127
  ImportType.LOCAL,
127
128
  TypingSection.TYPING,
@@ -34,6 +34,7 @@ from .utils import (
34
34
  extract_sample_name,
35
35
  get_namespace_from_package_name,
36
36
  get_namespace_config,
37
+ get_all_operation_groups_recursively,
37
38
  )
38
39
 
39
40
  _LOGGER = logging.getLogger(__name__)
@@ -415,14 +416,15 @@ class JinjaSerializer(ReaderAndWriter): # pylint: disable=abstract-method
415
416
  clients=clients,
416
417
  )
417
418
  else:
418
- for client in self.code_model.clients:
419
- for operation_group in client.operation_groups:
420
- self._serialize_and_write_operations_file(
421
- env=env,
422
- namespace_path=namespace_path,
423
- operation_group=operation_group,
424
- clients=self.code_model.clients,
425
- )
419
+ for operation_group in get_all_operation_groups_recursively(
420
+ self.code_model.clients
421
+ ):
422
+ self._serialize_and_write_operations_file(
423
+ env=env,
424
+ namespace_path=namespace_path,
425
+ operation_group=operation_group,
426
+ clients=clients,
427
+ )
426
428
 
427
429
  def _serialize_and_write_version_file(
428
430
  self,
@@ -38,6 +38,7 @@ from ..models import (
38
38
  from .parameter_serializer import ParameterSerializer, PopKwargType
39
39
  from ..models.parameter_list import ParameterType
40
40
  from . import utils
41
+ from ..._utils import JSON_REGEXP
41
42
 
42
43
  T = TypeVar("T")
43
44
  OrderedSet = Dict[T, None]
@@ -59,6 +60,14 @@ OperationType = TypeVar(
59
60
  )
60
61
 
61
62
 
63
+ def _need_type_ignore(builder: OperationType) -> bool:
64
+ for excep in builder.non_default_errors:
65
+ for status_code in excep.status_codes:
66
+ if status_code in (401, 404, 409, 304):
67
+ return True
68
+ return False
69
+
70
+
62
71
  def _xml_config(send_xml: bool, content_types: List[str]) -> str:
63
72
  if not (send_xml and "xml" in str(content_types)):
64
73
  return ""
@@ -547,11 +556,16 @@ class RequestBuilderSerializer(
547
556
  return retval
548
557
 
549
558
  def serialize_headers(self, builder: RequestBuilderType) -> List[str]:
550
- retval = ["# Construct headers"]
551
- for parameter in builder.parameters.headers:
559
+ headers = [
560
+ h
561
+ for h in builder.parameters.headers
562
+ if not builder.has_form_data_body or h.wire_name.lower() != "content-type"
563
+ ]
564
+ retval = ["# Construct headers"] if headers else []
565
+ for header in headers:
552
566
  retval.extend(
553
567
  self.parameter_serializer.serialize_query_header(
554
- parameter,
568
+ header,
555
569
  "headers",
556
570
  self.serializer_name,
557
571
  self.code_model.is_legacy,
@@ -747,8 +761,17 @@ class _OperationSerializer(
747
761
 
748
762
  This function serializes the body params that need to be serialized.
749
763
  """
750
- retval: List[str] = []
751
764
  body_param = cast(BodyParameter, builder.parameters.body_parameter)
765
+ if body_param.is_form_data:
766
+ return [
767
+ f"if isinstance({body_param.client_name}, _model_base.Model):",
768
+ f" _body = handle_multipart_form_data_model({body_param.client_name})",
769
+ "else:",
770
+ f" _body = {body_param.client_name}",
771
+ "_files = {k: multipart_form_data_file(v) for k, v in _body.items() if isinstance(v, (IOBase, bytes))}",
772
+ "_data = {k: v for k, v in _body.items() if not isinstance(v, (IOBase, bytes))}",
773
+ ]
774
+ retval: List[str] = []
752
775
  body_kwarg_name = builder.request_builder.parameters.body_parameter.client_name
753
776
  send_xml = builder.parameters.body_parameter.type.is_xml
754
777
  xml_serialization_ctxt = (
@@ -769,17 +792,20 @@ class _OperationSerializer(
769
792
  f"'{body_param.type.serialization_type}'{is_xml_cmd}{serialization_ctxt_cmd})"
770
793
  )
771
794
  elif self.code_model.options["models_mode"] == "dpg":
772
- if hasattr(body_param.type, "encode") and body_param.type.encode: # type: ignore
773
- create_body_call = (
774
- f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
775
- "cls=SdkJSONEncoder, exclude_readonly=True, "
776
- f"format='{body_param.type.encode}') # type: ignore" # type: ignore
777
- )
795
+ if JSON_REGEXP.match(body_param.default_content_type):
796
+ if hasattr(body_param.type, "encode") and body_param.type.encode: # type: ignore
797
+ create_body_call = (
798
+ f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
799
+ "cls=SdkJSONEncoder, exclude_readonly=True, "
800
+ f"format='{body_param.type.encode}') # type: ignore" # type: ignore
801
+ )
802
+ else:
803
+ create_body_call = (
804
+ f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
805
+ "cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore"
806
+ )
778
807
  else:
779
- create_body_call = (
780
- f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
781
- "cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore"
782
- )
808
+ create_body_call = f"_{body_kwarg_name} = {body_param.client_name}"
783
809
  else:
784
810
  create_body_call = f"_{body_kwarg_name} = {body_param.client_name}"
785
811
  if body_param.optional:
@@ -959,7 +985,10 @@ class _OperationSerializer(
959
985
  f" {parameter.client_name}={parameter.name_in_high_level_operation},"
960
986
  f"{' # type: ignore' if type_ignore else ''}"
961
987
  )
962
- if request_builder.overloads:
988
+ if request_builder.has_form_data_body:
989
+ retval.append(" data=_data,")
990
+ retval.append(" files=_files,")
991
+ elif request_builder.overloads:
963
992
  seen_body_params = set()
964
993
  for overload in request_builder.overloads:
965
994
  body_param = cast(
@@ -1030,7 +1059,9 @@ class _OperationSerializer(
1030
1059
  builder.parameters.body_parameter, builder.parameters.parameters
1031
1060
  )
1032
1061
  )
1033
- if builder.overloads:
1062
+ if builder.has_form_data_body:
1063
+ retval.extend(self._create_body_parameter(builder))
1064
+ elif builder.overloads:
1034
1065
  # we are only dealing with two overloads. If there are three, we generate an abstract operation
1035
1066
  retval.extend(self._initialize_overloads(builder, is_paging=is_paging))
1036
1067
  elif builder.parameters.has_body:
@@ -1097,12 +1128,17 @@ class _OperationSerializer(
1097
1128
  and response.default_content_type == "application/json"
1098
1129
  else ""
1099
1130
  )
1131
+ response_attr = (
1132
+ "json"
1133
+ if JSON_REGEXP.match(str(response.default_content_type))
1134
+ else "text"
1135
+ )
1100
1136
  deserialize_code.append("deserialized = _deserialize(")
1101
1137
  deserialize_code.append(
1102
1138
  f" {response.type.type_annotation(is_operation_file=True)},{pylint_disable}"
1103
1139
  )
1104
1140
  deserialize_code.append(
1105
- f" response.json(){response.result_property}{format_filed}"
1141
+ f" response.{response_attr}(){response.result_property}{format_filed}"
1106
1142
  )
1107
1143
  deserialize_code.append(")")
1108
1144
 
@@ -1138,8 +1174,9 @@ class _OperationSerializer(
1138
1174
  f" {async_await} response.read() # Load the body in memory and close the socket",
1139
1175
  ]
1140
1176
  )
1177
+ type_ignore = " # type: ignore" if _need_type_ignore(builder) else ""
1141
1178
  retval.append(
1142
- " map_error(status_code=response.status_code, response=response, error_map=error_map)"
1179
+ f" map_error(status_code=response.status_code, response=response, error_map=error_map){type_ignore}"
1143
1180
  )
1144
1181
  error_model = ""
1145
1182
  if (
@@ -48,9 +48,8 @@ class ClientSerializer:
48
48
  @property
49
49
  def class_definition(self) -> str:
50
50
  class_name = self.client.name
51
- has_mixin_og = any(og for og in self.client.operation_groups if og.is_mixin)
52
51
  base_class = ""
53
- if has_mixin_og:
52
+ if self.client.has_mixin:
54
53
  base_class = f"{class_name}OperationsMixin"
55
54
  pylint_disable = self.client.pylint_disable
56
55
  if base_class:
@@ -130,6 +130,14 @@ class GeneralSerializer(BaseSerializer):
130
130
  "MatchConditions",
131
131
  ImportType.SDKCORE,
132
132
  )
133
+ if self.code_model.has_form_data:
134
+ file_import.add_submodule_import("typing", "Union", ImportType.STDLIB)
135
+ file_import.add_submodule_import("typing", "Any", ImportType.STDLIB)
136
+ file_import.add_submodule_import("io", "IOBase", ImportType.STDLIB)
137
+ file_import.add_submodule_import("io", "BytesIO", ImportType.STDLIB)
138
+ file_import.add_import("uuid", ImportType.STDLIB)
139
+ file_import.add_mutable_mapping_import()
140
+ file_import.add_submodule_import("._model_base", "Model", ImportType.LOCAL)
133
141
 
134
142
  return template.render(
135
143
  code_model=self.code_model,
@@ -7,6 +7,7 @@ from typing import Optional, List, Union
7
7
  import functools
8
8
  from jinja2 import Environment
9
9
 
10
+ from .utils import get_all_operation_groups_recursively
10
11
  from ..models import (
11
12
  CodeModel,
12
13
  OperationGroup,
@@ -45,18 +46,18 @@ class OperationGroupsSerializer(BaseSerializer):
45
46
  for client in self.clients
46
47
  for r in client.request_builders
47
48
  if r.client.name == operation_group.client.name
48
- and r.group_name == operation_group.property_name
49
+ and r.group_name == operation_group.identify_name
49
50
  and not r.is_overload
50
51
  and not r.abstract
51
52
  and not r.is_lro # lro has already initial builder
52
53
  ]
53
54
 
54
55
  def serialize(self) -> str:
55
- operation_groups = (
56
- [self.operation_group]
57
- if self.operation_group
58
- else [og for client in self.clients for og in client.operation_groups]
59
- )
56
+ if self.operation_group:
57
+ operation_groups = [self.operation_group]
58
+ else:
59
+ operation_groups = get_all_operation_groups_recursively(self.clients)
60
+
60
61
  imports = FileImport(self.code_model)
61
62
  for operation_group in operation_groups:
62
63
  imports.merge(
@@ -3,9 +3,11 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import Optional
6
+ from typing import Optional, List
7
7
  from pathlib import Path
8
8
 
9
+ from ..models import Client, OperationGroup
10
+
9
11
 
10
12
  def method_signature_and_response_type_annotation_template(
11
13
  *,
@@ -30,3 +32,15 @@ def get_namespace_config(namespace: str, multiapi: bool) -> str:
30
32
 
31
33
  def get_namespace_from_package_name(package_name: Optional[str]) -> str:
32
34
  return (package_name or "").replace("-", ".")
35
+
36
+
37
+ def get_all_operation_groups_recursively(clients: List[Client]) -> List[OperationGroup]:
38
+ operation_groups = []
39
+ queue = []
40
+ for client in clients:
41
+ queue.extend(client.operation_groups)
42
+ while queue:
43
+ operation_groups.append(queue.pop(0))
44
+ if operation_groups[-1].operation_groups:
45
+ queue.extend(operation_groups[-1].operation_groups)
46
+ return operation_groups
@@ -726,7 +726,11 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
726
726
  ):
727
727
  if obj is None:
728
728
  return obj
729
- return _deserialize_with_callable(deserializer, obj)
729
+ try:
730
+ return _deserialize_with_callable(deserializer, obj)
731
+ except Exception:
732
+ pass
733
+ return obj
730
734
 
731
735
  if get_deserializer(annotation, rf):
732
736
  return functools.partial(_deserialize_default, get_deserializer(annotation, rf))
@@ -38,6 +38,13 @@ class {{ operation_group.class_name }}: {{ operation_group.pylint_disable }}
38
38
  {% if code_model.options["multiapi"] %}
39
39
  self._api_version = input_args.pop(0) if input_args else kwargs.pop("api_version")
40
40
  {% endif %}
41
+
42
+ {% for og in operation_group.operation_groups %}
43
+ self.{{ og.property_name }} = {{ og.class_name }}(
44
+ self._client, self._config, self._serialize, self._deserialize{{ ", self._api_version" if code_model.options["multiapi"] else "" }}
45
+ )
46
+ {% endfor %}
47
+
41
48
  {{ check_abstract_methods() }}
42
49
  {% elif operation_group.has_abstract_operations %}
43
50
 
@@ -94,7 +94,7 @@ setup(
94
94
  {% elif azure_arm %}
95
95
  "azure-mgmt-core<2.0.0,>=1.3.2",
96
96
  {% else %}
97
- "azure-core<2.0.0,>=1.28.0",
97
+ "azure-core<2.0.0,>=1.29.5",
98
98
  {% endif %}
99
99
  {% if code_model.need_typing_extensions %}
100
100
  "typing-extensions>=4.3.0; python_version<'3.8.0'",
@@ -65,3 +65,33 @@ def prep_if_none_match(etag: Optional[str], match_condition: Optional[MatchCondi
65
65
  return "*"
66
66
  return None
67
67
  {% endif %}
68
+ {% if code_model.has_form_data %}
69
+ class NamedBytesIO(BytesIO):
70
+ def __init__(self, name: str, *args, **kwargs):
71
+ super().__init__(*args, **kwargs)
72
+ self.name = name
73
+
74
+ def multipart_form_data_file(file: Union[IOBase, bytes]) -> IOBase:
75
+ if isinstance(file, IOBase):
76
+ return file
77
+ return NamedBytesIO("auto-name-" + str(uuid.uuid4()), file)
78
+
79
+ def handle_multipart_form_data_model(body: Model) -> MutableMapping[str, Any]: # pylint: disable=unsubscriptable-object
80
+ """handle first layer of model.
81
+ If its value is bytes or IO, replace it with raw value instead of serialized value.
82
+
83
+ :param body: The model to handle.
84
+ :type body: ~payload.multipart._model_base.Model
85
+ :return: The handled model.
86
+ :rtype: MutableMapping[str, Any]
87
+ """
88
+ result = body.as_dict()
89
+ rest_name_attr = {v._rest_name: k for k, v in body._attr_to_rest_field.items()} # pylint: disable=protected-access
90
+ for rest_name in result.keys():
91
+ attr = rest_name_attr.get(rest_name)
92
+ if attr is not None:
93
+ raw_value = getattr(body, attr, None)
94
+ if isinstance(raw_value, (bytes, IOBase)):
95
+ result[rest_name] = raw_value
96
+ return result
97
+ {% endif %}
@@ -124,11 +124,9 @@ def update_description(
124
124
  return description
125
125
 
126
126
 
127
- def update_operation_group_class_name(
128
- yaml_data: Dict[str, Any], class_name: str
129
- ) -> str:
127
+ def update_operation_group_class_name(prefix: str, class_name: str) -> str:
130
128
  if class_name == "":
131
- return yaml_data["name"] + "OperationsMixin"
129
+ return prefix + "OperationsMixin"
132
130
  if class_name == "Operations":
133
131
  return "Operations"
134
132
  return class_name + "Operations"
@@ -173,6 +171,16 @@ def headers_convert(yaml_data: Dict[str, Any], replace_data: Any) -> None:
173
171
  yaml_data[k] = v
174
172
 
175
173
 
174
+ def has_json_content_type(yaml_data: Dict[str, Any]) -> bool:
175
+ return any(ct for ct in yaml_data.get("contentTypes", []) if JSON_REGEXP.match(ct))
176
+
177
+
178
+ def has_multi_part_content_type(yaml_data: Dict[str, Any]) -> bool:
179
+ return any(
180
+ ct for ct in yaml_data.get("contentTypes", []) if ct == "multipart/form-data"
181
+ )
182
+
183
+
176
184
  class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
177
185
  """Add Python naming information."""
178
186
 
@@ -197,13 +205,13 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
197
205
  code_model: Dict[str, Any],
198
206
  body_parameter: Dict[str, Any],
199
207
  ):
200
- if (
208
+ # only add overload for special content type
209
+ if ( # pylint: disable=too-many-boolean-expressions
201
210
  body_parameter
202
211
  and body_parameter["type"]["type"] in ("model", "dict", "list")
203
- and any(
204
- ct
205
- for ct in body_parameter.get("contentTypes", [])
206
- if JSON_REGEXP.match(ct)
212
+ and (
213
+ has_json_content_type(body_parameter)
214
+ or (self.is_cadl and has_multi_part_content_type(body_parameter))
207
215
  )
208
216
  and not body_parameter["type"].get("xmlMetadata")
209
217
  and not any(t for t in ["flattened", "groupedBy"] if body_parameter.get(t))
@@ -212,8 +220,12 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
212
220
  is_dpg_model = body_parameter["type"].get("base") == "dpg"
213
221
  body_parameter["type"] = {
214
222
  "type": "combined",
215
- "types": [body_parameter["type"], KNOWN_TYPES["binary"]],
223
+ "types": [body_parameter["type"]],
216
224
  }
225
+ # don't add binary overload for multipart content type
226
+ if not (self.is_cadl and has_multi_part_content_type(body_parameter)):
227
+ body_parameter["type"]["types"].append(KNOWN_TYPES["binary"])
228
+
217
229
  if origin_type == "model" and is_dpg_model and self.models_mode == "dpg":
218
230
  body_parameter["type"]["types"].insert(1, KNOWN_TYPES["any-object"])
219
231
  code_model["types"].append(body_parameter["type"])
@@ -373,9 +385,9 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
373
385
  for prop, param_name in yaml_data["propertyToParameterName"].items()
374
386
  }
375
387
  wire_name_lower = (yaml_data.get("wireName") or "").lower()
376
- if (
377
- yaml_data["location"] == "header"
378
- and wire_name_lower in HEADERS_HIDE_IN_METHOD
388
+ if yaml_data["location"] == "header" and (
389
+ wire_name_lower in HEADERS_HIDE_IN_METHOD
390
+ or yaml_data.get("clientDefaultValue") == "multipart/form-data"
379
391
  ):
380
392
  yaml_data["hideInMethod"] = True
381
393
  if (
@@ -511,7 +523,13 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
511
523
  ) -> None:
512
524
  operation_groups_yaml_data = client["operationGroups"]
513
525
  for operation_group in operation_groups_yaml_data:
514
- operation_group["clientName"] = client["name"]
526
+ operation_group["identifyName"] = self.pad_reserved_words(
527
+ operation_group.get("name", operation_group["propertyName"]),
528
+ PadType.OPERATION_GROUP,
529
+ )
530
+ operation_group["identifyName"] = to_snake_case(
531
+ operation_group["identifyName"]
532
+ )
515
533
  operation_group["propertyName"] = self.pad_reserved_words(
516
534
  operation_group["propertyName"], PadType.OPERATION_GROUP
517
535
  )
@@ -519,11 +537,14 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
519
537
  operation_group["propertyName"]
520
538
  )
521
539
  operation_group["className"] = update_operation_group_class_name(
522
- client, operation_group["className"]
540
+ client["name"], operation_group["className"]
523
541
  )
524
542
  for operation in operation_group["operations"]:
525
543
  self.get_operation_updater(operation)(code_model, operation)
526
544
 
545
+ if operation_group.get("operationGroups"):
546
+ self.update_operation_groups(code_model, operation_group)
547
+
527
548
  def update_yaml(self, yaml_data: Dict[str, Any]) -> None:
528
549
  """Convert in place the YAML str."""
529
550
  self.update_types(yaml_data["types"])
@@ -93,6 +93,7 @@ _always_reserved = [
93
93
  "yield",
94
94
  "async",
95
95
  "await",
96
+ "int",
96
97
  ]
97
98
 
98
99
  RESERVED_WORDS = {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@autorest/python",
3
- "version": "6.10.2",
3
+ "version": "6.11.1",
4
4
  "description": "The Python extension for generators in AutoRest.",
5
5
  "main": "index.js",
6
6
  "repository": {