@autorest/python 6.2.3 → 6.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -87,7 +87,6 @@ class BaseType(BaseModel, ABC):
87
87
  return ", ".join(attrs_list)
88
88
 
89
89
  @property
90
- @abstractmethod
91
90
  def serialization_type(self) -> str:
92
91
  """The tag recognized by 'msrest' as a serialization/deserialization.
93
92
 
@@ -100,6 +99,11 @@ class BaseType(BaseModel, ABC):
100
99
  If dict: '{str}'
101
100
  """
102
101
  ...
102
+ raise NotImplementedError()
103
+
104
+ @property
105
+ def msrest_deserialization_key(self) -> str:
106
+ return self.serialization_type
103
107
 
104
108
  @property
105
109
  def client_default_value(self) -> Any:
@@ -74,7 +74,9 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
74
74
  ]
75
75
  self.format_lro_operations()
76
76
 
77
- def _build_request_builders(self):
77
+ def _build_request_builders(
78
+ self,
79
+ ) -> List[Union[RequestBuilder, OverloadedRequestBuilder]]:
78
80
  request_builders: List[Union[RequestBuilder, OverloadedRequestBuilder]] = []
79
81
  for og_group in self.yaml_data["operationGroups"]:
80
82
  for operation_yaml in og_group["operations"]:
@@ -172,10 +174,7 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
172
174
  )
173
175
 
174
176
  for gp in self.parameters:
175
- if (
176
- gp.method_location == ParameterMethodLocation.KWARG
177
- and gp not in self.parameters.kwargs_to_pop
178
- ):
177
+ if gp.method_location == ParameterMethodLocation.KWARG:
179
178
  continue
180
179
  file_import.merge(gp.imports(async_mode))
181
180
  file_import.add_submodule_import(
@@ -288,23 +287,9 @@ class Client(_ClientConfigBase[ClientGlobalParameterList]):
288
287
  and self.code_model.options["models_mode"] == "msrest"
289
288
  ):
290
289
  path_to_models = ".." if async_mode else "."
291
- if len(self.code_model.model_types) != len(
292
- self.code_model.public_model_types
293
- ):
294
- # this means we have hidden models. In that case, we import directly from the models
295
- # file, not the module, bc we don't expose the hidden models in the models module
296
-
297
- # Also in this case, we're in version tolerant, so python3 only is true
298
- file_import.add_submodule_import(
299
- f"{path_to_models}models",
300
- self.code_model.models_filename,
301
- ImportType.LOCAL,
302
- alias="models",
303
- )
304
- else:
305
- file_import.add_submodule_import(
306
- path_to_models, "models", ImportType.LOCAL
307
- )
290
+ file_import.add_submodule_import(
291
+ path_to_models, "models", ImportType.LOCAL, alias="_models"
292
+ )
308
293
  elif self.code_model.options["models_mode"] == "msrest":
309
294
  # in this case, we have client_models = {} in the service client, which needs a type annotation
310
295
  # this import will always be commented, so will always add it to the typing section
@@ -414,7 +399,11 @@ class Config(_ClientConfigBase[ConfigGlobalParameterList]):
414
399
  and gp not in self.parameters.kwargs_to_pop
415
400
  ):
416
401
  continue
417
- file_import.merge(gp.imports(async_mode=async_mode))
402
+ file_import.merge(
403
+ gp.imports(
404
+ async_mode=async_mode, relative_path=".." if async_mode else "."
405
+ )
406
+ )
418
407
  return file_import
419
408
 
420
409
  def imports_for_multiapi(self, async_mode: bool) -> FileImport:
@@ -426,7 +415,11 @@ class Config(_ClientConfigBase[ConfigGlobalParameterList]):
426
415
  and gp.client_name == "api_version"
427
416
  ):
428
417
  continue
429
- file_import.merge(gp.imports_for_multiapi(async_mode=async_mode))
418
+ file_import.merge(
419
+ gp.imports_for_multiapi(
420
+ async_mode=async_mode, relative_path=".." if async_mode else "."
421
+ )
422
+ )
430
423
  return file_import
431
424
 
432
425
  @classmethod
@@ -41,7 +41,7 @@ class CombinedType(BaseType):
41
41
  If list: '[str]'
42
42
  If dict: '{str}'
43
43
  """
44
- ...
44
+ raise ValueError("Shouldn't get serialization type of a combinedtype")
45
45
 
46
46
  @property
47
47
  def client_default_value(self) -> Any:
@@ -69,7 +69,7 @@ class ImportModel:
69
69
  except AttributeError:
70
70
  return False
71
71
 
72
- def __hash__(self):
72
+ def __hash__(self) -> int:
73
73
  retval: int = 0
74
74
  for attr in dir(self):
75
75
  if attr[0] != "_":
@@ -88,7 +88,7 @@ class TypeDefinition:
88
88
 
89
89
 
90
90
  class FileImport:
91
- def __init__(self, imports: List[ImportModel] = None) -> None:
91
+ def __init__(self, imports: Optional[List[ImportModel]] = None) -> None:
92
92
  self.imports = imports or []
93
93
  # has sync and async type definitions
94
94
  self.type_definitions: Dict[str, TypeDefinition] = {}
@@ -81,6 +81,10 @@ class ModelType(
81
81
  return f"{'' if self.is_public else '_models.'}_models.{self.name}"
82
82
  return "object"
83
83
 
84
+ @property
85
+ def msrest_deserialization_key(self) -> str:
86
+ return self.name
87
+
84
88
  @property
85
89
  def is_polymorphic(self) -> bool:
86
90
  return any(p.is_polymorphic for p in self.properties)
@@ -48,11 +48,16 @@ class OperationGroup(BaseModel):
48
48
 
49
49
  def imports_for_multiapi(self, async_mode: bool) -> FileImport:
50
50
  file_import = FileImport()
51
+ relative_path = ".." if async_mode else "."
51
52
  for operation in self.operations:
52
53
  file_import.merge(
53
- operation.imports_for_multiapi(
54
- async_mode, relative_path=".." if async_mode else "."
55
- )
54
+ operation.imports_for_multiapi(async_mode, relative_path=relative_path)
55
+ )
56
+ if (
57
+ self.code_model.model_types or self.code_model.enums
58
+ ) and self.code_model.options["models_mode"] == "msrest":
59
+ file_import.add_submodule_import(
60
+ relative_path, "models", ImportType.LOCAL, alias="_models"
56
61
  )
57
62
  return file_import
58
63
 
@@ -84,8 +89,10 @@ class OperationGroup(BaseModel):
84
89
  )
85
90
  # for multiapi
86
91
  if (
87
- self.code_model.model_types or self.code_model.enums
88
- ) and self.code_model.options["models_mode"] == "msrest":
92
+ (self.code_model.public_model_types)
93
+ and self.code_model.options["models_mode"] == "msrest"
94
+ and not self.is_mixin
95
+ ):
89
96
  file_import.add_submodule_import(
90
97
  relative_path, "models", ImportType.LOCAL, alias="_models"
91
98
  )
@@ -393,6 +393,15 @@ class _RequestBuilderParameterList(
393
393
  p for p in super().path if p.location != ParameterLocation.ENDPOINT_PATH
394
394
  ]
395
395
 
396
+ @property
397
+ def constant(
398
+ self,
399
+ ) -> List[Union[RequestBuilderParameter, RequestBuilderBodyParameterType]]:
400
+ """All constant parameters"""
401
+ return [
402
+ p for p in super().constant if p.location != ParameterLocation.ENDPOINT_PATH
403
+ ]
404
+
396
405
 
397
406
  class RequestBuilderParameterList(_RequestBuilderParameterList):
398
407
  """Parameter list for Request Builder"""
@@ -463,15 +472,6 @@ class ClientGlobalParameterList(_ClientGlobalParameterList[ClientParameter]):
463
472
  except StopIteration:
464
473
  return None
465
474
 
466
- @property
467
- def kwargs_to_pop(self) -> List[Union[ClientParameter, BodyParameter]]:
468
- """We only want to pass base url path parameters in the client"""
469
- return [
470
- k
471
- for k in super().kwargs_to_pop
472
- if k.location == ParameterLocation.ENDPOINT_PATH
473
- ]
474
-
475
475
 
476
476
  class ConfigGlobalParameterList(_ClientGlobalParameterList[ConfigParameter]):
477
477
  """Parameter list for config"""
@@ -83,6 +83,10 @@ class Property(BaseModel): # pylint: disable=too-many-instance-attributes
83
83
  def serialization_type(self) -> str:
84
84
  return self.type.serialization_type
85
85
 
86
+ @property
87
+ def msrest_deserialization_key(self) -> str:
88
+ return self.type.msrest_deserialization_key
89
+
86
90
  def type_annotation(self, *, is_operation_file: bool = False) -> str:
87
91
  if self.optional and self.client_default_value is None:
88
92
  return f"Optional[{self.type.type_annotation(is_operation_file=is_operation_file)}]"
@@ -71,10 +71,18 @@ class RequestBuilderBase(BaseBuilder[ParameterListType]):
71
71
 
72
72
  def imports(self) -> FileImport:
73
73
  file_import = FileImport()
74
+ relative_path = ".."
75
+ if (
76
+ not self.code_model.options["builders_visibility"] == "embedded"
77
+ and self.group_name
78
+ ):
79
+ relative_path = "..." if self.group_name else ".."
74
80
  if self.abstract:
75
81
  return file_import
76
82
  for parameter in self.parameters.method:
77
- file_import.merge(parameter.imports(async_mode=False))
83
+ file_import.merge(
84
+ parameter.imports(async_mode=False, relative_path=relative_path)
85
+ )
78
86
 
79
87
  file_import.add_submodule_import(
80
88
  "azure.core.rest",
@@ -83,12 +91,6 @@ class RequestBuilderBase(BaseBuilder[ParameterListType]):
83
91
  )
84
92
 
85
93
  if self.parameters.path:
86
- relative_path = ".."
87
- if (
88
- not self.code_model.options["builders_visibility"] == "embedded"
89
- and self.group_name
90
- ):
91
- relative_path = "..." if self.group_name else ".."
92
94
  file_import.add_submodule_import(
93
95
  f"{relative_path}_vendor", "_format_url_section", ImportType.LOCAL
94
96
  )
@@ -9,7 +9,6 @@ from abc import abstractmethod
9
9
  from collections import defaultdict
10
10
  from typing import Any, Generic, List, Type, TypeVar, Dict, Union, Optional, cast
11
11
 
12
-
13
12
  from ..models import (
14
13
  Operation,
15
14
  PagingOperation,
@@ -967,11 +966,10 @@ class _OperationSerializer(
967
966
  retval.append(
968
967
  f"deserialized = self._deserialize('{response.serialization_type}', pipeline_response)"
969
968
  )
970
- elif self.code_model.options["models_mode"] == "dpg" and isinstance(
971
- response.type, ModelType
972
- ):
969
+ elif self.code_model.options["models_mode"] == "dpg":
973
970
  retval.append(
974
- f"deserialized = _deserialize({response.serialization_type}, response.json())"
971
+ f"deserialized = _deserialize({response.type.type_annotation(is_operation_file=True)}"
972
+ ", response.json())"
975
973
  )
976
974
  else:
977
975
  deserialized_value = (
@@ -1290,7 +1288,9 @@ class _PagingOperationSerializer(
1290
1288
  elif self.code_model.options["models_mode"]:
1291
1289
  cont_token_property = f"deserialized.{continuation_token_name} or None"
1292
1290
  else:
1293
- cont_token_property = f'deserialized.get("{continuation_token_name}", None)'
1291
+ cont_token_property = (
1292
+ f'deserialized.get("{continuation_token_name}") or None'
1293
+ )
1294
1294
  list_type = "AsyncList" if self.async_mode else "iter"
1295
1295
  retval.append(f" return {cont_token_property}, {list_type}(list_of_elem)")
1296
1296
  return retval
@@ -116,15 +116,28 @@ class ClientSerializer:
116
116
 
117
117
  def serializers_and_operation_groups_properties(self) -> List[str]:
118
118
  retval = []
119
- if self.client.code_model.model_types:
120
- client_models_value = (
121
- "{k: v for k, v in models.__dict__.items() if isinstance(v, type)}"
122
- )
123
- else:
124
- client_models_value = "{} # type: Dict[str, Any]"
119
+
120
+ def _get_client_models_value(models_dict_name: str) -> str:
121
+ if self.client.code_model.model_types:
122
+ return f"{{k: v for k, v in {models_dict_name}.__dict__.items() if isinstance(v, type)}}"
123
+ return "{} # type: Dict[str, Any]"
124
+
125
125
  is_msrest_model = self.client.code_model.options["models_mode"] == "msrest"
126
126
  if is_msrest_model:
127
- retval.append(f"client_models = {client_models_value}")
127
+ add_private_models = len(self.client.code_model.model_types) != len(
128
+ self.client.code_model.public_model_types
129
+ )
130
+ model_dict_name = (
131
+ f"_models.{self.client.code_model.models_filename}"
132
+ if add_private_models
133
+ else "_models"
134
+ )
135
+ retval.append(
136
+ f"client_models = {_get_client_models_value(model_dict_name)}"
137
+ )
138
+ if add_private_models and self.client.code_model.model_types:
139
+ update_dict = f"{{k: v for k, v in _models.__dict__.items() if isinstance(v, type)}}"
140
+ retval.append(f"client_models.update({update_dict})")
128
141
  client_models_str = "client_models" if is_msrest_model else ""
129
142
  retval.append(f"self._serialize = Serializer({client_models_str})")
130
143
  retval.append(f"self._deserialize = Deserializer({client_models_str})")
@@ -195,7 +195,10 @@ class MsrestModelSerializer(_ModelSerializer):
195
195
  xml_metadata = f", 'xml': {{{prop.type.xml_serialization_ctxt}}}"
196
196
  else:
197
197
  xml_metadata = ""
198
- return f'"{prop.client_name}": {{"key": "{attribute_key}", "type": "{prop.serialization_type}"{xml_metadata}}},'
198
+ return (
199
+ f'"{prop.client_name}": {{"key": "{attribute_key}",'
200
+ f' "type": "{prop.msrest_deserialization_key}"{xml_metadata}}},'
201
+ )
199
202
 
200
203
 
201
204
  class DpgModelSerializer(_ModelSerializer):
@@ -87,6 +87,15 @@ class SampleSerializer:
87
87
 
88
88
  return client_params
89
89
 
90
+ @staticmethod
91
+ def handle_param(param: Any) -> str:
92
+ if isinstance(param, str):
93
+ if any(i in param for i in '\r\n"'):
94
+ return f'"""{param}"""'
95
+ return f'"{param}"'
96
+
97
+ return str(param)
98
+
90
99
  # prepare operation parameters
91
100
  def _operation_params(self) -> Dict[str, Any]:
92
101
  params_positional = [
@@ -94,7 +103,6 @@ class SampleSerializer:
94
103
  for p in self.operation.parameters.positional
95
104
  if not p.client_default_value
96
105
  ]
97
- cls = lambda x: f'"{x}"' if isinstance(x, str) else str(x)
98
106
  failure_info = "fail to find required param named {} in example file {}"
99
107
  operation_params = {}
100
108
  for param in params_positional:
@@ -103,7 +111,7 @@ class SampleSerializer:
103
111
  if not param.optional:
104
112
  if not param_value:
105
113
  raise Exception(failure_info.format(name, self.sample_origin_name))
106
- operation_params[param.client_name] = cls(param_value)
114
+ operation_params[param.client_name] = self.handle_param(param_value)
107
115
  return operation_params
108
116
 
109
117
  def _operation_group_name(self) -> str:
@@ -5,9 +5,6 @@
5
5
  {{ serializer.init_signature_and_response_type_annotation(async_mode) | indent }}
6
6
  {% if serializer.should_init_super %}
7
7
  super().__init__()
8
- {% endif %}
9
- {% if client.parameters.kwargs_to_pop %}
10
- {{ op_tools.serialize(serializer.pop_kwargs_from_signature()) | indent(8) }}
11
8
  {% endif %}
12
9
  {% if client.has_parameterized_host %}
13
10
  {{ serializer.host_variable_name }} = {{ keywords.escape_str(client.url) }}
@@ -17,6 +17,7 @@ from datetime import datetime, date, time, timedelta
17
17
  from azure.core.utils._utils import _FixedOffset
18
18
  from collections.abc import MutableMapping
19
19
  from azure.core.exceptions import DeserializationError
20
+ from azure.core import CaseInsensitiveEnumMeta
20
21
  import copy
21
22
 
22
23
  _LOGGER = logging.getLogger(__name__)
@@ -121,9 +122,11 @@ try:
121
122
  except ImportError:
122
123
  TZ_UTC = _FixedOffset(0) # type: ignore
123
124
 
125
+
124
126
  def _serialize_bytes(o) -> str:
125
127
  return base64.b64encode(o).decode()
126
128
 
129
+
127
130
  def _serialize_datetime(o):
128
131
  if hasattr(o, "year") and hasattr(o, "hour"):
129
132
  # astimezone() fails for naive times in Python 2.7, so make make sure o is aware (tzinfo is set)
@@ -136,12 +139,14 @@ def _serialize_datetime(o):
136
139
  # Next try datetime.date or datetime.time
137
140
  return o.isoformat()
138
141
 
142
+
139
143
  def _is_readonly(p):
140
144
  try:
141
145
  return p._readonly
142
146
  except AttributeError:
143
147
  return False
144
148
 
149
+
145
150
  class AzureJSONEncoder(JSONEncoder):
146
151
  """A JSON encoder that's capable of serializing datetime objects and bytes."""
147
152
 
@@ -170,9 +175,8 @@ class AzureJSONEncoder(JSONEncoder):
170
175
  return super(AzureJSONEncoder, self).default(o)
171
176
 
172
177
 
173
- _VALID_DATE = re.compile(
174
- r'\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}'
175
- r'\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?')
178
+ _VALID_DATE = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?")
179
+
176
180
 
177
181
  def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime:
178
182
  """Deserialize ISO-8601 formatted string into Datetime object.
@@ -188,7 +192,7 @@ def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime:
188
192
  if not match:
189
193
  raise ValueError("Invalid datetime string: " + attr)
190
194
 
191
- check_decimal = attr.split('.')
195
+ check_decimal = attr.split(".")
192
196
  if len(check_decimal) > 1:
193
197
  decimal_str = ""
194
198
  for digit in check_decimal[1]:
@@ -205,6 +209,7 @@ def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime:
205
209
  raise OverflowError("Hit max or min date")
206
210
  return date_obj
207
211
 
212
+
208
213
  def _deserialize_date(attr: typing.Union[str, date]) -> date:
209
214
  """Deserialize ISO-8601 formatted string into Date object.
210
215
  :param str attr: response string to be deserialized.
@@ -215,6 +220,7 @@ def _deserialize_date(attr: typing.Union[str, date]) -> date:
215
220
  return attr
216
221
  return isodate.parse_date(attr, defaultmonth=None, defaultday=None)
217
222
 
223
+
218
224
  def _deserialize_time(attr: typing.Union[str, time]) -> time:
219
225
  """Deserialize ISO-8601 formatted string into time object.
220
226
 
@@ -225,18 +231,31 @@ def _deserialize_time(attr: typing.Union[str, time]) -> time:
225
231
  return attr
226
232
  return isodate.parse_time(attr)
227
233
 
234
+
228
235
  def deserialize_bytes(attr):
236
+ if isinstance(attr, (bytes, bytearray)):
237
+ return attr
229
238
  return bytes(base64.b64decode(attr))
230
239
 
240
+
241
+ def deserialize_duration(attr):
242
+ if isinstance(attr, timedelta):
243
+ return attr
244
+ return isodate.parse_duration(attr)
245
+
246
+
231
247
  _DESERIALIZE_MAPPING = {
232
248
  datetime: _deserialize_datetime,
233
249
  date: _deserialize_date,
234
250
  time: _deserialize_time,
235
251
  bytes: deserialize_bytes,
252
+ timedelta: deserialize_duration,
253
+ typing.Any: lambda x: x,
236
254
  }
237
255
 
256
+
238
257
  def _get_model(module_name: str, model_name: str):
239
- module_end = module_name.rsplit('.', 1)[0]
258
+ module_end = module_name.rsplit(".", 1)[0]
240
259
  module = sys.modules[module_end]
241
260
  models = {k: v for k, v in module.__dict__.items() if isinstance(v, type)}
242
261
  if model_name not in models:
@@ -244,10 +263,11 @@ def _get_model(module_name: str, model_name: str):
244
263
  return model_name
245
264
  return models[model_name]
246
265
 
266
+
247
267
  _UNSET = object()
248
268
 
249
- class _MyMutableMapping(MutableMapping):
250
269
 
270
+ class _MyMutableMapping(MutableMapping):
251
271
  def __init__(self, data: typing.Dict[str, typing.Any]) -> None:
252
272
  self._data = copy.deepcopy(data)
253
273
 
@@ -332,9 +352,11 @@ class _MyMutableMapping(MutableMapping):
332
352
  def __repr__(self) -> str:
333
353
  return str(self._data)
334
354
 
355
+
335
356
  def _is_model(obj: typing.Any) -> bool:
336
357
  return getattr(obj, "_is_model", False)
337
358
 
359
+
338
360
  def _serialize(o):
339
361
  if isinstance(o, (bytes, bytearray)):
340
362
  return _serialize_bytes(o)
@@ -351,7 +373,10 @@ def _serialize(o):
351
373
  pass
352
374
  return o
353
375
 
354
- def _get_rest_field(attr_to_rest_field: typing.Dict[str, "_RestField"], rest_name: str) -> typing.Optional["_RestField"]:
376
+
377
+ def _get_rest_field(
378
+ attr_to_rest_field: typing.Dict[str, "_RestField"], rest_name: str
379
+ ) -> typing.Optional["_RestField"]:
355
380
  try:
356
381
  return next(rf for rf in attr_to_rest_field.values() if rf._rest_name == rest_name)
357
382
  except StopIteration:
@@ -361,8 +386,10 @@ def _get_rest_field(attr_to_rest_field: typing.Dict[str, "_RestField"], rest_nam
361
386
  def _create_value(rest_field: typing.Optional["_RestField"], value: typing.Any) -> typing.Any:
362
387
  return _deserialize(rest_field._type, value) if (rest_field and rest_field._is_model) else _serialize(value)
363
388
 
389
+
364
390
  class Model(_MyMutableMapping):
365
391
  _is_model = True
392
+
366
393
  def __init__(self, *args, **kwargs):
367
394
  class_name = self.__class__.__name__
368
395
  if len(args) > 1:
@@ -373,20 +400,15 @@ class Model(_MyMutableMapping):
373
400
  if rest_field._default is not _UNSET
374
401
  }
375
402
  if args:
376
- dict_to_pass.update({
377
- k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v)
378
- for k, v in args[0].items()
379
- })
403
+ dict_to_pass.update(
404
+ {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()}
405
+ )
380
406
  else:
381
407
  non_attr_kwargs = [k for k in kwargs if k not in self._attr_to_rest_field]
382
408
  if non_attr_kwargs:
383
409
  # actual type errors only throw the first wrong keyword arg they see, so following that.
384
- raise TypeError(
385
- f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'"
386
- )
387
- dict_to_pass.update({
388
- self._attr_to_rest_field[k]._rest_name: _serialize(v) for k, v in kwargs.items()
389
- })
410
+ raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'")
411
+ dict_to_pass.update({self._attr_to_rest_field[k]._rest_name: _serialize(v) for k, v in kwargs.items()})
390
412
  super().__init__(dict_to_pass)
391
413
 
392
414
  def copy(self):
@@ -394,15 +416,14 @@ class Model(_MyMutableMapping):
394
416
 
395
417
  def __new__(cls, *args: typing.Any, **kwargs: typing.Any):
396
418
  # we know the last three classes in mro are going to be 'Model', 'dict', and 'object'
397
- mros = cls.__mro__[:-3][::-1] # ignore model, dict, and object parents, and reverse the mro order
398
- attr_to_rest_field: typing.Dict[str, _RestField] = { # map attribute name to rest_field property
399
- k: v
400
- for mro_class in mros
401
- for k, v in mro_class.__dict__.items() if k[0] != "_" and hasattr(v, "_type")
419
+ mros = cls.__mro__[:-3][::-1] # ignore model, dict, and object parents, and reverse the mro order
420
+ attr_to_rest_field: typing.Dict[str, _RestField] = { # map attribute name to rest_field property
421
+ k: v for mro_class in mros for k, v in mro_class.__dict__.items() if k[0] != "_" and hasattr(v, "_type")
402
422
  }
403
423
  annotations = {
404
424
  k: v
405
- for mro_class in mros if hasattr(mro_class, '__annotations__')
425
+ for mro_class in mros
426
+ if hasattr(mro_class, "__annotations__")
406
427
  for k, v in mro_class.__annotations__.items()
407
428
  }
408
429
  for attr, rest_field in attr_to_rest_field.items():
@@ -411,18 +432,15 @@ class Model(_MyMutableMapping):
411
432
  rest_field._type = rest_field._get_deserialize_callable_from_annotation(annotations.get(attr, None))
412
433
  if not rest_field._rest_name_input:
413
434
  rest_field._rest_name_input = attr
414
- cls._attr_to_rest_field: typing.Dict[str, _RestField] = {
415
- k: v
416
- for k, v in attr_to_rest_field.items()
417
- }
435
+ cls._attr_to_rest_field: typing.Dict[str, _RestField] = {k: v for k, v in attr_to_rest_field.items()}
418
436
 
419
437
  return super().__new__(cls)
420
438
 
421
439
  def __init_subclass__(cls, discriminator=None):
422
440
  for base in cls.__bases__:
423
- if hasattr(base, '__mapping__'):
441
+ if hasattr(base, "__mapping__"):
424
442
  base.__mapping__[discriminator or cls.__name__] = cls
425
-
443
+
426
444
  @classmethod
427
445
  def _get_discriminator(cls) -> typing.Optional[str]:
428
446
  for v in cls.__dict__.values():
@@ -432,81 +450,29 @@ class Model(_MyMutableMapping):
432
450
 
433
451
  @classmethod
434
452
  def _deserialize(cls, data):
435
- if not hasattr(cls, '__mapping__'):
453
+ if not hasattr(cls, "__mapping__"):
436
454
  return cls(data)
437
455
  discriminator = cls._get_discriminator()
438
- mapped_cls = cls.__mapping__.get(data.get(discriminator),cls)
456
+ mapped_cls = cls.__mapping__.get(data.get(discriminator), cls)
439
457
  if mapped_cls == cls:
440
458
  return cls(data)
441
459
  return mapped_cls._deserialize(data)
442
460
 
443
461
 
444
- def _deserialize(deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]], value: typing.Any):
445
- try:
446
- if value is None:
447
- return None
448
- if isinstance(deserializer, type) and issubclass(deserializer, Model):
449
- return deserializer._deserialize(value)
450
- return deserializer(value) if deserializer else value
451
- except Exception as e:
452
- raise DeserializationError() from e
453
-
454
- class _RestField:
455
- def __init__(
456
- self,
457
- *,
458
- name: typing.Optional[str] = None,
459
- type: typing.Optional[typing.Callable] = None,
460
- is_discriminator: bool = False,
461
- readonly: bool = False,
462
- default: typing.Any = _UNSET,
463
- ):
464
- self._type = type
465
- self._rest_name_input = name
466
- self._module: typing.Optional[str] = None
467
- self._is_discriminator = is_discriminator
468
- self._readonly = readonly
469
- self._is_model = False
470
- self._default = default
471
-
472
- @property
473
- def _rest_name(self) -> str:
474
- if self._rest_name_input is None:
475
- raise ValueError("Rest name was never set")
476
- return self._rest_name_input
477
-
478
- def __get__(self, obj: Model, type=None):
479
- # by this point, type and rest_name will have a value bc we default
480
- # them in __new__ of the Model class
481
- item = obj.get(self._rest_name)
482
- if item is None:
483
- return item
484
- return _deserialize(self._type, _serialize(item))
485
-
486
- def __set__(self, obj: Model, value) -> None:
487
- if value is None:
488
- # we want to wipe out entries if users set attr to None
489
- try:
490
- obj.__delitem__(self._rest_name)
491
- except KeyError:
492
- pass
493
- return
494
- if self._is_model and not _is_model(value):
495
- obj.__setitem__(self._rest_name, _deserialize(self._type, value))
496
- obj.__setitem__(self._rest_name, _serialize(value))
497
-
498
- def _get_deserialize_callable_from_annotation(self, annotation: typing.Any) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]:
462
+ def _get_deserialize_callable_from_annotation(
463
+ annotation: typing.Any, module: str,
464
+ ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]:
499
465
  if not annotation or annotation in [int, float]:
500
466
  return None
501
467
 
502
468
  try:
503
- if _is_model(_get_model(self._module, annotation)):
504
- self._is_model = True
469
+ if _is_model(_get_model(module, annotation)):
505
470
  def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj):
506
471
  if _is_model(obj):
507
472
  return obj
508
473
  return _deserialize(model_deserializer, obj)
509
- return functools.partial(_deserialize_model, _get_model(self._module, annotation))
474
+
475
+ return functools.partial(_deserialize_model, _get_model(module, annotation))
510
476
  except Exception:
511
477
  pass
512
478
 
@@ -517,49 +483,61 @@ class _RestField:
517
483
  except AttributeError:
518
484
  pass
519
485
 
486
+ if isinstance(annotation, typing._GenericAlias): # pylint: disable=protected-access
487
+ if annotation.__origin__ is typing.Union:
488
+ def _deserialize_with_union(union_annotation: typing._GenericAlias, obj):
489
+ for t in union_annotation.__args__:
490
+ try:
491
+ return _deserialize(t, obj)
492
+ except DeserializationError:
493
+ pass
494
+ raise DeserializationError()
495
+ return functools.partial(_deserialize_with_union, annotation)
496
+
520
497
  # is it optional?
521
498
  try:
522
499
  # right now, assuming we don't have unions, since we're getting rid of the only
523
500
  # union we used to have in msrest models, which was union of str and enum
524
501
  if any(a for a in annotation.__args__ if a == type(None)):
525
502
 
526
- if_obj_deserializer = self._get_deserialize_callable_from_annotation(
527
- next(a for a in annotation.__args__ if a != type(None)),
503
+ if_obj_deserializer = _get_deserialize_callable_from_annotation(
504
+ next(a for a in annotation.__args__ if a != type(None)), module
528
505
  )
506
+
529
507
  def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj):
530
508
  if obj is None:
531
509
  return obj
532
- return _deserialize(if_obj_deserializer, obj)
510
+ return _deserialize_with_callable(if_obj_deserializer, obj)
533
511
 
534
512
  return functools.partial(_deserialize_with_optional, if_obj_deserializer)
535
513
  except (AttributeError):
536
514
  pass
537
515
 
538
-
539
516
  # is it a forward ref / in quotes?
540
517
  if isinstance(annotation, str) or type(annotation) == typing.ForwardRef:
541
518
  try:
542
519
  model_name = annotation.__forward_arg__ # type: ignore
543
520
  except AttributeError:
544
521
  model_name = annotation
545
- if self._module is not None:
546
- annotation = _get_model(self._module, model_name)
522
+ if module is not None:
523
+ annotation = _get_model(module, model_name)
547
524
 
548
525
  try:
549
526
  if annotation._name == "Dict":
550
- key_deserializer = self._get_deserialize_callable_from_annotation(annotation.__args__[0])
551
- value_deserializer = self._get_deserialize_callable_from_annotation(annotation.__args__[1])
527
+ key_deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[0], module)
528
+ value_deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[1], module)
529
+
552
530
  def _deserialize_dict(
553
531
  key_deserializer: typing.Optional[typing.Callable],
554
532
  value_deserializer: typing.Optional[typing.Callable],
555
- obj: typing.Dict[typing.Any, typing.Any]
533
+ obj: typing.Dict[typing.Any, typing.Any],
556
534
  ):
557
535
  if obj is None:
558
536
  return obj
559
537
  return {
560
- _deserialize(key_deserializer, k): _deserialize(value_deserializer, v)
561
- for k, v in obj.items()
538
+ _deserialize(key_deserializer, k): _deserialize(value_deserializer, v) for k, v in obj.items()
562
539
  }
540
+
563
541
  return functools.partial(
564
542
  _deserialize_dict,
565
543
  key_deserializer,
@@ -570,38 +548,31 @@ class _RestField:
570
548
  try:
571
549
  if annotation._name in ["List", "Set", "Tuple", "Sequence"]:
572
550
  if len(annotation.__args__) > 1:
551
+
573
552
  def _deserialize_multiple_sequence(
574
- entry_deserializers: typing.List[typing.Optional[typing.Callable]],
575
- obj
553
+ entry_deserializers: typing.List[typing.Optional[typing.Callable]], obj
576
554
  ):
577
555
  if obj is None:
578
556
  return obj
579
557
  return type(obj)(
580
- _deserialize(deserializer, entry)
581
- for entry, deserializer in zip(obj, entry_deserializers)
558
+ _deserialize(deserializer, entry) for entry, deserializer in zip(obj, entry_deserializers)
582
559
  )
560
+
583
561
  entry_deserializers = [
584
- self._get_deserialize_callable_from_annotation(dt)
585
- for dt in annotation.__args__
562
+ _get_deserialize_callable_from_annotation(dt, module) for dt in annotation.__args__
586
563
  ]
587
- return functools.partial(
588
- _deserialize_multiple_sequence,
589
- entry_deserializers
590
- )
591
- deserializer = self._get_deserialize_callable_from_annotation(annotation.__args__[0])
564
+ return functools.partial(_deserialize_multiple_sequence, entry_deserializers)
565
+ deserializer = _get_deserialize_callable_from_annotation(annotation.__args__[0], module)
566
+
592
567
  def _deserialize_sequence(
593
568
  deserializer: typing.Optional[typing.Callable],
594
569
  obj,
595
570
  ):
596
571
  if obj is None:
597
572
  return obj
598
- return type(obj)(
599
- _deserialize(deserializer, entry) for entry in obj
600
- )
601
- return functools.partial(
602
- _deserialize_sequence,
603
- deserializer
604
- )
573
+ return type(obj)(_deserialize(deserializer, entry) for entry in obj)
574
+
575
+ return functools.partial(_deserialize_sequence, deserializer)
605
576
  except (TypeError, IndexError, AttributeError, SyntaxError):
606
577
  pass
607
578
 
@@ -613,15 +584,84 @@ class _RestField:
613
584
  if obj is None:
614
585
  return obj
615
586
  try:
616
- return _deserialize(annotation, obj)
587
+ return _deserialize_with_callable(annotation, obj)
617
588
  except Exception:
618
589
  pass
619
- return _deserialize(deserializer_from_mapping, obj)
620
- return functools.partial(
621
- _deserialize_default,
622
- annotation,
623
- _DESERIALIZE_MAPPING.get(annotation)
624
- )
590
+ return _deserialize_with_callable(deserializer_from_mapping, obj)
591
+
592
+ return functools.partial(_deserialize_default, annotation, _DESERIALIZE_MAPPING.get(annotation))
593
+
594
+
595
+ def _deserialize_with_callable(deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]], value: typing.Any):
596
+ try:
597
+ if value is None:
598
+ return None
599
+ if isinstance(deserializer, CaseInsensitiveEnumMeta):
600
+ try:
601
+ return deserializer(value)
602
+ except ValueError:
603
+ # for unknown value, return raw value
604
+ return value
605
+ if isinstance(deserializer, type) and issubclass(deserializer, Model):
606
+ return deserializer._deserialize(value)
607
+ return deserializer(value) if deserializer else value
608
+ except Exception as e:
609
+ raise DeserializationError() from e
610
+
611
+
612
+ def _deserialize(deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]], value: typing.Any):
613
+ deserializer = _get_deserialize_callable_from_annotation(deserializer, "")
614
+ return _deserialize_with_callable(deserializer, value)
615
+
616
+ class _RestField:
617
+ def __init__(
618
+ self,
619
+ *,
620
+ name: typing.Optional[str] = None,
621
+ type: typing.Optional[typing.Callable] = None,
622
+ is_discriminator: bool = False,
623
+ readonly: bool = False,
624
+ default: typing.Any = _UNSET,
625
+ ):
626
+ self._type = type
627
+ self._rest_name_input = name
628
+ self._module: typing.Optional[str] = None
629
+ self._is_discriminator = is_discriminator
630
+ self._readonly = readonly
631
+ self._is_model = False
632
+ self._default = default
633
+
634
+ @property
635
+ def _rest_name(self) -> str:
636
+ if self._rest_name_input is None:
637
+ raise ValueError("Rest name was never set")
638
+ return self._rest_name_input
639
+
640
+ def __get__(self, obj: Model, type=None):
641
+ # by this point, type and rest_name will have a value bc we default
642
+ # them in __new__ of the Model class
643
+ item = obj.get(self._rest_name)
644
+ if item is None:
645
+ return item
646
+ return _deserialize(self._type, _serialize(item))
647
+
648
+ def __set__(self, obj: Model, value) -> None:
649
+ if value is None:
650
+ # we want to wipe out entries if users set attr to None
651
+ try:
652
+ obj.__delitem__(self._rest_name)
653
+ except KeyError:
654
+ pass
655
+ return
656
+ if self._is_model and not _is_model(value):
657
+ obj.__setitem__(self._rest_name, _deserialize(self._type, value))
658
+ obj.__setitem__(self._rest_name, _serialize(value))
659
+
660
+ def _get_deserialize_callable_from_annotation(
661
+ self, annotation: typing.Any
662
+ ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]:
663
+ return _get_deserialize_callable_from_annotation(annotation, self._module)
664
+
625
665
 
626
666
  def rest_field(
627
667
  *,
@@ -632,5 +672,8 @@ def rest_field(
632
672
  ) -> typing.Any:
633
673
  return _RestField(name=name, type=type, readonly=readonly, default=default)
634
674
 
635
- def rest_discriminator(*, name: typing.Optional[str] = None, type: typing.Optional[typing.Callable] = None) -> typing.Any:
675
+
676
+ def rest_discriminator(
677
+ *, name: typing.Optional[str] = None, type: typing.Optional[typing.Callable] = None
678
+ ) -> typing.Any:
636
679
  return _RestField(name=name, type=type, is_discriminator=True)
@@ -58,6 +58,7 @@ setup(
58
58
  "Programming Language :: Python :: 3.8",
59
59
  "Programming Language :: Python :: 3.9",
60
60
  "Programming Language :: Python :: 3.10",
61
+ "Programming Language :: Python :: 3.11",
61
62
  "License :: OSI Approved :: MIT License",
62
63
  ],
63
64
  zip_safe=False,
@@ -114,7 +114,9 @@ class AutorestAPI(ABC):
114
114
  def message(self, channel: Channel, text: str) -> None:
115
115
  """Send a log message to autorest."""
116
116
 
117
- def get_boolean_value(self, key: str, default: bool = None) -> Optional[bool]:
117
+ def get_boolean_value(
118
+ self, key: str, default: Optional[bool] = None
119
+ ) -> Optional[bool]:
118
120
  """Check if value is present on the line, and interpret it as bool if it was.
119
121
 
120
122
  If value is not not on the line, return the "default".
@@ -17,7 +17,9 @@ class LocalAutorestAPI(AutorestAPI):
17
17
  """A local API that will write on local disk."""
18
18
 
19
19
  def __init__(
20
- self, reachable_files: List[str] = None, output_folder: str = "generated"
20
+ self,
21
+ reachable_files: Optional[List[str]] = None,
22
+ output_folder: str = "generated",
21
23
  ) -> None:
22
24
  super().__init__()
23
25
  if reachable_files is None:
@@ -25,30 +25,34 @@ class TypingSection(str, Enum):
25
25
  class FileImport:
26
26
  def __init__(
27
27
  self,
28
- imports: Dict[
29
- TypingSection,
28
+ imports: Optional[
30
29
  Dict[
31
- ImportType,
30
+ TypingSection,
32
31
  Dict[
33
- str,
34
- Set[
35
- Optional[
36
- Union[
37
- str,
38
- Tuple[
39
- str,
32
+ ImportType,
33
+ Dict[
34
+ str,
35
+ Set[
36
+ Optional[
37
+ Union[
40
38
  str,
41
- ],
42
- Tuple[
43
- str,
44
- str,
45
- Tuple[Tuple[Tuple[int, int], str, Optional[str]]],
46
- ],
39
+ Tuple[
40
+ str,
41
+ str,
42
+ ],
43
+ Tuple[
44
+ str,
45
+ str,
46
+ Tuple[
47
+ Tuple[Tuple[int, int], str, Optional[str]]
48
+ ],
49
+ ],
50
+ ]
47
51
  ]
48
- ]
52
+ ],
49
53
  ],
50
54
  ],
51
- ],
55
+ ]
52
56
  ] = None,
53
57
  ) -> None:
54
58
  # Basic implementation
@@ -9,7 +9,11 @@ import copy
9
9
  from typing import Callable, Dict, Any, List, Optional
10
10
 
11
11
  from .._utils import to_snake_case
12
- from .helpers import pad_reserved_words, add_redefined_builtin_info
12
+ from .helpers import (
13
+ pad_reserved_words,
14
+ add_redefined_builtin_info,
15
+ pad_builtin_namespaces,
16
+ )
13
17
  from .python_mappings import PadType
14
18
 
15
19
  from .. import YamlUpdatePlugin, YamlUpdatePluginAutorest
@@ -362,6 +366,8 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
362
366
  for client in clients:
363
367
  update_client(client)
364
368
  self.update_operation_groups(yaml_data, client)
369
+ if yaml_data.get("namespace"):
370
+ yaml_data["namespace"] = pad_builtin_namespaces(yaml_data["namespace"])
365
371
 
366
372
 
367
373
  class PreProcessPluginAutorest(YamlUpdatePluginAutorest, PreProcessPlugin):
@@ -3,9 +3,14 @@
3
3
  # Licensed under the MIT License. See License.txt in the project root for
4
4
  # license information.
5
5
  # --------------------------------------------------------------------------
6
- from typing import Any, Dict
7
6
  import re
8
- from .python_mappings import PadType, RESERVED_WORDS, REDEFINED_BUILTINS
7
+ from typing import Any, Dict
8
+ from .python_mappings import (
9
+ PadType,
10
+ RESERVED_WORDS,
11
+ REDEFINED_BUILTINS,
12
+ BUILTIN_PACKAGES,
13
+ )
9
14
 
10
15
 
11
16
  def pad_reserved_words(name: str, pad_type: PadType):
@@ -26,5 +31,12 @@ def add_redefined_builtin_info(name: str, yaml_data: Dict[str, Any]) -> None:
26
31
  yaml_data["pylintDisable"] = "redefined-builtin"
27
32
 
28
33
 
34
+ def pad_builtin_namespaces(namespace: str) -> str:
35
+ items = namespace.split(".")
36
+ if items[0] in BUILTIN_PACKAGES:
37
+ items[0] = items[0] + "_"
38
+ return ".".join(items)
39
+
40
+
29
41
  def pad_special_chars(name: str) -> str:
30
42
  return re.sub(r"[^A-z0-9_]", "_", name)
@@ -174,3 +174,30 @@ REDEFINED_BUILTINS = [ # we don't pad, but we need to do lint ignores
174
174
  "max",
175
175
  "filter",
176
176
  ]
177
+
178
+ BUILTIN_PACKAGES = [
179
+ "array",
180
+ "atexit",
181
+ "binascii",
182
+ "builtins",
183
+ "cmath",
184
+ "errno",
185
+ "faulthandler",
186
+ "fcntl",
187
+ "gc",
188
+ "grp",
189
+ "itertools",
190
+ "marshal",
191
+ "math",
192
+ "posix",
193
+ "pwd",
194
+ "pyexpat",
195
+ "select",
196
+ "spwd",
197
+ "sys",
198
+ "syslog",
199
+ "time",
200
+ "unicodedata",
201
+ "xxsubtype",
202
+ "zlib",
203
+ ]
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@autorest/python",
3
- "version": "6.2.3",
3
+ "version": "6.2.6",
4
4
  "description": "The Python extension for generators in AutoRest.",
5
5
  "repository": {
6
6
  "type": "git",
@@ -21,7 +21,7 @@
21
21
  "@autorest/system-requirements": "~1.0.0"
22
22
  },
23
23
  "devDependencies": {
24
- "@microsoft.azure/autorest.testserver": "^3.3.41",
24
+ "@microsoft.azure/autorest.testserver": "^3.3.45",
25
25
  "typescript": "^4.8.3"
26
26
  },
27
27
  "files": [
package/setup.py CHANGED
@@ -38,6 +38,9 @@ setup(
38
38
  'Programming Language :: Python :: 3',
39
39
  'Programming Language :: Python :: 3.7',
40
40
  'Programming Language :: Python :: 3.8',
41
+ 'Programming Language :: Python :: 3.9',
42
+ 'Programming Language :: Python :: 3.10',
43
+ 'Programming Language :: Python :: 3.11',
41
44
  'License :: OSI Approved :: MIT License',
42
45
  ],
43
46
  packages=find_packages(exclude=[