@autorest/python 6.7.3 → 6.7.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -406,9 +406,13 @@ class ClientParameter(Parameter):
406
406
  def method_location(self) -> ParameterMethodLocation:
407
407
  if self.constant:
408
408
  return ParameterMethodLocation.KWARG
409
- if self.is_host and (
410
- self.code_model.options["version_tolerant"]
411
- or self.code_model.options["low_level_client"]
409
+ if (
410
+ self.is_host
411
+ and (
412
+ self.code_model.options["version_tolerant"]
413
+ or self.code_model.options["low_level_client"]
414
+ )
415
+ and not self.code_model.options["azure_arm"]
412
416
  ):
413
417
  # this means i am the base url
414
418
  return ParameterMethodLocation.KEYWORD_ONLY
@@ -751,7 +751,7 @@ class _OperationSerializer(
751
751
  elif self.code_model.options["models_mode"] == "dpg":
752
752
  create_body_call = (
753
753
  f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
754
- "cls=AzureJSONEncoder) # type: ignore"
754
+ "cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore"
755
755
  )
756
756
  else:
757
757
  create_body_call = f"_{body_kwarg_name} = {body_param.client_name}"
@@ -128,10 +128,16 @@ def _is_readonly(p):
128
128
  class AzureJSONEncoder(JSONEncoder):
129
129
  """A JSON encoder that's capable of serializing datetime objects and bytes."""
130
130
 
131
+ def __init__(self, *args, exclude_readonly: bool = False, **kwargs):
132
+ super().__init__(*args, **kwargs)
133
+ self.exclude_readonly = exclude_readonly
134
+
131
135
  def default(self, o): # pylint: disable=too-many-return-statements
132
136
  if _is_model(o):
133
- readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)]
134
- return {k: v for k, v in o.items() if k not in readonly_props}
137
+ if self.exclude_readonly:
138
+ readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)]
139
+ return {k: v for k, v in o.items() if k not in readonly_props}
140
+ return dict(o.items())
135
141
  if isinstance(o, (bytes, bytearray)):
136
142
  return base64.b64encode(o).decode()
137
143
  if isinstance(o, _Null):
@@ -295,11 +301,29 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] =
295
301
  return _DESERIALIZE_MAPPING.get(annotation)
296
302
 
297
303
 
304
+ def _get_type_alias_type(module_name: str, alias_name: str):
305
+ types = {
306
+ k: v
307
+ for k, v in sys.modules[module_name].__dict__.items()
308
+ if isinstance(v, typing._GenericAlias) # type: ignore
309
+ }
310
+ if alias_name not in types:
311
+ return alias_name
312
+ return types[alias_name]
313
+
314
+
298
315
  def _get_model(module_name: str, model_name: str):
299
- models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)}
316
+ models = {
317
+ k: v
318
+ for k, v in sys.modules[module_name].__dict__.items()
319
+ if isinstance(v, type)
320
+ }
300
321
  module_end = module_name.rsplit(".", 1)[0]
301
- module = sys.modules[module_end]
302
- models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)})
322
+ models.update({
323
+ k: v
324
+ for k, v in sys.modules[module_end].__dict__.items()
325
+ if isinstance(v, type)
326
+ })
303
327
  if isinstance(model_name, str):
304
328
  model_name = model_name.split(".")[-1]
305
329
  if model_name not in models:
@@ -461,7 +485,7 @@ class Model(_MyMutableMapping):
461
485
  raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'")
462
486
  dict_to_pass.update(
463
487
  {
464
- self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format)
488
+ self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v)
465
489
  for k, v in kwargs.items()
466
490
  if v is not None
467
491
  }
@@ -499,24 +523,60 @@ class Model(_MyMutableMapping):
499
523
  base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member
500
524
 
501
525
  @classmethod
502
- def _get_discriminator(cls) -> typing.Optional[str]:
526
+ def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]:
503
527
  for v in cls.__dict__.values():
504
- if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access
528
+ if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: # pylint: disable=protected-access
505
529
  return v._rest_name # pylint: disable=protected-access
506
530
  return None
507
531
 
508
532
  @classmethod
509
- def _deserialize(cls, data):
533
+ def _deserialize(cls, data, exist_discriminators):
510
534
  if not hasattr(cls, "__mapping__"): # pylint: disable=no-member
511
535
  return cls(data)
512
- discriminator = cls._get_discriminator()
513
- mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member
536
+ discriminator = cls._get_discriminator(exist_discriminators)
537
+ exist_discriminators.append(discriminator)
538
+ mapped_cls = cls.__mapping__.get(
539
+ data.get(discriminator), cls
540
+ ) # pylint: disable=no-member
514
541
  if mapped_cls == cls:
515
542
  return cls(data)
516
- return mapped_cls._deserialize(data) # pylint: disable=protected-access
517
-
518
-
519
- def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements
543
+ return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access
544
+
545
+ def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]:
546
+ """Return a dict that can be JSONify using json.dump.
547
+
548
+ :keyword bool exclude_readonly: Whether to remove the readonly properties.
549
+ :returns: A dict JSON compatible object
550
+ :rtype: dict
551
+ """
552
+
553
+ result = {}
554
+ if exclude_readonly:
555
+ readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)]
556
+ for k, v in self.items():
557
+ if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false
558
+ continue
559
+ result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly)
560
+ return result
561
+
562
+ @staticmethod
563
+ def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any:
564
+ if v is None or isinstance(v, _Null):
565
+ return None
566
+ if isinstance(v, (list, tuple, set)):
567
+ return [
568
+ Model._as_dict_value(x, exclude_readonly=exclude_readonly)
569
+ for x in v
570
+ ]
571
+ if isinstance(v, dict):
572
+ return {
573
+ dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly)
574
+ for dk, dv in v.items()
575
+ }
576
+ return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v
577
+
578
+
579
+ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912
520
580
  annotation: typing.Any,
521
581
  module: typing.Optional[str],
522
582
  rf: typing.Optional["_RestField"] = None,
@@ -524,8 +584,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur
524
584
  if not annotation or annotation in [int, float]:
525
585
  return None
526
586
 
587
+ # is it a type alias?
588
+ if isinstance(annotation, str):
589
+ if module is not None:
590
+ annotation = _get_type_alias_type(module, annotation)
591
+
592
+ # is it a forward ref / in quotes?
593
+ if isinstance(annotation, (str, typing.ForwardRef)):
594
+ try:
595
+ model_name = annotation.__forward_arg__ # type: ignore
596
+ except AttributeError:
597
+ model_name = annotation
598
+ if module is not None:
599
+ annotation = _get_model(module, model_name)
600
+
527
601
  try:
528
- if module and _is_model(_get_model(module, annotation)):
602
+ if module and _is_model(annotation):
529
603
  if rf:
530
604
  rf._is_model = True
531
605
 
@@ -534,7 +608,7 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur
534
608
  return obj
535
609
  return _deserialize(model_deserializer, obj)
536
610
 
537
- return functools.partial(_deserialize_model, _get_model(module, annotation))
611
+ return functools.partial(_deserialize_model, annotation)
538
612
  except Exception:
539
613
  pass
540
614
 
@@ -552,22 +626,8 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur
552
626
  except AttributeError:
553
627
  pass
554
628
 
555
- if getattr(annotation, "__origin__", None) is typing.Union:
556
-
557
- def _deserialize_with_union(union_annotation, obj):
558
- for t in union_annotation.__args__:
559
- try:
560
- return _deserialize(t, obj, module, rf)
561
- except DeserializationError:
562
- pass
563
- raise DeserializationError()
564
-
565
- return functools.partial(_deserialize_with_union, annotation)
566
-
567
629
  # is it optional?
568
630
  try:
569
- # right now, assuming we don't have unions, since we're getting rid of the only
570
- # union we used to have in msrest models, which was union of str and enum
571
631
  if any(a for a in annotation.__args__ if a == type(None)):
572
632
  if_obj_deserializer = _get_deserialize_callable_from_annotation(
573
633
  next(a for a in annotation.__args__ if a != type(None)), module, rf
@@ -582,14 +642,18 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur
582
642
  except AttributeError:
583
643
  pass
584
644
 
585
- # is it a forward ref / in quotes?
586
- if isinstance(annotation, (str, typing.ForwardRef)):
587
- try:
588
- model_name = annotation.__forward_arg__ # type: ignore
589
- except AttributeError:
590
- model_name = annotation
591
- if module is not None:
592
- annotation = _get_model(module, model_name)
645
+ if getattr(annotation, "__origin__", None) is typing.Union:
646
+ deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__]
647
+
648
+ def _deserialize_with_union(deserializers, obj):
649
+ for deserializer in deserializers:
650
+ try:
651
+ return _deserialize(deserializer, obj)
652
+ except DeserializationError:
653
+ pass
654
+ raise DeserializationError()
655
+
656
+ return functools.partial(_deserialize_with_union, deserializers)
593
657
 
594
658
  try:
595
659
  if annotation._name == "Dict":
@@ -680,7 +744,7 @@ def _deserialize_with_callable(
680
744
  # for unknown value, return raw value
681
745
  return value
682
746
  if isinstance(deserializer, type) and issubclass(deserializer, Model):
683
- return deserializer._deserialize(value)
747
+ return deserializer._deserialize(value, [])
684
748
  return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value)
685
749
  except Exception as e:
686
750
  raise DeserializationError() from e
@@ -730,6 +794,8 @@ class _RestField:
730
794
  item = obj.get(self._rest_name)
731
795
  if item is None:
732
796
  return item
797
+ if self._is_model:
798
+ return item
733
799
  return _deserialize(self._type, _serialize(item, self._format), rf=self)
734
800
 
735
801
  def __set__(self, obj: Model, value) -> None:
@@ -740,8 +806,11 @@ class _RestField:
740
806
  except KeyError:
741
807
  pass
742
808
  return
743
- if self._is_model and not _is_model(value):
744
- obj.__setitem__(self._rest_name, _deserialize(self._type, value))
809
+ if self._is_model:
810
+ if not _is_model(value):
811
+ value = _deserialize(self._type, value)
812
+ obj.__setitem__(self._rest_name, value)
813
+ return
745
814
  obj.__setitem__(self._rest_name, _serialize(value, self._format))
746
815
 
747
816
  def _get_deserialize_callable_from_annotation(
@@ -1918,7 +1918,7 @@ class Deserializer(object):
1918
1918
  if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore
1919
1919
  raise DeserializationError("Date must have only digits and -. Received: %s" % attr)
1920
1920
  # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
1921
- return isodate.parse_date(attr, defaultmonth=None, defaultday=None)
1921
+ return isodate.parse_date(attr, defaultmonth=0, defaultday=0)
1922
1922
 
1923
1923
  @staticmethod
1924
1924
  def deserialize_time(attr):
@@ -82,15 +82,15 @@ def main() -> None:
82
82
  "AUTOREST_PYTHON_ATTACH_VSCODE_DEBUG", False
83
83
  ):
84
84
  try:
85
- import ptvsd # pylint: disable=import-outside-toplevel
85
+ import debugpy # pylint: disable=import-outside-toplevel
86
86
  except ImportError as exc:
87
87
  raise SystemExit(
88
88
  "Please pip install ptvsd in order to use VSCode debugging"
89
89
  ) from exc
90
90
 
91
91
  # 5678 is the default attach port in the VS Code debug configurations
92
- ptvsd.enable_attach(address=("localhost", 5678), redirect_output=True)
93
- ptvsd.wait_for_attach()
92
+ debugpy.listen(("localhost", 5678))
93
+ debugpy.wait_for_client()
94
94
  breakpoint() # pylint: disable=undefined-variable,forgotten-debug-statement
95
95
 
96
96
  _LOGGER.debug("Starting JSON RPC server")
@@ -1020,9 +1020,7 @@ class M4Reformatter(
1020
1020
  retval = {
1021
1021
  "type": OAUTH_TYPE,
1022
1022
  "policy": {
1023
- "type": "ARMChallengeAuthenticationPolicy"
1024
- if self.azure_arm
1025
- else "BearerTokenCredentialPolicy",
1023
+ "type": "BearerTokenCredentialPolicy",
1026
1024
  "credentialScopes": credential_scopes,
1027
1025
  },
1028
1026
  }
@@ -1068,11 +1066,7 @@ class M4Reformatter(
1068
1066
  return credential_scopes
1069
1067
 
1070
1068
  def update_credential_from_flags(self) -> Dict[str, Any]:
1071
- default_auth_policy = (
1072
- "ARMChallengeAuthenticationPolicy"
1073
- if self.azure_arm
1074
- else "BearerTokenCredentialPolicy"
1075
- )
1069
+ default_auth_policy = "BearerTokenCredentialPolicy"
1076
1070
  auth_policy = (
1077
1071
  self._autorestapi.get_value("credential-default-policy-type")
1078
1072
  or default_auth_policy
@@ -201,6 +201,10 @@ HEADERS_CONVERT_IN_METHOD = {
201
201
  }
202
202
 
203
203
 
204
+ def get_wire_name_lower(parameter: Dict[str, Any]) -> str:
205
+ return (parameter.get("wireName") or "").lower()
206
+
207
+
204
208
  def headers_convert(yaml_data: Dict[str, Any], replace_data: Any) -> None:
205
209
  if isinstance(replace_data, dict):
206
210
  for k, v in replace_data.items():
@@ -210,6 +214,10 @@ def headers_convert(yaml_data: Dict[str, Any], replace_data: Any) -> None:
210
214
  class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
211
215
  """Add Python naming information."""
212
216
 
217
+ @property
218
+ def azure_arm(self) -> bool:
219
+ return self.options.get("azure-arm", False)
220
+
213
221
  @property
214
222
  def version_tolerant(self) -> bool:
215
223
  return self.options.get("version-tolerant", True)
@@ -285,25 +293,59 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
285
293
  yaml_data["legacyFilename"] = to_snake_case(yaml_data["name"].replace(" ", "_"))
286
294
  for parameter in yaml_data["parameters"]:
287
295
  self.update_parameter(parameter)
296
+ if parameter["clientName"] == "credential":
297
+ policy = parameter["type"].get("policy")
298
+ if (
299
+ policy
300
+ and policy["type"] == "BearerTokenCredentialPolicy"
301
+ and self.azure_arm
302
+ ):
303
+ policy["type"] = "ARMChallengeAuthenticationPolicy"
304
+ policy["credentialScopes"] = [
305
+ "https://management.azure.com/.default"
306
+ ]
307
+
288
308
  prop_name = yaml_data["name"]
289
309
  if prop_name.endswith("Client"):
290
310
  prop_name = prop_name[: len(prop_name) - len("Client")]
291
311
  yaml_data["builderPadName"] = to_snake_case(prop_name)
292
312
  for og in yaml_data["operationGroups"]:
293
313
  for o in og["operations"]:
314
+ property_if_match = None
315
+ property_if_none_match = None
294
316
  for p in o["parameters"]:
317
+ wire_name_lower = get_wire_name_lower(p)
295
318
  if (
296
319
  p["location"] == "header"
297
- and p["wireName"] == "client-request-id"
298
- ):
299
- yaml_data["requestIdHeaderName"] = p["wireName"]
300
- if (
301
- self.version_tolerant
302
- and p["location"] == "header"
303
- and p["clientName"] in ("if_match", "if_none_match")
320
+ and wire_name_lower == "client-request-id"
304
321
  ):
305
- o["hasEtag"] = True
306
- yaml_data["hasEtag"] = True
322
+ yaml_data["requestIdHeaderName"] = wire_name_lower
323
+ if self.version_tolerant and p["location"] == "header":
324
+ if wire_name_lower == "if-match":
325
+ property_if_match = p
326
+ elif wire_name_lower == "if-none-match":
327
+ property_if_none_match = p
328
+ # pylint: disable=line-too-long
329
+ # some service(e.g. https://github.com/Azure/azure-rest-api-specs/blob/main/specification/cosmos-db/data-plane/Microsoft.Tables/preview/2019-02-02/table.json)
330
+ # only has one, so we need to add "if-none-match" or "if-match" if it's missing
331
+ if not property_if_match and property_if_none_match:
332
+ property_if_match = property_if_none_match.copy()
333
+ property_if_match["wireName"] = "if-match"
334
+ if not property_if_none_match and property_if_match:
335
+ property_if_none_match = property_if_match.copy()
336
+ property_if_none_match["wireName"] = "if-none-match"
337
+
338
+ if property_if_match and property_if_none_match:
339
+ # arrange if-match and if-none-match to the end of parameters
340
+ o["parameters"] = [
341
+ item
342
+ for item in o["parameters"]
343
+ if get_wire_name_lower(item)
344
+ not in ("if-match", "if-none-match")
345
+ ] + [property_if_match, property_if_none_match]
346
+
347
+ o["hasEtag"] = True
348
+ yaml_data["hasEtag"] = True
307
349
 
308
350
  def get_operation_updater(
309
351
  self, yaml_data: Dict[str, Any]
@@ -380,7 +422,6 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
380
422
  add_overloads_for_body_param(yaml_data)
381
423
 
382
424
  def _update_lro_operation_helper(self, yaml_data: Dict[str, Any]) -> None:
383
- azure_arm = self.options.get("azure-arm", False)
384
425
  for response in yaml_data.get("responses", []):
385
426
  response["discriminator"] = "lro"
386
427
  response["pollerSync"] = (
@@ -392,13 +433,13 @@ class PreProcessPlugin(YamlUpdatePlugin): # pylint: disable=abstract-method
392
433
  if not response.get("pollingMethodSync"):
393
434
  response["pollingMethodSync"] = (
394
435
  "azure.mgmt.core.polling.arm_polling.ARMPolling"
395
- if azure_arm
436
+ if self.azure_arm
396
437
  else "azure.core.polling.base_polling.LROBasePolling"
397
438
  )
398
439
  if not response.get("pollingMethodAsync"):
399
440
  response["pollingMethodAsync"] = (
400
441
  "azure.mgmt.core.polling.async_arm_polling.AsyncARMPolling"
401
- if azure_arm
442
+ if self.azure_arm
402
443
  else "azure.core.polling.async_base_polling.AsyncLROBasePolling"
403
444
  )
404
445
 
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@autorest/python",
3
- "version": "6.7.3",
3
+ "version": "6.7.5",
4
4
  "description": "The Python extension for generators in AutoRest.",
5
5
  "main": "index.js",
6
6
  "repository": {