@autorest/python 6.12.3 → 6.13.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -73,7 +73,7 @@ def parse_args(
73
73
  dest="debug",
74
74
  help="Debug mode",
75
75
  required=False,
76
- action="store_true",
76
+ action="store",
77
77
  )
78
78
  args, unknown_args = parser.parse_known_args()
79
79
 
@@ -43,7 +43,6 @@ from .parameter import (
43
43
  ParameterLocation,
44
44
  BodyParameter,
45
45
  ParameterDelimeter,
46
- MultipartBodyParameter,
47
46
  ClientParameter,
48
47
  ConfigParameter,
49
48
  )
@@ -115,7 +114,6 @@ __all__ = [
115
114
  "BodyParameter",
116
115
  "RequestBuilderBodyParameter",
117
116
  "ParameterDelimeter",
118
- "MultipartBodyParameter",
119
117
  "CredentialType",
120
118
  "ClientParameter",
121
119
  "ConfigParameter",
@@ -135,9 +133,9 @@ TYPE_TO_OBJECT = {
135
133
  "binary": BinaryType,
136
134
  "any": AnyType,
137
135
  "datetime": DatetimeType,
138
- "time": TimeType,
136
+ "plainTime": TimeType,
139
137
  "duration": DurationType,
140
- "date": DateType,
138
+ "plainDate": DateType,
141
139
  "bytes": ByteArraySchema,
142
140
  "boolean": BooleanType,
143
141
  "combined": CombinedType,
@@ -186,3 +186,7 @@ class BaseType(BaseModel, ABC): # pylint: disable=too-many-public-methods
186
186
  @property
187
187
  def type_description(self) -> str:
188
188
  return self.type_annotation()
189
+
190
+ @property
191
+ def is_form_data(self) -> bool:
192
+ return False
@@ -4,7 +4,17 @@
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
6
  import logging
7
- from typing import List, Dict, Any, Generic, TypeVar, Optional, Union, TYPE_CHECKING
7
+ from typing import (
8
+ List,
9
+ Dict,
10
+ Any,
11
+ Generic,
12
+ TypeVar,
13
+ Optional,
14
+ Union,
15
+ TYPE_CHECKING,
16
+ cast,
17
+ )
8
18
  from abc import abstractmethod
9
19
 
10
20
  from .base import BaseModel
@@ -28,11 +38,16 @@ if TYPE_CHECKING:
28
38
  from .operation import Operation
29
39
  from .request_builder import RequestBuilder
30
40
 
41
+
42
+ OverloadListType = TypeVar(
43
+ "OverloadListType", bound=Union[List["Operation"], List["RequestBuilder"]]
44
+ )
45
+
31
46
  _LOGGER = logging.getLogger(__name__)
32
47
 
33
48
 
34
49
  class BaseBuilder(
35
- Generic[ParameterListType], BaseModel
50
+ Generic[ParameterListType, OverloadListType], BaseModel
36
51
  ): # pylint: disable=too-many-instance-attributes
37
52
  """Base class for Operations and Request Builders"""
38
53
 
@@ -44,16 +59,14 @@ class BaseBuilder(
44
59
  name: str,
45
60
  parameters: ParameterListType,
46
61
  *,
47
- overloads=None,
62
+ overloads: Optional[OverloadListType] = None,
48
63
  ) -> None:
49
64
  super().__init__(yaml_data=yaml_data, code_model=code_model)
50
65
  self.client = client
51
66
  self.name = name
52
67
  self._description: str = yaml_data.get("description", "")
53
68
  self.parameters = parameters
54
- self.overloads: Union[List["Operation"], List["RequestBuilder"]] = (
55
- overloads or []
56
- )
69
+ self.overloads = overloads or cast(OverloadListType, [])
57
70
  self._summary: str = yaml_data.get("summary", "")
58
71
  self.want_tracing: bool = yaml_data.get("wantTracing", True)
59
72
  self.group_name: str = yaml_data[
@@ -11,7 +11,6 @@ from .model_type import ModelType
11
11
  from .combined_type import CombinedType
12
12
  from .client import Client
13
13
  from .request_builder import RequestBuilder, OverloadedRequestBuilder
14
- from .constant_type import ConstantType
15
14
 
16
15
 
17
16
  def _is_legacy(options) -> bool:
@@ -279,25 +278,6 @@ class CodeModel: # pylint: disable=too-many-public-methods, disable=too-many-in
279
278
 
280
279
  @property
281
280
  def need_typing_extensions(self) -> bool:
282
- if self.options["models_mode"] and any(
283
- isinstance(p.type, ConstantType)
284
- and (p.optional or self.options["models_mode"] == "dpg")
285
- for model in self.model_types
286
- for p in model.properties
287
- ):
288
- return True
289
- if any(
290
- isinstance(parameter.type, ConstantType)
291
- for client in self.clients
292
- for og in client.operation_groups
293
- for op in og.operations
294
- for parameter in op.parameters.method
295
- ):
296
- return True
297
- if any(
298
- isinstance(parameter.type, ConstantType)
299
- for client in self.clients
300
- for parameter in client.config.parameters.kwargs_to_pop
301
- ):
281
+ if self.options["models_mode"] == "dpg":
302
282
  return True
303
283
  return False
@@ -79,6 +79,10 @@ class CombinedType(BaseType):
79
79
  pattern = re.compile(r"Union\[.*\]")
80
80
  return f'Union[{", ".join(map(lambda x: x[6: -1] if pattern.match(x) else x, inside_types))}]'
81
81
 
82
+ @property
83
+ def is_form_data(self) -> bool:
84
+ return any(t.is_form_data for t in self.types)
85
+
82
86
  def get_json_template_representation(
83
87
  self,
84
88
  *,
@@ -76,6 +76,10 @@ class ModelType( # pylint: disable=abstract-method
76
76
  self.snake_case_name: str = self.yaml_data["snakeCaseName"]
77
77
  self.page_result_model: bool = self.yaml_data.get("pageResultModel", False)
78
78
 
79
+ @property
80
+ def is_form_data(self) -> bool:
81
+ return any(p.is_multipart_file_input for p in self.properties)
82
+
79
83
  @property
80
84
  def is_xml(self) -> bool:
81
85
  return self.yaml_data.get("isXml", False)
@@ -314,6 +318,15 @@ class GeneratedModelType(ModelType): # pylint: disable=abstract-method
314
318
  if kwargs.get("model_typing")
315
319
  else TypingSection.REGULAR,
316
320
  )
321
+ if self.is_form_data:
322
+ file_import.add_submodule_import(
323
+ relative_path,
324
+ "_model_base",
325
+ ImportType.LOCAL,
326
+ typing_section=TypingSection.TYPING
327
+ if kwargs.get("model_typing")
328
+ else TypingSection.REGULAR,
329
+ )
317
330
  return file_import
318
331
 
319
332
 
@@ -14,6 +14,7 @@ from typing import (
14
14
  Generic,
15
15
  TypeVar,
16
16
  cast,
17
+ Sequence,
17
18
  )
18
19
 
19
20
  from .request_builder_parameter import RequestBuilderParameter
@@ -30,7 +31,6 @@ from .response import (
30
31
  )
31
32
  from .parameter import (
32
33
  BodyParameter,
33
- MultipartBodyParameter,
34
34
  Parameter,
35
35
  ParameterLocation,
36
36
  )
@@ -55,7 +55,7 @@ def is_internal(target: Optional[BaseType]) -> bool:
55
55
 
56
56
 
57
57
  class OperationBase( # pylint: disable=too-many-public-methods
58
- Generic[ResponseType], BaseBuilder[ParameterList]
58
+ Generic[ResponseType], BaseBuilder[ParameterList, List["Operation"]]
59
59
  ):
60
60
  def __init__(
61
61
  self,
@@ -290,7 +290,6 @@ class OperationBase( # pylint: disable=too-many-public-methods
290
290
  Parameter,
291
291
  RequestBuilderParameter,
292
292
  BodyParameter,
293
- MultipartBodyParameter,
294
293
  ]
295
294
  ],
296
295
  location: ParameterLocation,
@@ -492,7 +491,7 @@ class OperationBase( # pylint: disable=too-many-public-methods
492
491
  ) from exc
493
492
 
494
493
  @property
495
- def success_status_codes(self) -> List[Union[str, int]]:
494
+ def success_status_codes(self) -> Sequence[Union[str, int]]:
496
495
  """The list of all successfull status code."""
497
496
  return sorted(
498
497
  [code for response in self.responses for code in response.status_codes]
@@ -575,33 +574,13 @@ class Operation(OperationBase[Response]):
575
574
  relative_path = "..." if async_mode else ".."
576
575
  if self.code_model.options["models_mode"] == "dpg":
577
576
  if self.parameters.has_body:
578
- if not self.parameters.body_parameter.is_form_data:
577
+ if not self.has_form_data_body:
579
578
  file_import.add_submodule_import(
580
579
  f"{relative_path}_model_base",
581
580
  "SdkJSONEncoder",
582
581
  ImportType.LOCAL,
583
582
  )
584
583
  file_import.add_import("json", ImportType.STDLIB)
585
- else:
586
- file_import.add_submodule_import(
587
- relative_path, "_model_base", ImportType.LOCAL
588
- )
589
- file_import.add_submodule_import("io", "IOBase", ImportType.STDLIB)
590
- file_import.add_submodule_import(
591
- f"{relative_path}_vendor",
592
- "multipart_file",
593
- ImportType.LOCAL,
594
- )
595
- file_import.add_submodule_import(
596
- f"{relative_path}_vendor",
597
- "multipart_data",
598
- ImportType.LOCAL,
599
- )
600
- file_import.add_submodule_import(
601
- f"{relative_path}_vendor",
602
- "handle_multipart_form_data_model",
603
- ImportType.LOCAL,
604
- )
605
584
  if self.default_error_deserialization or any(
606
585
  r.type for r in self.responses
607
586
  ):
@@ -14,7 +14,6 @@ from typing import (
14
14
  Optional,
15
15
  TypeVar,
16
16
  Union,
17
- Generic,
18
17
  )
19
18
 
20
19
  from .imports import FileImport, ImportType, TypingSection
@@ -243,9 +242,18 @@ class _ParameterBase(
243
242
  class BodyParameter(_ParameterBase):
244
243
  """Body parameter."""
245
244
 
245
+ @property
246
+ def entries(self) -> List["BodyParameter"]:
247
+ return [
248
+ BodyParameter.from_yaml(e, self.code_model)
249
+ for e in self.yaml_data.get("entries", [])
250
+ ]
251
+
246
252
  @property
247
253
  def is_form_data(self) -> bool:
248
- return self.default_content_type == "multipart/form-data"
254
+ # hacky, but rn in legacy, there is no formdata model type, it's just a dict
255
+ # with all of the entries splatted out
256
+ return self.type.is_form_data or bool(self.entries)
249
257
 
250
258
  @property
251
259
  def is_partial_body(self) -> bool:
@@ -262,6 +270,10 @@ class BodyParameter(_ParameterBase):
262
270
 
263
271
  @property
264
272
  def in_method_signature(self) -> bool:
273
+ if self.yaml_data.get("entries"):
274
+ # Right now, only legacy generates with multipart bodies and entries
275
+ # and legacy generates with the multipart body arguments splatted out
276
+ return False
265
277
  return not (self.flattened or self.grouped_by)
266
278
 
267
279
  @property
@@ -278,6 +290,18 @@ class BodyParameter(_ParameterBase):
278
290
  return self.type.target_model_subtype((JSONModelType,)) is not None
279
291
  return isinstance(self.type, JSONModelType)
280
292
 
293
+ def imports(self, async_mode: bool, **kwargs: Any) -> FileImport:
294
+ file_import = super().imports(async_mode, **kwargs)
295
+ if self.is_form_data:
296
+ relative_path = "..." if async_mode else ".."
297
+ file_import.add_submodule_import(
298
+ f"{relative_path}_vendor",
299
+ "prepare_multipart_form_data",
300
+ ImportType.LOCAL,
301
+ )
302
+ file_import.add_submodule_import("typing", "List", ImportType.STDLIB)
303
+ return file_import
304
+
281
305
  @classmethod
282
306
  def from_yaml(
283
307
  cls, yaml_data: Dict[str, Any], code_model: "CodeModel"
@@ -294,46 +318,6 @@ EntryBodyParameterType = TypeVar(
294
318
  )
295
319
 
296
320
 
297
- class _MultipartBodyParameter(Generic[EntryBodyParameterType], BodyParameter):
298
- """Base class for MultipartBodyParameter and RequestBuilderMultipartBodyParameter"""
299
-
300
- def __init__(
301
- self,
302
- yaml_data: Dict[str, Any],
303
- code_model: "CodeModel",
304
- type: BaseType,
305
- entries: List[EntryBodyParameterType],
306
- ) -> None:
307
- super().__init__(yaml_data, code_model, type)
308
- self.entries = entries
309
-
310
- @property
311
- def in_method_signature(self) -> bool:
312
- # Right now, only legacy generates with multipart bodies
313
- # and legacy generates with the multipart body arguments splatted out
314
- return False
315
-
316
-
317
- class MultipartBodyParameter(
318
- _MultipartBodyParameter[BodyParameter] # pylint: disable=unsubscriptable-object
319
- ):
320
- """Multipart body parameter for Operation. Used for files and data input."""
321
-
322
- @classmethod
323
- def from_yaml(
324
- cls, yaml_data: Dict[str, Any], code_model: "CodeModel"
325
- ) -> "MultipartBodyParameter":
326
- return cls(
327
- yaml_data=yaml_data,
328
- code_model=code_model,
329
- type=code_model.lookup_type(id(yaml_data["type"])),
330
- entries=[
331
- BodyParameter.from_yaml(entry, code_model)
332
- for entry in yaml_data["entries"]
333
- ],
334
- )
335
-
336
-
337
321
  class Parameter(_ParameterBase):
338
322
  """Basic Parameter class"""
339
323
 
@@ -455,12 +439,3 @@ class ConfigParameter(Parameter):
455
439
  if self.constant:
456
440
  return ParameterMethodLocation.KWARG
457
441
  return ParameterMethodLocation.POSITIONAL
458
-
459
-
460
- def get_body_parameter(
461
- yaml_data: Dict[str, Any], code_model: "CodeModel"
462
- ) -> Union[BodyParameter, MultipartBodyParameter]:
463
- """Creates a regular body parameter or Multipart body parameter"""
464
- if yaml_data.get("entries"):
465
- return MultipartBodyParameter.from_yaml(yaml_data, code_model)
466
- return BodyParameter.from_yaml(yaml_data, code_model)
@@ -22,19 +22,15 @@ from enum import Enum
22
22
 
23
23
  from .request_builder_parameter import (
24
24
  RequestBuilderBodyParameter,
25
- RequestBuilderMultipartBodyParameter,
26
25
  RequestBuilderParameter,
27
- get_request_body_parameter,
28
26
  )
29
27
  from .parameter import (
30
- MultipartBodyParameter,
31
28
  ParameterLocation,
32
29
  BodyParameter,
33
30
  Parameter,
34
31
  ParameterMethodLocation,
35
32
  ClientParameter,
36
33
  ConfigParameter,
37
- get_body_parameter,
38
34
  )
39
35
 
40
36
  ParameterType = TypeVar(
@@ -43,10 +39,6 @@ ParameterType = TypeVar(
43
39
  BodyParameterType = TypeVar(
44
40
  "BodyParameterType", bound=Union[BodyParameter, RequestBuilderBodyParameter]
45
41
  )
46
- RequestBuilderBodyParameterType = Union[
47
- RequestBuilderBodyParameter, RequestBuilderMultipartBodyParameter
48
- ]
49
-
50
42
 
51
43
  if TYPE_CHECKING:
52
44
  from .code_model import CodeModel
@@ -314,7 +306,7 @@ class _ParameterListBase(
314
306
 
315
307
  class _ParameterList(
316
308
  _ParameterListBase[ # pylint: disable=unsubscriptable-object
317
- Parameter, Union[MultipartBodyParameter, BodyParameter]
309
+ Parameter, BodyParameter
318
310
  ]
319
311
  ):
320
312
  """Base Parameter class for the two operation ParameterLists"""
@@ -325,11 +317,9 @@ class _ParameterList(
325
317
 
326
318
  @staticmethod
327
319
  def body_parameter_creator() -> (
328
- Callable[
329
- [Dict[str, Any], "CodeModel"], Union[MultipartBodyParameter, BodyParameter]
330
- ]
320
+ Callable[[Dict[str, Any], "CodeModel"], BodyParameter]
331
321
  ):
332
- return get_body_parameter
322
+ return BodyParameter.from_yaml
333
323
 
334
324
  @property
335
325
  def implementation(self) -> str:
@@ -348,7 +338,7 @@ class ParameterList(_ParameterList):
348
338
 
349
339
  class _RequestBuilderParameterList(
350
340
  _ParameterListBase[ # pylint: disable=unsubscriptable-object
351
- RequestBuilderParameter, RequestBuilderBodyParameterType
341
+ RequestBuilderParameter, RequestBuilderBodyParameter
352
342
  ]
353
343
  ):
354
344
  """_RequestBuilderParameterList is base parameter list for RequestBuilder classes"""
@@ -361,9 +351,9 @@ class _RequestBuilderParameterList(
361
351
 
362
352
  @staticmethod
363
353
  def body_parameter_creator() -> (
364
- Callable[[Dict[str, Any], "CodeModel"], RequestBuilderBodyParameterType]
354
+ Callable[[Dict[str, Any], "CodeModel"], RequestBuilderBodyParameter]
365
355
  ):
366
- return get_request_body_parameter
356
+ return RequestBuilderBodyParameter.from_yaml
367
357
 
368
358
  @property
369
359
  def implementation(self) -> str:
@@ -372,14 +362,14 @@ class _RequestBuilderParameterList(
372
362
  @property
373
363
  def unsorted_method_params(
374
364
  self,
375
- ) -> List[Union[RequestBuilderParameter, RequestBuilderBodyParameterType]]:
365
+ ) -> List[Union[RequestBuilderParameter, RequestBuilderBodyParameter]]:
376
366
  # don't have access to client params in request builder
377
367
  retval = [
378
368
  p
379
369
  for p in super().unsorted_method_params
380
370
  if not (
381
371
  p.location == ParameterLocation.BODY
382
- and cast(RequestBuilderBodyParameterType, p).is_partial_body
372
+ and cast(RequestBuilderBodyParameter, p).is_partial_body
383
373
  )
384
374
  ]
385
375
  retval.extend(
@@ -400,7 +390,7 @@ class _RequestBuilderParameterList(
400
390
  @property
401
391
  def constant(
402
392
  self,
403
- ) -> List[Union[RequestBuilderParameter, RequestBuilderBodyParameterType]]:
393
+ ) -> List[Union[RequestBuilderParameter, RequestBuilderBodyParameter]]:
404
394
  """All constant parameters"""
405
395
  return [
406
396
  p for p in super().constant if p.location != ParameterLocation.ENDPOINT_PATH
@@ -3,7 +3,7 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import Any, Dict, Optional, TYPE_CHECKING, List
6
+ from typing import Any, Dict, Optional, TYPE_CHECKING, List, cast
7
7
 
8
8
  from .base import BaseModel
9
9
  from .constant_type import ConstantType
@@ -37,6 +37,9 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
37
37
  if self.client_default_value is None:
38
38
  self.client_default_value = self.type.client_default_value
39
39
  self.flattened_names: List[str] = yaml_data.get("flattenedNames", [])
40
+ self.is_multipart_file_input: bool = yaml_data.get(
41
+ "isMultipartFileInput", False
42
+ )
40
43
 
41
44
  @property
42
45
  def pylint_disable(self) -> str:
@@ -93,18 +96,31 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
93
96
  def is_enum_discriminator(self) -> bool:
94
97
  return self.is_discriminator and self.type.type == "enum"
95
98
 
96
- def type_annotation(self, *, is_operation_file: bool = False) -> str:
99
+ @property
100
+ def is_base_discriminator(self) -> bool:
101
+ """If this discriminator is on the base model for polymorphic inheritance"""
97
102
  if self.is_enum_discriminator:
98
- # here we are the enum discriminator property on the base model
99
- return "Literal[None]"
103
+ return self.is_polymorphic and self.client_default_value is None
104
+ return (
105
+ self.is_discriminator
106
+ and self.is_polymorphic
107
+ and cast(ConstantType, self.type).value is None
108
+ )
109
+
110
+ def type_annotation(self, *, is_operation_file: bool = False) -> str:
111
+ types_type_annotation = self.type.type_annotation(
112
+ is_operation_file=is_operation_file
113
+ )
114
+ if self.is_multipart_file_input:
115
+ # we only support FileType or list of FileType
116
+ types_type_annotation = types_type_annotation.replace("bytes", "FileType")
117
+ if self.is_base_discriminator:
118
+ return "str"
100
119
  if self.optional and self.client_default_value is None:
101
- return f"Optional[{self.type.type_annotation(is_operation_file=is_operation_file)}]"
102
- return self.type.type_annotation(is_operation_file=is_operation_file)
120
+ return f"Optional[{types_type_annotation}]"
121
+ return types_type_annotation
103
122
 
104
123
  def get_declaration(self, value: Any = None) -> Any:
105
- if self.is_enum_discriminator:
106
- # here we are the enum discriminator property on the base model
107
- return None
108
124
  return self.type.get_declaration(value)
109
125
 
110
126
  def get_json_template_representation(
@@ -114,6 +130,8 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
114
130
  client_default_value_declaration: Optional[str] = None,
115
131
  description: Optional[str] = None,
116
132
  ) -> Any:
133
+ if self.is_multipart_file_input:
134
+ return "[filetype]" if self.type.type == "list" else "filetype"
117
135
  if self.client_default_value:
118
136
  client_default_value_declaration = self.get_declaration(
119
137
  self.client_default_value
@@ -161,6 +179,8 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
161
179
  "rest_discriminator" if self.is_discriminator else "rest_field",
162
180
  ImportType.LOCAL,
163
181
  )
182
+ if self.is_multipart_file_input:
183
+ file_import.add_submodule_import(".._vendor", "FileType", ImportType.LOCAL)
164
184
  return file_import
165
185
 
166
186
  @classmethod
@@ -33,7 +33,7 @@ ParameterListType = TypeVar(
33
33
  )
34
34
 
35
35
 
36
- class RequestBuilderBase(BaseBuilder[ParameterListType]):
36
+ class RequestBuilderBase(BaseBuilder[ParameterListType, List["RequestBuilder"]]):
37
37
  def __init__(
38
38
  self,
39
39
  yaml_data: Dict[str, Any],
@@ -3,13 +3,12 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import TYPE_CHECKING, Any, Dict, Union
6
+ from typing import TYPE_CHECKING, Any, Dict
7
7
  from .parameter import (
8
8
  ParameterLocation,
9
9
  ParameterMethodLocation,
10
10
  Parameter,
11
11
  BodyParameter,
12
- _MultipartBodyParameter,
13
12
  )
14
13
  from .base import BaseType
15
14
  from .primitive_types import BinaryType, StringType
@@ -67,32 +66,6 @@ class RequestBuilderBodyParameter(BodyParameter):
67
66
  return "_content"
68
67
 
69
68
 
70
- class RequestBuilderMultipartBodyParameter(
71
- _MultipartBodyParameter[ # pylint: disable=unsubscriptable-object
72
- RequestBuilderBodyParameter
73
- ]
74
- ):
75
- """Multipart body parameter for Request BUilders"""
76
-
77
- @property
78
- def name_in_high_level_operation(self) -> str:
79
- return f"_{self.client_name}"
80
-
81
- @classmethod
82
- def from_yaml(
83
- cls, yaml_data: Dict[str, Any], code_model: "CodeModel"
84
- ) -> "RequestBuilderMultipartBodyParameter":
85
- return cls(
86
- yaml_data=yaml_data,
87
- code_model=code_model,
88
- type=code_model.lookup_type(id(yaml_data["type"])),
89
- entries=[
90
- RequestBuilderBodyParameter.from_yaml(entry, code_model)
91
- for entry in yaml_data["entries"]
92
- ],
93
- )
94
-
95
-
96
69
  class RequestBuilderParameter(Parameter):
97
70
  """Basic RequestBuilder Parameter."""
98
71
 
@@ -149,12 +122,3 @@ class RequestBuilderParameter(Parameter):
149
122
  if self.implementation == "Client":
150
123
  return f"self._config.{self.client_name}"
151
124
  return self.client_name
152
-
153
-
154
- def get_request_body_parameter(
155
- yaml_data: Dict[str, Any], code_model: "CodeModel"
156
- ) -> Union[RequestBuilderBodyParameter, RequestBuilderMultipartBodyParameter]:
157
- """Get body parameter for a request builder"""
158
- if yaml_data.get("entries"):
159
- return RequestBuilderMultipartBodyParameter.from_yaml(yaml_data, code_model)
160
- return RequestBuilderBodyParameter.from_yaml(yaml_data, code_model)
@@ -26,7 +26,6 @@ from ..models import (
26
26
  ParameterMethodLocation,
27
27
  RequestBuilderBodyParameter,
28
28
  OverloadedRequestBuilder,
29
- MultipartBodyParameter,
30
29
  Property,
31
30
  RequestBuilderType,
32
31
  CombinedType,
@@ -207,7 +206,7 @@ def _serialize_json_model_body(
207
206
 
208
207
  def _serialize_multipart_body(builder: BuilderType) -> List[str]:
209
208
  retval: List[str] = []
210
- body_param = cast(MultipartBodyParameter, builder.parameters.body_parameter)
209
+ body_param = builder.parameters.body_parameter
211
210
  # we have to construct our form data before passing to the request as well
212
211
  retval.append("# Construct form data")
213
212
  retval.append(f"_{body_param.client_name} = {{")
@@ -349,7 +348,11 @@ class _BuilderBaseSerializer(Generic[BuilderType]): # pylint: disable=abstract-
349
348
  def param_description(self, builder: BuilderType) -> List[str]:
350
349
  description_list: List[str] = []
351
350
  for param in builder.parameters.method:
352
- if not param.in_docstring or param.hide_in_operation_signature:
351
+ if (
352
+ not param.in_docstring
353
+ or param.hide_in_operation_signature
354
+ or param.method_location == ParameterMethodLocation.KWARG
355
+ ):
353
356
  continue
354
357
  description_list.extend(
355
358
  f":{param.description_keyword} {param.client_name}: {param.description}".replace(
@@ -748,17 +751,39 @@ class _OperationSerializer(
748
751
 
749
752
  This function serializes the body params that need to be serialized.
750
753
  """
751
- body_param = cast(BodyParameter, builder.parameters.body_parameter)
754
+ retval: List[str] = []
755
+ body_param = builder.parameters.body_parameter
752
756
  if body_param.is_form_data:
753
- return [
754
- f"if isinstance({body_param.client_name}, _model_base.Model):",
755
- f" _body = handle_multipart_form_data_model({body_param.client_name})",
756
- "else:",
757
- f" _body = {body_param.client_name}",
758
- "_files = {k: multipart_file(v) for k, v in _body.items() if isinstance(v, (IOBase, bytes))}",
759
- "_data = {k: multipart_data(v) for k, v in _body.items() if not isinstance(v, (IOBase, bytes))}",
757
+ model_type = cast(
758
+ ModelType,
759
+ (
760
+ body_param.type.target_model_subtype((JSONModelType, DPGModelType))
761
+ if isinstance(body_param.type, CombinedType)
762
+ else body_param.type
763
+ ),
764
+ )
765
+ file_fields = [
766
+ p.wire_name for p in model_type.properties if p.is_multipart_file_input
760
767
  ]
761
- retval: List[str] = []
768
+ data_fields = [
769
+ p.wire_name
770
+ for p in model_type.properties
771
+ if not p.is_multipart_file_input
772
+ ]
773
+ retval.extend(
774
+ [
775
+ "_body = (",
776
+ f" {body_param.client_name}.as_dict()",
777
+ f" if isinstance({body_param.client_name}, _model_base.Model) else",
778
+ f" {body_param.client_name}",
779
+ ")",
780
+ f"_file_fields: List[str] = {file_fields}",
781
+ f"_data_fields: List[str] = {data_fields}",
782
+ "_files, _data = prepare_multipart_form_data(_body, _file_fields, _data_fields)",
783
+ ]
784
+ )
785
+ return retval
786
+
762
787
  body_kwarg_name = builder.request_builder.parameters.body_parameter.client_name
763
788
  send_xml = builder.parameters.body_parameter.type.is_xml
764
789
  xml_serialization_ctxt = (
@@ -810,8 +835,8 @@ class _OperationSerializer(
810
835
  ) -> List[str]:
811
836
  """Create the body parameter before we pass it as either json or content to the request builder"""
812
837
  retval = []
813
- body_param = cast(BodyParameter, builder.parameters.body_parameter)
814
- if hasattr(body_param, "entries"):
838
+ body_param = builder.parameters.body_parameter
839
+ if body_param.entries:
815
840
  return _serialize_multipart_body(builder)
816
841
  body_kwarg_name = builder.request_builder.parameters.body_parameter.client_name
817
842
  body_param_type = body_param.type
@@ -972,9 +997,13 @@ class _OperationSerializer(
972
997
  f" {parameter.client_name}={parameter.name_in_high_level_operation},"
973
998
  f"{' # type: ignore' if type_ignore else ''}"
974
999
  )
975
- if request_builder.has_form_data_body:
976
- retval.append(" data=_data,")
1000
+ if builder.parameters.has_body and builder.parameters.body_parameter.entries:
1001
+ # this is for legacy
1002
+ client_name = builder.parameters.body_parameter.client_name
1003
+ retval.append(f" {client_name}=_{client_name},")
1004
+ elif request_builder.has_form_data_body:
977
1005
  retval.append(" files=_files,")
1006
+ retval.append(" data=_data,")
978
1007
  elif request_builder.overloads:
979
1008
  seen_body_params = set()
980
1009
  for overload in request_builder.overloads:
@@ -130,18 +130,30 @@ class GeneralSerializer(BaseSerializer):
130
130
  "MatchConditions",
131
131
  ImportType.SDKCORE,
132
132
  )
133
- if self.code_model.has_form_data:
133
+ if (
134
+ self.code_model.has_form_data
135
+ and self.code_model.options["models_mode"] == "dpg"
136
+ ):
137
+ file_import.add_submodule_import("typing", "IO", ImportType.STDLIB)
138
+ file_import.add_submodule_import("typing", "Tuple", ImportType.STDLIB)
134
139
  file_import.add_submodule_import("typing", "Union", ImportType.STDLIB)
140
+ file_import.add_submodule_import("typing", "Optional", ImportType.STDLIB)
141
+ file_import.add_submodule_import("typing", "Mapping", ImportType.STDLIB)
142
+ file_import.add_submodule_import("typing", "Sequence", ImportType.STDLIB)
143
+ file_import.add_submodule_import("typing", "Dict", ImportType.STDLIB)
135
144
  file_import.add_submodule_import("typing", "Any", ImportType.STDLIB)
136
- file_import.add_submodule_import("io", "IOBase", ImportType.STDLIB)
137
- file_import.add_submodule_import("io", "BytesIO", ImportType.STDLIB)
138
- file_import.add_import("uuid", ImportType.STDLIB)
139
- file_import.add_import("json", ImportType.STDLIB)
140
- file_import.add_mutable_mapping_import()
141
- file_import.add_submodule_import("._model_base", "Model", ImportType.LOCAL)
145
+ file_import.add_submodule_import("typing", "List", ImportType.STDLIB)
142
146
  file_import.add_submodule_import(
143
- "._model_base", "SdkJSONEncoder", ImportType.LOCAL
147
+ "._model_base",
148
+ "SdkJSONEncoder",
149
+ ImportType.LOCAL,
144
150
  )
151
+ file_import.add_submodule_import(
152
+ "._model_base",
153
+ "Model",
154
+ ImportType.LOCAL,
155
+ )
156
+ file_import.add_import("json", ImportType.STDLIB)
145
157
 
146
158
  return template.render(
147
159
  code_model=self.code_model,
@@ -197,6 +197,9 @@ class MsrestModelSerializer(_ModelSerializer):
197
197
 
198
198
 
199
199
  class DpgModelSerializer(_ModelSerializer):
200
+ def super_call(self, model: ModelType):
201
+ return f"super().__init__({self.properties_to_pass_to_super(model)})"
202
+
200
203
  def imports(self) -> FileImport:
201
204
  file_import = FileImport(self.code_model)
202
205
  file_import.add_submodule_import(
@@ -238,6 +241,7 @@ class DpgModelSerializer(_ModelSerializer):
238
241
  if not any(
239
242
  p.client_name == pp.client_name
240
243
  and p.type_annotation() == pp.type_annotation()
244
+ and not p.is_base_discriminator
241
245
  for pp in parent_properties
242
246
  )
243
247
  ]
@@ -257,7 +261,10 @@ class DpgModelSerializer(_ModelSerializer):
257
261
  args.append(f"visibility=[{v_list}]")
258
262
  if prop.client_default_value is not None:
259
263
  args.append(f"default={prop.client_default_value_declaration}")
260
- if hasattr(prop.type, "encode") and prop.type.encode: # type: ignore
264
+
265
+ if prop.is_multipart_file_input:
266
+ args.append("is_multipart_file_input=True")
267
+ elif hasattr(prop.type, "encode") and prop.type.encode: # type: ignore
261
268
  args.append(f'format="{prop.type.encode}"') # type: ignore
262
269
 
263
270
  field = "rest_discriminator" if prop.is_discriminator else "rest_field"
@@ -274,7 +281,7 @@ class DpgModelSerializer(_ModelSerializer):
274
281
  def initialize_properties(self, model: ModelType) -> List[str]:
275
282
  init_args = []
276
283
  for prop in self.get_properties_to_declare(model):
277
- if prop.constant or prop.is_discriminator:
284
+ if prop.constant and not prop.is_base_discriminator:
278
285
  init_args.append(
279
286
  f"self.{prop.client_name}: {prop.type_annotation()} = "
280
287
  f"{prop.get_declaration()}"
@@ -286,5 +293,30 @@ class DpgModelSerializer(_ModelSerializer):
286
293
  return [
287
294
  p
288
295
  for p in model.properties
289
- if not p.is_discriminator and not p.constant and p.visibility != ["read"]
296
+ if p.is_base_discriminator
297
+ or not p.is_discriminator
298
+ and not p.constant
299
+ and p.visibility != ["read"]
290
300
  ]
301
+
302
+ @staticmethod
303
+ def properties_to_pass_to_super(model: ModelType) -> str:
304
+ properties_to_pass_to_super = ["*args"]
305
+ for parent in model.parents:
306
+ for prop in model.properties:
307
+ if (
308
+ prop.client_name
309
+ in [
310
+ prop.client_name
311
+ for prop in parent.properties
312
+ if prop.is_base_discriminator
313
+ ]
314
+ and prop.is_discriminator
315
+ and not prop.constant
316
+ and not prop.readonly
317
+ ):
318
+ properties_to_pass_to_super.append(
319
+ f"{prop.client_name}={prop.get_declaration()}"
320
+ )
321
+ properties_to_pass_to_super.append("**kwargs")
322
+ return ", ".join(properties_to_pass_to_super)
@@ -5,7 +5,6 @@
5
5
  # license information.
6
6
  # --------------------------------------------------------------------------
7
7
  # pylint: disable=protected-access, arguments-differ, signature-differs, broad-except
8
- # pyright: reportGeneralTypeIssues=false
9
8
 
10
9
  import calendar
11
10
  import decimal
@@ -16,9 +15,10 @@ import base64
16
15
  import re
17
16
  import copy
18
17
  import typing
19
- import email
18
+ import email.utils
20
19
  from datetime import datetime, date, time, timedelta, timezone
21
20
  from json import JSONEncoder
21
+ from typing_extensions import Self
22
22
  import isodate
23
23
  from {{ code_model.core_library }}.exceptions import DeserializationError
24
24
  from {{ code_model.core_library }}{{ ".utils" if code_model.options["unbranded"] else "" }} import CaseInsensitiveEnumMeta
@@ -35,6 +35,7 @@ _LOGGER = logging.getLogger(__name__)
35
35
  __all__ = ["SdkJSONEncoder", "Model", "rest_field", "rest_discriminator"]
36
36
 
37
37
  TZ_UTC = timezone.utc
38
+ _T = typing.TypeVar("_T")
38
39
 
39
40
 
40
41
  def _timedelta_as_isostr(td: timedelta) -> str:
@@ -242,7 +243,7 @@ def _deserialize_date(attr: typing.Union[str, date]) -> date:
242
243
  # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
243
244
  if isinstance(attr, date):
244
245
  return attr
245
- return isodate.parse_date(attr, defaultmonth=None, defaultday=None)
246
+ return isodate.parse_date(attr, defaultmonth=None, defaultday=None) # type: ignore
246
247
 
247
248
 
248
249
  def _deserialize_time(attr: typing.Union[str, time]) -> time:
@@ -383,8 +384,12 @@ class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=uns
383
384
  except KeyError:
384
385
  return default
385
386
 
386
- @typing.overload # type: ignore
387
- def pop(self, key: str) -> typing.Any: # pylint: disable=no-member
387
+ @typing.overload
388
+ def pop(self, key: str) -> typing.Any:
389
+ ...
390
+
391
+ @typing.overload
392
+ def pop(self, key: str, default: _T) -> _T:
388
393
  ...
389
394
 
390
395
  @typing.overload
@@ -405,8 +410,8 @@ class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=uns
405
410
  def update(self, *args: typing.Any, **kwargs: typing.Any) -> None:
406
411
  self._data.update(*args, **kwargs)
407
412
 
408
- @typing.overload # type: ignore
409
- def setdefault(self, key: str) -> typing.Any:
413
+ @typing.overload
414
+ def setdefault(self, key: str, default: None = None) -> None:
410
415
  ...
411
416
 
412
417
  @typing.overload
@@ -470,7 +475,13 @@ def _get_rest_field(
470
475
 
471
476
 
472
477
  def _create_value(rf: typing.Optional["_RestField"], value: typing.Any) -> typing.Any:
473
- return _deserialize(rf._type, value) if (rf and rf._is_model) else _serialize(value, rf._format if rf else None)
478
+ if not rf:
479
+ return _serialize(value, None)
480
+ if rf._is_multipart_file_input:
481
+ return value
482
+ if rf._is_model:
483
+ return _deserialize(rf._type, value)
484
+ return _serialize(value, rf._format)
474
485
 
475
486
 
476
487
  class Model(_MyMutableMapping):
@@ -506,7 +517,7 @@ class Model(_MyMutableMapping):
506
517
  def copy(self) -> "Model":
507
518
  return Model(self.__dict__)
508
519
 
509
- def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> "Model": # pylint: disable=unused-argument
520
+ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: # pylint: disable=unused-argument
510
521
  # we know the last three classes in mro are going to be 'Model', 'dict', and 'object'
511
522
  mros = cls.__mro__[:-3][::-1] # ignore model, dict, and object parents, and reverse the mro order
512
523
  attr_to_rest_field: typing.Dict[str, _RestField] = { # map attribute name to rest_field property
@@ -548,7 +559,7 @@ class Model(_MyMutableMapping):
548
559
  exist_discriminators.append(discriminator)
549
560
  mapped_cls = cls.__mapping__.get(
550
561
  data.get(discriminator), cls
551
- ) # pylint: disable=no-member
562
+ ) # pyright: ignore # pylint: disable=no-member
552
563
  if mapped_cls == cls:
553
564
  return cls(data)
554
565
  return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access
@@ -565,9 +576,14 @@ class Model(_MyMutableMapping):
565
576
  if exclude_readonly:
566
577
  readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)]
567
578
  for k, v in self.items():
568
- if exclude_readonly and k in readonly_props: # pyright: ignore[reportUnboundVariable]
579
+ if exclude_readonly and k in readonly_props: # pyright: ignore
569
580
  continue
570
- result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly)
581
+ is_multipart_file_input = False
582
+ try:
583
+ is_multipart_file_input = next(rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k)._is_multipart_file_input
584
+ except StopIteration:
585
+ pass
586
+ result[k] = v if is_multipart_file_input else Model._as_dict_value(v, exclude_readonly=exclude_readonly)
571
587
  return result
572
588
 
573
589
  @staticmethod
@@ -575,10 +591,10 @@ class Model(_MyMutableMapping):
575
591
  if v is None or isinstance(v, _Null):
576
592
  return None
577
593
  if isinstance(v, (list, tuple, set)):
578
- return [
594
+ return type(v)(
579
595
  Model._as_dict_value(x, exclude_readonly=exclude_readonly)
580
596
  for x in v
581
- ]
597
+ )
582
598
  if isinstance(v, dict):
583
599
  return {
584
600
  dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly)
@@ -619,22 +635,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
619
635
  return obj
620
636
  return _deserialize(model_deserializer, obj)
621
637
 
622
- return functools.partial(_deserialize_model, annotation)
638
+ return functools.partial(_deserialize_model, annotation) # pyright: ignore
623
639
  except Exception:
624
640
  pass
625
641
 
626
642
  # is it a literal?
627
643
  try:
628
- if annotation.__origin__ is typing.Literal:
644
+ if annotation.__origin__ is typing.Literal: # pyright: ignore
629
645
  return None
630
646
  except AttributeError:
631
647
  pass
632
648
 
633
649
  # is it optional?
634
650
  try:
635
- if any(a for a in annotation.__args__ if a == type(None)):
651
+ if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore
636
652
  if_obj_deserializer = _get_deserialize_callable_from_annotation(
637
- next(a for a in annotation.__args__ if a != type(None)), module, rf
653
+ next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
638
654
  )
639
655
 
640
656
  def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj):
@@ -647,7 +663,9 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
647
663
  pass
648
664
 
649
665
  if getattr(annotation, "__origin__", None) is typing.Union:
650
- deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__]
666
+ deserializers = [
667
+ _get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__ # pyright: ignore
668
+ ]
651
669
 
652
670
  def _deserialize_with_union(deserializers, obj):
653
671
  for deserializer in deserializers:
@@ -660,8 +678,10 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
660
678
  return functools.partial(_deserialize_with_union, deserializers)
661
679
 
662
680
  try:
663
- if annotation._name == "Dict":
664
- value_deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[1], module, rf)
681
+ if annotation._name == "Dict": # pyright: ignore
682
+ value_deserializer = _get_deserialize_callable_from_annotation(
683
+ annotation.__args__[1], module, rf # pyright: ignore
684
+ )
665
685
 
666
686
  def _deserialize_dict(
667
687
  value_deserializer: typing.Optional[typing.Callable],
@@ -681,8 +701,8 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
681
701
  except (AttributeError, IndexError):
682
702
  pass
683
703
  try:
684
- if annotation._name in ["List", "Set", "Tuple", "Sequence"]:
685
- if len(annotation.__args__) > 1:
704
+ if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore
705
+ if len(annotation.__args__) > 1: # pyright: ignore
686
706
 
687
707
  def _deserialize_multiple_sequence(
688
708
  entry_deserializers: typing.List[typing.Optional[typing.Callable]],
@@ -696,10 +716,12 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
696
716
  )
697
717
 
698
718
  entry_deserializers = [
699
- _get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__
719
+ _get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__ # pyright: ignore
700
720
  ]
701
721
  return functools.partial(_deserialize_multiple_sequence, entry_deserializers)
702
- deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[0], module, rf)
722
+ deserializer = _get_deserialize_callable_from_annotation(
723
+ annotation.__args__[0], module, rf # pyright: ignore
724
+ )
703
725
 
704
726
  def _deserialize_sequence(
705
727
  deserializer: typing.Optional[typing.Callable],
@@ -779,6 +801,7 @@ class _RestField:
779
801
  visibility: typing.Optional[typing.List[str]] = None,
780
802
  default: typing.Any = _UNSET,
781
803
  format: typing.Optional[str] = None,
804
+ is_multipart_file_input: bool = False,
782
805
  ):
783
806
  self._type = type
784
807
  self._rest_name_input = name
@@ -788,6 +811,7 @@ class _RestField:
788
811
  self._is_model = False
789
812
  self._default = default
790
813
  self._format = format
814
+ self._is_multipart_file_input = is_multipart_file_input
791
815
 
792
816
  @property
793
817
  def _rest_name(self) -> str:
@@ -833,8 +857,9 @@ def rest_field(
833
857
  visibility: typing.Optional[typing.List[str]] = None,
834
858
  default: typing.Any = _UNSET,
835
859
  format: typing.Optional[str] = None,
860
+ is_multipart_file_input: bool = False,
836
861
  ) -> typing.Any:
837
- return _RestField(name=name, type=type, visibility=visibility, default=default, format=format)
862
+ return _RestField(name=name, type=type, visibility=visibility, default=default, format=format, is_multipart_file_input=is_multipart_file_input)
838
863
 
839
864
 
840
865
  def rest_discriminator(
@@ -60,7 +60,7 @@
60
60
  {% set initialize_properties = serializer.initialize_properties(model) %}
61
61
  {% if not model.internal and serializer.init_line(model) or initialize_properties %}
62
62
  def __init__(self, *args: Any, **kwargs: Any) -> None:{{ '# pylint: disable=useless-super-delegation' if not initialize_properties else '' }}
63
- super().__init__(*args, **kwargs)
63
+ {{ serializer.super_call(model) }}
64
64
  {% for initialize_property in initialize_properties %}
65
65
  {{ initialize_property }}
66
66
  {% endfor %}
@@ -94,7 +94,10 @@ setup(
94
94
  {% elif azure_arm %}
95
95
  "azure-mgmt-core<2.0.0,>=1.3.2",
96
96
  {% else %}
97
- "azure-core<2.0.0,>=1.29.5",
97
+ "azure-core<2.0.0,>=1.30.0",
98
+ {% endif %}
99
+ {% if code_model.need_typing_extensions %}
100
+ "typing-extensions>=4.6.0",
98
101
  {% endif %}
99
102
  ],
100
103
  {% if package_mode %}
@@ -170,13 +170,6 @@ class RawDeserializer:
170
170
  return None
171
171
 
172
172
 
173
- try:
174
- basestring # type: ignore
175
- unicode_str = unicode # type: ignore
176
- except NameError:
177
- basestring = str
178
- unicode_str = str
179
-
180
173
  _LOGGER = logging.getLogger(__name__)
181
174
 
182
175
  try:
@@ -547,7 +540,7 @@ class Serializer(object):
547
540
  "multiple": lambda x, y: x % y != 0,
548
541
  }
549
542
 
550
- def __init__(self, classes: Optional[Mapping[str, Type[ModelType]]]=None):
543
+ def __init__(self, classes: Optional[Mapping[str, type]]=None):
551
544
  self.serialize_type = {
552
545
  "iso-8601": Serializer.serialize_iso,
553
546
  "rfc-1123": Serializer.serialize_rfc,
@@ -563,7 +556,7 @@ class Serializer(object):
563
556
  "[]": self.serialize_iter,
564
557
  "{}": self.serialize_dict,
565
558
  }
566
- self.dependencies: Dict[str, Type[ModelType]] = dict(classes) if classes else {}
559
+ self.dependencies: Dict[str, type] = dict(classes) if classes else {}
567
560
  self.key_transformer = full_restapi_key_transformer
568
561
  self.client_side_validation = True
569
562
 
@@ -651,7 +644,7 @@ class Serializer(object):
651
644
  else: # That's a basic type
652
645
  # Integrate namespace if necessary
653
646
  local_node = _create_xml_node(xml_name, xml_prefix, xml_ns)
654
- local_node.text = unicode_str(new_attr)
647
+ local_node.text = str(new_attr)
655
648
  serialized.append(local_node) # type: ignore
656
649
  else: # JSON
657
650
  for k in reversed(keys): # type: ignore
@@ -1000,7 +993,7 @@ class Serializer(object):
1000
993
  return self.serialize_basic(attr, self.basic_types[obj_type], **kwargs)
1001
994
  if obj_type is _long_type:
1002
995
  return self.serialize_long(attr)
1003
- if obj_type is unicode_str:
996
+ if obj_type is str:
1004
997
  return self.serialize_unicode(attr)
1005
998
  if obj_type is datetime.datetime:
1006
999
  return self.serialize_iso(attr)
@@ -1376,7 +1369,7 @@ class Deserializer(object):
1376
1369
 
1377
1370
  valid_date = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?")
1378
1371
 
1379
- def __init__(self, classes: Optional[Mapping[str, Type[ModelType]]]=None):
1372
+ def __init__(self, classes: Optional[Mapping[str, type]]=None):
1380
1373
  self.deserialize_type = {
1381
1374
  "iso-8601": Deserializer.deserialize_iso,
1382
1375
  "rfc-1123": Deserializer.deserialize_rfc,
@@ -1396,7 +1389,7 @@ class Deserializer(object):
1396
1389
  "duration": (isodate.Duration, datetime.timedelta),
1397
1390
  "iso-8601": (datetime.datetime),
1398
1391
  }
1399
- self.dependencies: Dict[str, Type[ModelType]] = dict(classes) if classes else {}
1392
+ self.dependencies: Dict[str, type] = dict(classes) if classes else {}
1400
1393
  self.key_extractors = [rest_key_extractor, xml_key_extractor]
1401
1394
  # Additional properties only works if the "rest_key_extractor" is used to
1402
1395
  # extract the keys. Making it to work whatever the key extractor is too much
@@ -1449,7 +1442,7 @@ class Deserializer(object):
1449
1442
 
1450
1443
  response, class_name = self._classify_target(target_obj, data)
1451
1444
 
1452
- if isinstance(response, basestring):
1445
+ if isinstance(response, str):
1453
1446
  return self.deserialize_data(data, response)
1454
1447
  elif isinstance(response, type) and issubclass(response, Enum):
1455
1448
  return self.deserialize_enum(data, response)
@@ -1520,7 +1513,7 @@ class Deserializer(object):
1520
1513
  if target is None:
1521
1514
  return None, None
1522
1515
 
1523
- if isinstance(target, basestring):
1516
+ if isinstance(target, str):
1524
1517
  try:
1525
1518
  target = self.dependencies[target]
1526
1519
  except KeyError:
@@ -1583,7 +1576,7 @@ class Deserializer(object):
1583
1576
  if hasattr(raw_data, "_content_consumed"):
1584
1577
  return RawDeserializer.deserialize_from_http_generics(raw_data.text, raw_data.headers)
1585
1578
 
1586
- if isinstance(raw_data, (basestring, bytes)) or hasattr(raw_data, "read"):
1579
+ if isinstance(raw_data, (str, bytes)) or hasattr(raw_data, "read"):
1587
1580
  return RawDeserializer.deserialize_from_text(raw_data, content_type) # type: ignore
1588
1581
  return raw_data
1589
1582
 
@@ -1705,7 +1698,7 @@ class Deserializer(object):
1705
1698
  if isinstance(attr, ET.Element):
1706
1699
  # Do no recurse on XML, just return the tree as-is
1707
1700
  return attr
1708
- if isinstance(attr, basestring):
1701
+ if isinstance(attr, str):
1709
1702
  return self.deserialize_basic(attr, "str")
1710
1703
  obj_type = type(attr)
1711
1704
  if obj_type in self.basic_types:
@@ -1762,7 +1755,7 @@ class Deserializer(object):
1762
1755
  if data_type == "bool":
1763
1756
  if attr in [True, False, 1, 0]:
1764
1757
  return bool(attr)
1765
- elif isinstance(attr, basestring):
1758
+ elif isinstance(attr, str):
1766
1759
  if attr.lower() in ["true", "1"]:
1767
1760
  return True
1768
1761
  elif attr.lower() in ["false", "0"]:
@@ -65,38 +65,42 @@ def prep_if_none_match(etag: Optional[str], match_condition: Optional[MatchCondi
65
65
  return "*"
66
66
  return None
67
67
  {% endif %}
68
- {% if code_model.has_form_data %}
69
- class NamedBytesIO(BytesIO):
70
- def __init__(self, name: str, *args, **kwargs):
71
- super().__init__(*args, **kwargs)
72
- self.name = name
68
+ {% if code_model.has_form_data and code_model.options["models_mode"] == "dpg" %}
69
+ # file-like tuple could be `(filename, IO (or bytes))` or `(filename, IO (or bytes), content_type)`
70
+ FileContent = Union[str, bytes, IO[str], IO[bytes]]
73
71
 
74
- def multipart_file(file: Union[IOBase, bytes]) -> IOBase:
75
- if isinstance(file, IOBase):
76
- return file
77
- return NamedBytesIO("auto-name-" + str(uuid.uuid4()), file)
72
+ FileType = Union[
73
+ # file (or bytes)
74
+ FileContent,
75
+ # (filename, file (or bytes))
76
+ Tuple[Optional[str], FileContent],
77
+ # (filename, file (or bytes), content_type)
78
+ Tuple[Optional[str], FileContent, Optional[str]],
79
+ ]
78
80
 
79
- def multipart_data(data: Any) -> Any:
80
- if isinstance(data, (list, tuple, dict, Model)):
81
- return json.dumps(data, cls=SdkJSONEncoder, exclude_readonly=True)
82
- return data
81
+ FilesType = Union[Mapping[str, FileType], Sequence[Tuple[str, FileType]]]
83
82
 
84
- def handle_multipart_form_data_model(body: Model) -> MutableMapping[str, Any]: # pylint: disable=unsubscriptable-object
85
- """handle first layer of model.
86
- If its value is bytes or IO, replace it with raw value instead of serialized value.
83
+ def serialize_multipart_data_entry(data_entry: Any) -> Any:
84
+ if isinstance(data_entry, (list, tuple, dict, Model)):
85
+ return json.dumps(data_entry, cls=SdkJSONEncoder, exclude_readonly=True)
86
+ return data_entry
87
87
 
88
- :param body: The model to handle.
89
- :type body: ~payload.multipart._model_base.Model
90
- :return: The handled model.
91
- :rtype: MutableMapping[str, Any]
92
- """
93
- result = body.as_dict()
94
- rest_name_attr = {v._rest_name: k for k, v in body._attr_to_rest_field.items()} # pylint: disable=protected-access
95
- for rest_name in result.keys():
96
- attr = rest_name_attr.get(rest_name)
97
- if attr is not None:
98
- raw_value = getattr(body, attr, None)
99
- if isinstance(raw_value, (bytes, IOBase)):
100
- result[rest_name] = raw_value
101
- return result
102
- {% endif %}
88
+ def prepare_multipart_form_data(
89
+ body: Mapping[str, Any], multipart_fields: List[str], data_fields: List[str]
90
+ ) -> Tuple[List[FileType], Dict[str, Any]]:
91
+ files: List[FileType] = []
92
+ data: Dict[str, Any] = {}
93
+ for multipart_field in multipart_fields:
94
+ multipart_entry = body.get(multipart_field)
95
+ if isinstance(multipart_entry, list):
96
+ files.extend([(multipart_field, e) for e in multipart_entry ])
97
+ elif multipart_entry:
98
+ files.append((multipart_field, multipart_entry))
99
+
100
+ for data_field in data_fields:
101
+ data_entry = body.get(data_field)
102
+ if data_entry:
103
+ data[data_field] = serialize_multipart_data_entry(data_entry)
104
+
105
+ return files, data
106
+ {% endif %}
@@ -240,6 +240,10 @@ def update_primitive( # pylint: disable=too-many-return-statements
240
240
  base = _update_type_base("bytes", yaml_data)
241
241
  base["encode"] = yaml_data["format"]
242
242
  return base
243
+ if type_group == "date":
244
+ return _update_type_base("plainDate", yaml_data)
245
+ if type_group == "time":
246
+ return _update_type_base("plainTime", yaml_data)
243
247
  return _update_type_base(type_group, yaml_data)
244
248
 
245
249
 
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@autorest/python",
3
- "version": "6.12.3",
3
+ "version": "6.13.0",
4
4
  "description": "The Python extension for generators in AutoRest.",
5
5
  "main": "index.js",
6
6
  "repository": {
package/run_cadl.py CHANGED
@@ -22,7 +22,7 @@ if __name__ == "__main__":
22
22
  env_builder = venv.EnvBuilder(with_pip=True)
23
23
  venv_context = env_builder.ensure_directories(venv_path)
24
24
 
25
- if "--debug" in sys.argv:
25
+ if "--debug" in sys.argv or "--debug=true" in sys.argv:
26
26
  try:
27
27
  import debugpy # pylint: disable=import-outside-toplevel
28
28
  except ImportError: