@autorest/python 6.11.0 → 6.11.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -77,6 +77,15 @@ class CodeModel: # pylint: disable=too-many-public-methods, disable=too-many-in
77
77
  t for t in self.types_map.values() if isinstance(t, CombinedType) and t.name
78
78
  ]
79
79
 
80
+ @property
81
+ def has_form_data(self) -> bool:
82
+ for client in self.clients:
83
+ for operation_group in client.operation_groups:
84
+ for operation in operation_group.operations:
85
+ if operation.has_form_data_body:
86
+ return True
87
+ return False
88
+
80
89
  @property
81
90
  def has_etag(self) -> bool:
82
91
  return any(client.has_etag for client in self.clients)
@@ -133,7 +142,12 @@ class CodeModel: # pylint: disable=too-many-public-methods, disable=too-many-in
133
142
  return True
134
143
  if async_mode:
135
144
  return self.need_mixin_abc
136
- return self.need_request_converter or self.need_mixin_abc or self.has_etag
145
+ return (
146
+ self.need_request_converter
147
+ or self.need_mixin_abc
148
+ or self.has_etag
149
+ or self.has_form_data
150
+ )
137
151
 
138
152
  @property
139
153
  def need_request_converter(self) -> bool:
@@ -184,8 +184,7 @@ class FileImport:
184
184
  self._append_import(i)
185
185
  self.type_definitions.update(file_import.type_definitions)
186
186
 
187
- def define_mutable_mapping_type(self) -> None:
188
- """Helper function for defining the mutable mapping type"""
187
+ def add_mutable_mapping_import(self) -> None:
189
188
  self.add_import("sys", ImportType.STDLIB)
190
189
  self.add_submodule_import(
191
190
  "typing",
@@ -195,6 +194,10 @@ class FileImport:
195
194
  None,
196
195
  (((3, 9), "collections.abc", None),),
197
196
  )
197
+
198
+ def define_mutable_mapping_type(self) -> None:
199
+ """Helper function for defining the mutable mapping type"""
200
+ self.add_mutable_mapping_import()
198
201
  self.define_mypy_type(
199
202
  "JSON",
200
203
  "MutableMapping[str, Any] # pylint: disable=unsubscriptable-object",
@@ -92,6 +92,10 @@ class OperationBase( # pylint: disable=too-many-public-methods
92
92
  self.name = "_" + self.name
93
93
  self.has_etag: bool = self.yaml_data.get("hasEtag", False)
94
94
 
95
+ @property
96
+ def has_form_data_body(self):
97
+ return self.parameters.has_form_data_body
98
+
95
99
  @property
96
100
  def expose_stream_keyword(self) -> bool:
97
101
  return self.yaml_data.get("exposeStreamKeyword", False)
@@ -568,13 +572,31 @@ class Operation(OperationBase[Response]):
568
572
  and not self.code_model.options["models_mode"]
569
573
  ):
570
574
  file_import.add_submodule_import("typing", "cast", ImportType.STDLIB)
575
+ relative_path = "..." if async_mode else ".."
571
576
  if self.code_model.options["models_mode"] == "dpg":
572
- relative_path = "..." if async_mode else ".."
573
577
  if self.parameters.has_body:
574
- file_import.add_submodule_import(
575
- f"{relative_path}_model_base", "SdkJSONEncoder", ImportType.LOCAL
576
- )
577
- file_import.add_import("json", ImportType.STDLIB)
578
+ if not self.parameters.body_parameter.is_form_data:
579
+ file_import.add_submodule_import(
580
+ f"{relative_path}_model_base",
581
+ "SdkJSONEncoder",
582
+ ImportType.LOCAL,
583
+ )
584
+ file_import.add_import("json", ImportType.STDLIB)
585
+ else:
586
+ file_import.add_submodule_import(
587
+ relative_path, "_model_base", ImportType.LOCAL
588
+ )
589
+ file_import.add_submodule_import("io", "IOBase", ImportType.STDLIB)
590
+ file_import.add_submodule_import(
591
+ f"{relative_path}_vendor",
592
+ "multipart_form_data_file",
593
+ ImportType.LOCAL,
594
+ )
595
+ file_import.add_submodule_import(
596
+ f"{relative_path}_vendor",
597
+ "handle_multipart_form_data_model",
598
+ ImportType.LOCAL,
599
+ )
578
600
  if self.default_error_deserialization or any(
579
601
  r.type for r in self.responses
580
602
  ):
@@ -243,6 +243,10 @@ class _ParameterBase(
243
243
  class BodyParameter(_ParameterBase):
244
244
  """Body parameter."""
245
245
 
246
+ @property
247
+ def is_form_data(self) -> bool:
248
+ return self.default_content_type == "multipart/form-data"
249
+
246
250
  @property
247
251
  def is_partial_body(self) -> bool:
248
252
  """Whether it's part of a bigger body parameter, i.e. a MultipartBodyParameter"""
@@ -133,6 +133,10 @@ class _ParameterListBase(
133
133
  params.append(self.body_parameter)
134
134
  return params
135
135
 
136
+ @property
137
+ def has_form_data_body(self):
138
+ return self.has_body and self.body_parameter.is_form_data
139
+
136
140
  @property
137
141
  def has_body(self) -> bool:
138
142
  """Whether there is a body parameter in the parameter list"""
@@ -57,6 +57,10 @@ class RequestBuilderBase(BaseBuilder[ParameterListType]):
57
57
  self.method: str = yaml_data["method"]
58
58
  self.want_tracing = False
59
59
 
60
+ @property
61
+ def has_form_data_body(self):
62
+ return self.parameters.has_form_data_body
63
+
60
64
  @property
61
65
  def is_lro(self) -> bool:
62
66
  return self.yaml_data.get("discriminator") in ("lro", "lropaging")
@@ -38,6 +38,7 @@ from ..models import (
38
38
  from .parameter_serializer import ParameterSerializer, PopKwargType
39
39
  from ..models.parameter_list import ParameterType
40
40
  from . import utils
41
+ from ..._utils import JSON_REGEXP
41
42
 
42
43
  T = TypeVar("T")
43
44
  OrderedSet = Dict[T, None]
@@ -59,6 +60,14 @@ OperationType = TypeVar(
59
60
  )
60
61
 
61
62
 
63
+ def _need_type_ignore(builder: OperationType) -> bool:
64
+ for excep in builder.non_default_errors:
65
+ for status_code in excep.status_codes:
66
+ if status_code in (401, 404, 409, 304):
67
+ return True
68
+ return False
69
+
70
+
62
71
  def _xml_config(send_xml: bool, content_types: List[str]) -> str:
63
72
  if not (send_xml and "xml" in str(content_types)):
64
73
  return ""
@@ -547,11 +556,16 @@ class RequestBuilderSerializer(
547
556
  return retval
548
557
 
549
558
  def serialize_headers(self, builder: RequestBuilderType) -> List[str]:
550
- retval = ["# Construct headers"]
551
- for parameter in builder.parameters.headers:
559
+ headers = [
560
+ h
561
+ for h in builder.parameters.headers
562
+ if not builder.has_form_data_body or h.wire_name.lower() != "content-type"
563
+ ]
564
+ retval = ["# Construct headers"] if headers else []
565
+ for header in headers:
552
566
  retval.extend(
553
567
  self.parameter_serializer.serialize_query_header(
554
- parameter,
568
+ header,
555
569
  "headers",
556
570
  self.serializer_name,
557
571
  self.code_model.is_legacy,
@@ -747,8 +761,17 @@ class _OperationSerializer(
747
761
 
748
762
  This function serializes the body params that need to be serialized.
749
763
  """
750
- retval: List[str] = []
751
764
  body_param = cast(BodyParameter, builder.parameters.body_parameter)
765
+ if body_param.is_form_data:
766
+ return [
767
+ f"if isinstance({body_param.client_name}, _model_base.Model):",
768
+ f" _body = handle_multipart_form_data_model({body_param.client_name})",
769
+ "else:",
770
+ f" _body = {body_param.client_name}",
771
+ "_files = {k: multipart_form_data_file(v) for k, v in _body.items() if isinstance(v, (IOBase, bytes))}",
772
+ "_data = {k: v for k, v in _body.items() if not isinstance(v, (IOBase, bytes))}",
773
+ ]
774
+ retval: List[str] = []
752
775
  body_kwarg_name = builder.request_builder.parameters.body_parameter.client_name
753
776
  send_xml = builder.parameters.body_parameter.type.is_xml
754
777
  xml_serialization_ctxt = (
@@ -769,17 +792,20 @@ class _OperationSerializer(
769
792
  f"'{body_param.type.serialization_type}'{is_xml_cmd}{serialization_ctxt_cmd})"
770
793
  )
771
794
  elif self.code_model.options["models_mode"] == "dpg":
772
- if hasattr(body_param.type, "encode") and body_param.type.encode: # type: ignore
773
- create_body_call = (
774
- f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
775
- "cls=SdkJSONEncoder, exclude_readonly=True, "
776
- f"format='{body_param.type.encode}') # type: ignore" # type: ignore
777
- )
795
+ if JSON_REGEXP.match(body_param.default_content_type):
796
+ if hasattr(body_param.type, "encode") and body_param.type.encode: # type: ignore
797
+ create_body_call = (
798
+ f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
799
+ "cls=SdkJSONEncoder, exclude_readonly=True, "
800
+ f"format='{body_param.type.encode}') # type: ignore" # type: ignore
801
+ )
802
+ else:
803
+ create_body_call = (
804
+ f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
805
+ "cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore"
806
+ )
778
807
  else:
779
- create_body_call = (
780
- f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
781
- "cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore"
782
- )
808
+ create_body_call = f"_{body_kwarg_name} = {body_param.client_name}"
783
809
  else:
784
810
  create_body_call = f"_{body_kwarg_name} = {body_param.client_name}"
785
811
  if body_param.optional:
@@ -959,7 +985,10 @@ class _OperationSerializer(
959
985
  f" {parameter.client_name}={parameter.name_in_high_level_operation},"
960
986
  f"{' # type: ignore' if type_ignore else ''}"
961
987
  )
962
- if request_builder.overloads:
988
+ if request_builder.has_form_data_body:
989
+ retval.append(" data=_data,")
990
+ retval.append(" files=_files,")
991
+ elif request_builder.overloads:
963
992
  seen_body_params = set()
964
993
  for overload in request_builder.overloads:
965
994
  body_param = cast(
@@ -1030,7 +1059,9 @@ class _OperationSerializer(
1030
1059
  builder.parameters.body_parameter, builder.parameters.parameters
1031
1060
  )
1032
1061
  )
1033
- if builder.overloads:
1062
+ if builder.has_form_data_body:
1063
+ retval.extend(self._create_body_parameter(builder))
1064
+ elif builder.overloads:
1034
1065
  # we are only dealing with two overloads. If there are three, we generate an abstract operation
1035
1066
  retval.extend(self._initialize_overloads(builder, is_paging=is_paging))
1036
1067
  elif builder.parameters.has_body:
@@ -1097,12 +1128,17 @@ class _OperationSerializer(
1097
1128
  and response.default_content_type == "application/json"
1098
1129
  else ""
1099
1130
  )
1131
+ response_attr = (
1132
+ "json"
1133
+ if JSON_REGEXP.match(str(response.default_content_type))
1134
+ else "text"
1135
+ )
1100
1136
  deserialize_code.append("deserialized = _deserialize(")
1101
1137
  deserialize_code.append(
1102
1138
  f" {response.type.type_annotation(is_operation_file=True)},{pylint_disable}"
1103
1139
  )
1104
1140
  deserialize_code.append(
1105
- f" response.json(){response.result_property}{format_filed}"
1141
+ f" response.{response_attr}(){response.result_property}{format_filed}"
1106
1142
  )
1107
1143
  deserialize_code.append(")")
1108
1144
 
@@ -1138,8 +1174,9 @@ class _OperationSerializer(
1138
1174
  f" {async_await} response.read() # Load the body in memory and close the socket",
1139
1175
  ]
1140
1176
  )
1177
+ type_ignore = " # type: ignore" if _need_type_ignore(builder) else ""
1141
1178
  retval.append(
1142
- " map_error(status_code=response.status_code, response=response, error_map=error_map)"
1179
+ f" map_error(status_code=response.status_code, response=response, error_map=error_map){type_ignore}"
1143
1180
  )
1144
1181
  error_model = ""
1145
1182
  if (
@@ -130,6 +130,14 @@ class GeneralSerializer(BaseSerializer):
130
130
  "MatchConditions",
131
131
  ImportType.SDKCORE,
132
132
  )
133
+ if self.code_model.has_form_data:
134
+ file_import.add_submodule_import("typing", "Union", ImportType.STDLIB)
135
+ file_import.add_submodule_import("typing", "Any", ImportType.STDLIB)
136
+ file_import.add_submodule_import("io", "IOBase", ImportType.STDLIB)
137
+ file_import.add_submodule_import("io", "BytesIO", ImportType.STDLIB)
138
+ file_import.add_import("uuid", ImportType.STDLIB)
139
+ file_import.add_mutable_mapping_import()
140
+ file_import.add_submodule_import("._model_base", "Model", ImportType.LOCAL)
133
141
 
134
142
  return template.render(
135
143
  code_model=self.code_model,
@@ -94,7 +94,7 @@ setup(
94
94
  {% elif azure_arm %}
95
95
  "azure-mgmt-core<2.0.0,>=1.3.2",
96
96
  {% else %}
97
- "azure-core<2.0.0,>=1.28.0",
97
+ "azure-core<2.0.0,>=1.29.5",
98
98
  {% endif %}
99
99
  {% if code_model.need_typing_extensions %}
100
100
  "typing-extensions>=4.3.0; python_version<'3.8.0'",
@@ -65,3 +65,33 @@ def prep_if_none_match(etag: Optional[str], match_condition: Optional[MatchCondi
65
65
  return "*"
66
66
  return None
67
67
  {% endif %}
68
+ {% if code_model.has_form_data %}
69
+ class NamedBytesIO(BytesIO):
70
+ def __init__(self, name: str, *args, **kwargs):
71
+ super().__init__(*args, **kwargs)
72
+ self.name = name
73
+
74
+ def multipart_form_data_file(file: Union[IOBase, bytes]) -> IOBase:
75
+ if isinstance(file, IOBase):
76
+ return file
77
+ return NamedBytesIO("auto-name-" + str(uuid.uuid4()), file)
78
+
79
+ def handle_multipart_form_data_model(body: Model) -> MutableMapping[str, Any]: # pylint: disable=unsubscriptable-object
80
+ """handle first layer of model.
81
+ If its value is bytes or IO, replace it with raw value instead of serialized value.
82
+
83
+ :param body: The model to handle.
84
+ :type body: ~payload.multipart._model_base.Model
85
+ :return: The handled model.
86
+ :rtype: MutableMapping[str, Any]
87
+ """
88
+ result = body.as_dict()
89
+ rest_name_attr = {v._rest_name: k for k, v in body._attr_to_rest_field.items()} # pylint: disable=protected-access
90
+ for rest_name in result.keys():
91
+ attr = rest_name_attr.get(rest_name)
92
+ if attr is not None:
93
+ raw_value = getattr(body, attr, None)
94
+ if isinstance(raw_value, (bytes, IOBase)):
95
+ result[rest_name] = raw_value
96
+ return result
97
+ {% endif %}
@@ -171,6 +171,16 @@ def headers_convert(yaml_data: Dict[str, Any], replace_data: Any) -> None:
171
171
  yaml_data[k] = v
172
172
 
173
173
 
174
+ def has_json_content_type(yaml_data: Dict[str, Any]) -> bool:
175
+ return any(ct for ct in yaml_data.get("contentTypes", []) if JSON_REGEXP.match(ct))
176
+
177
+
178
+ def has_multi_part_content_type(yaml_data: Dict[str, Any]) -> bool:
179
+ return any(
180
+ ct for ct in yaml_data.get("contentTypes", []) if ct == "multipart/form-data"
181
+ )
182
+
183
+
174
184
  class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
175
185
  """Add Python naming information."""
176
186
 
@@ -195,13 +205,13 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
195
205
  code_model: Dict[str, Any],
196
206
  body_parameter: Dict[str, Any],
197
207
  ):
198
- if (
208
+ # only add overload for special content type
209
+ if ( # pylint: disable=too-many-boolean-expressions
199
210
  body_parameter
200
211
  and body_parameter["type"]["type"] in ("model", "dict", "list")
201
- and any(
202
- ct
203
- for ct in body_parameter.get("contentTypes", [])
204
- if JSON_REGEXP.match(ct)
212
+ and (
213
+ has_json_content_type(body_parameter)
214
+ or (self.is_cadl and has_multi_part_content_type(body_parameter))
205
215
  )
206
216
  and not body_parameter["type"].get("xmlMetadata")
207
217
  and not any(t for t in ["flattened", "groupedBy"] if body_parameter.get(t))
@@ -210,8 +220,12 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
210
220
  is_dpg_model = body_parameter["type"].get("base") == "dpg"
211
221
  body_parameter["type"] = {
212
222
  "type": "combined",
213
- "types": [body_parameter["type"], KNOWN_TYPES["binary"]],
223
+ "types": [body_parameter["type"]],
214
224
  }
225
+ # don't add binary overload for multipart content type
226
+ if not (self.is_cadl and has_multi_part_content_type(body_parameter)):
227
+ body_parameter["type"]["types"].append(KNOWN_TYPES["binary"])
228
+
215
229
  if origin_type == "model" and is_dpg_model and self.models_mode == "dpg":
216
230
  body_parameter["type"]["types"].insert(1, KNOWN_TYPES["any-object"])
217
231
  code_model["types"].append(body_parameter["type"])
@@ -371,9 +385,9 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
371
385
  for prop, param_name in yaml_data["propertyToParameterName"].items()
372
386
  }
373
387
  wire_name_lower = (yaml_data.get("wireName") or "").lower()
374
- if (
375
- yaml_data["location"] == "header"
376
- and wire_name_lower in HEADERS_HIDE_IN_METHOD
388
+ if yaml_data["location"] == "header" and (
389
+ wire_name_lower in HEADERS_HIDE_IN_METHOD
390
+ or yaml_data.get("clientDefaultValue") == "multipart/form-data"
377
391
  ):
378
392
  yaml_data["hideInMethod"] = True
379
393
  if (
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@autorest/python",
3
- "version": "6.11.0",
3
+ "version": "6.11.1",
4
4
  "description": "The Python extension for generators in AutoRest.",
5
5
  "main": "index.js",
6
6
  "repository": {