@azure-tools/typespec-python 0.29.0 → 0.31.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. package/dist/scripts/eng/format.d.ts +2 -0
  2. package/dist/scripts/eng/format.d.ts.map +1 -0
  3. package/dist/scripts/eng/format.js +4 -0
  4. package/dist/scripts/eng/format.js.map +1 -0
  5. package/dist/scripts/eng/lint.d.ts +5 -0
  6. package/dist/scripts/eng/lint.d.ts.map +1 -0
  7. package/dist/scripts/eng/lint.js +66 -0
  8. package/dist/scripts/eng/lint.js.map +1 -0
  9. package/dist/scripts/{regenerate.d.ts.map → eng/regenerate.d.ts.map} +1 -1
  10. package/dist/scripts/{regenerate.js → eng/regenerate.js} +27 -29
  11. package/dist/scripts/eng/regenerate.js.map +1 -0
  12. package/dist/scripts/{run-tests.d.ts.map → eng/run-tests.d.ts.map} +1 -1
  13. package/dist/scripts/eng/run-tests.js +80 -0
  14. package/dist/scripts/eng/run-tests.js.map +1 -0
  15. package/dist/scripts/eng/utils.d.ts +3 -0
  16. package/dist/scripts/eng/utils.d.ts.map +1 -0
  17. package/dist/scripts/eng/utils.js +39 -0
  18. package/dist/scripts/eng/utils.js.map +1 -0
  19. package/dist/scripts/system-requirements.d.ts.map +1 -1
  20. package/dist/scripts/system-requirements.js.map +1 -1
  21. package/dist/src/code-model.d.ts.map +1 -1
  22. package/dist/src/code-model.js +4 -1
  23. package/dist/src/code-model.js.map +1 -1
  24. package/dist/src/emitter.d.ts.map +1 -1
  25. package/dist/src/emitter.js +17 -4
  26. package/dist/src/emitter.js.map +1 -1
  27. package/dist/src/external-process.d.ts +0 -1
  28. package/dist/src/external-process.d.ts.map +1 -1
  29. package/dist/src/http.js +14 -8
  30. package/dist/src/http.js.map +1 -1
  31. package/dist/src/lib.d.ts +1 -0
  32. package/dist/src/lib.d.ts.map +1 -1
  33. package/dist/src/lib.js +1 -0
  34. package/dist/src/lib.js.map +1 -1
  35. package/dist/src/types.d.ts +1 -1
  36. package/dist/src/types.d.ts.map +1 -1
  37. package/dist/src/types.js +82 -9
  38. package/dist/src/types.js.map +1 -1
  39. package/dist/src/utils.d.ts +2 -2
  40. package/dist/src/utils.d.ts.map +1 -1
  41. package/dist/src/utils.js +65 -10
  42. package/dist/src/utils.js.map +1 -1
  43. package/generator/pygen/black.py +2 -3
  44. package/generator/pygen/codegen/models/combined_type.py +1 -1
  45. package/generator/pygen/codegen/models/credential_types.py +7 -14
  46. package/generator/pygen/codegen/models/enum_type.py +1 -1
  47. package/generator/pygen/codegen/models/lro_operation.py +0 -1
  48. package/generator/pygen/codegen/models/lro_paging_operation.py +1 -1
  49. package/generator/pygen/codegen/models/model_type.py +5 -7
  50. package/generator/pygen/codegen/models/operation.py +13 -2
  51. package/generator/pygen/codegen/models/paging_operation.py +0 -1
  52. package/generator/pygen/codegen/models/parameter.py +5 -1
  53. package/generator/pygen/codegen/models/parameter_list.py +2 -5
  54. package/generator/pygen/codegen/models/primitive_types.py +11 -4
  55. package/generator/pygen/codegen/models/property.py +5 -1
  56. package/generator/pygen/codegen/serializers/__init__.py +1 -1
  57. package/generator/pygen/codegen/serializers/builder_serializer.py +22 -20
  58. package/generator/pygen/codegen/serializers/general_serializer.py +2 -1
  59. package/generator/pygen/codegen/serializers/model_serializer.py +3 -0
  60. package/generator/pygen/codegen/serializers/sample_serializer.py +1 -3
  61. package/generator/pygen/codegen/serializers/test_serializer.py +6 -0
  62. package/generator/pygen/codegen/templates/model_base.py.jinja2 +319 -67
  63. package/generator/pygen/codegen/templates/model_dpg.py.jinja2 +5 -0
  64. package/generator/pygen/codegen/templates/serialization.py.jinja2 +271 -162
  65. package/generator/pygen/codegen/templates/test.py.jinja2 +2 -2
  66. package/generator/pygen/m2r.py +1 -1
  67. package/generator/pygen/postprocess/__init__.py +2 -2
  68. package/generator/pygen/postprocess/venvtools.py +1 -3
  69. package/generator/pygen/preprocess/__init__.py +1 -1
  70. package/generator/pygen/utils.py +10 -3
  71. package/generator/setup.py +1 -1
  72. package/package.json +13 -10
  73. package/scripts/__pycache__/venvtools.cpython-310.pyc +0 -0
  74. package/scripts/eng/format.ts +5 -0
  75. package/scripts/eng/lint.ts +75 -0
  76. package/scripts/eng/mypy.ini +38 -0
  77. package/scripts/eng/pylintrc +58 -0
  78. package/scripts/eng/pyrightconfig.json +6 -0
  79. package/scripts/{regenerate.ts → eng/regenerate.ts} +40 -33
  80. package/scripts/eng/run-tests.ts +90 -0
  81. package/scripts/eng/utils.ts +38 -0
  82. package/scripts/run_tsp.py +7 -5
  83. package/dist/scripts/regenerate.js.map +0 -1
  84. package/dist/scripts/run-tests.js +0 -50
  85. package/dist/scripts/run-tests.js.map +0 -1
  86. package/scripts/run-tests.ts +0 -57
  87. /package/dist/scripts/{regenerate.d.ts → eng/regenerate.d.ts} +0 -0
  88. /package/dist/scripts/{run-tests.d.ts → eng/run-tests.d.ts} +0 -0
@@ -4,7 +4,7 @@
4
4
  # Licensed under the MIT License. See License.txt in the project root for
5
5
  # license information.
6
6
  # --------------------------------------------------------------------------
7
- # pylint: disable=protected-access, arguments-differ, signature-differs, broad-except
7
+ # pylint: disable=protected-access, arguments-differ, signature-differs, broad-except, too-many-lines
8
8
 
9
9
  import copy
10
10
  import calendar
@@ -19,6 +19,7 @@ import enum
19
19
  import email.utils
20
20
  from datetime import datetime, date, time, timedelta, timezone
21
21
  from json import JSONEncoder
22
+ import xml.etree.ElementTree as ET
22
23
  from typing_extensions import Self
23
24
  import isodate
24
25
  from {{ code_model.core_library }}.exceptions import DeserializationError
@@ -123,7 +124,7 @@ def _serialize_datetime(o, format: typing.Optional[str] = None):
123
124
 
124
125
  def _is_readonly(p):
125
126
  try:
126
- return p._visibility == ["read"] # pylint: disable=protected-access
127
+ return p._visibility == ["read"]
127
128
  except AttributeError:
128
129
  return False
129
130
 
@@ -286,6 +287,12 @@ def _deserialize_decimal(attr):
286
287
  return decimal.Decimal(str(attr))
287
288
 
288
289
 
290
+ def _deserialize_int_as_str(attr):
291
+ if isinstance(attr, int):
292
+ return attr
293
+ return int(attr)
294
+
295
+
289
296
  _DESERIALIZE_MAPPING = {
290
297
  datetime: _deserialize_datetime,
291
298
  date: _deserialize_date,
@@ -307,9 +314,11 @@ _DESERIALIZE_MAPPING_WITHFORMAT = {
307
314
 
308
315
 
309
316
  def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None):
317
+ if annotation is int and rf and rf._format == "str":
318
+ return _deserialize_int_as_str
310
319
  if rf and rf._format:
311
320
  return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format)
312
- return _DESERIALIZE_MAPPING.get(annotation)
321
+ return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore
313
322
 
314
323
 
315
324
  def _get_type_alias_type(module_name: str, alias_name: str):
@@ -324,17 +333,9 @@ def _get_type_alias_type(module_name: str, alias_name: str):
324
333
 
325
334
 
326
335
  def _get_model(module_name: str, model_name: str):
327
- models = {
328
- k: v
329
- for k, v in sys.modules[module_name].__dict__.items()
330
- if isinstance(v, type)
331
- }
336
+ models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)}
332
337
  module_end = module_name.rsplit(".", 1)[0]
333
- models.update({
334
- k: v
335
- for k, v in sys.modules[module_end].__dict__.items()
336
- if isinstance(v, type)
337
- })
338
+ models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)})
338
339
  if isinstance(model_name, str):
339
340
  model_name = model_name.split(".")[-1]
340
341
  if model_name not in models:
@@ -386,16 +387,13 @@ class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=uns
386
387
  return default
387
388
 
388
389
  @typing.overload
389
- def pop(self, key: str) -> typing.Any:
390
- ...
390
+ def pop(self, key: str) -> typing.Any: ...
391
391
 
392
392
  @typing.overload
393
- def pop(self, key: str, default: _T) -> _T:
394
- ...
393
+ def pop(self, key: str, default: _T) -> _T: ...
395
394
 
396
395
  @typing.overload
397
- def pop(self, key: str, default: typing.Any) -> typing.Any:
398
- ...
396
+ def pop(self, key: str, default: typing.Any) -> typing.Any: ...
399
397
 
400
398
  def pop(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
401
399
  if default is _UNSET:
@@ -412,12 +410,10 @@ class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=uns
412
410
  self._data.update(*args, **kwargs)
413
411
 
414
412
  @typing.overload
415
- def setdefault(self, key: str, default: None = None) -> None:
416
- ...
413
+ def setdefault(self, key: str, default: None = None) -> None: ...
417
414
 
418
415
  @typing.overload
419
- def setdefault(self, key: str, default: typing.Any) -> typing.Any:
420
- ...
416
+ def setdefault(self, key: str, default: typing.Any) -> typing.Any: ...
421
417
 
422
418
  def setdefault(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
423
419
  if default is _UNSET:
@@ -454,6 +450,10 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m
454
450
  return float(o)
455
451
  if isinstance(o, enum.Enum):
456
452
  return o.value
453
+ if isinstance(o, int):
454
+ if format == "str":
455
+ return str(o)
456
+ return o
457
457
  try:
458
458
  # First try datetime.datetime
459
459
  return _serialize_datetime(o, format)
@@ -484,6 +484,8 @@ def _create_value(rf: typing.Optional["_RestField"], value: typing.Any) -> typin
484
484
  return value
485
485
  if rf._is_model:
486
486
  return _deserialize(rf._type, value)
487
+ if isinstance(value, ET.Element):
488
+ value = _deserialize(rf._type, value)
487
489
  return _serialize(value, rf._format)
488
490
 
489
491
 
@@ -502,10 +504,58 @@ class Model(_MyMutableMapping):
502
504
  for rest_field in self._attr_to_rest_field.values()
503
505
  if rest_field._default is not _UNSET
504
506
  }
505
- if args:
506
- dict_to_pass.update(
507
- {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()}
508
- )
507
+ if args: # pylint: disable=too-many-nested-blocks
508
+ if isinstance(args[0], ET.Element):
509
+ existed_attr_keys = []
510
+ model_meta = getattr(self, "_xml", {})
511
+
512
+ for rf in self._attr_to_rest_field.values():
513
+ prop_meta = getattr(rf, "_xml", {})
514
+ xml_name = prop_meta.get("name", rf._rest_name)
515
+ xml_ns = prop_meta.get("ns", model_meta.get("ns", None))
516
+ if xml_ns:
517
+ xml_name = "{" + xml_ns + "}" + xml_name
518
+
519
+ # attribute
520
+ if prop_meta.get("attribute", False) and args[0].get(xml_name) is not None:
521
+ existed_attr_keys.append(xml_name)
522
+ dict_to_pass[rf._rest_name] = _deserialize(rf._type, args[0].get(xml_name))
523
+ continue
524
+
525
+ # unwrapped element is array
526
+ if prop_meta.get("unwrapped", False):
527
+ # unwrapped array could either use prop items meta/prop meta
528
+ if prop_meta.get("itemsName"):
529
+ xml_name = prop_meta.get("itemsName")
530
+ xml_ns = prop_meta.get("itemNs")
531
+ if xml_ns:
532
+ xml_name = "{" + xml_ns + "}" + xml_name
533
+ items = args[0].findall(xml_name) # pyright: ignore
534
+ if len(items) > 0:
535
+ existed_attr_keys.append(xml_name)
536
+ dict_to_pass[rf._rest_name] = _deserialize(rf._type, items)
537
+ continue
538
+
539
+ # text element is primitive type
540
+ if prop_meta.get("text", False):
541
+ if args[0].text is not None:
542
+ dict_to_pass[rf._rest_name] = _deserialize(rf._type, args[0].text)
543
+ continue
544
+
545
+ # wrapped element could be normal property or array, it should only have one element
546
+ item = args[0].find(xml_name)
547
+ if item is not None:
548
+ existed_attr_keys.append(xml_name)
549
+ dict_to_pass[rf._rest_name] = _deserialize(rf._type, item)
550
+
551
+ # rest thing is additional properties
552
+ for e in args[0]:
553
+ if e.tag not in existed_attr_keys:
554
+ dict_to_pass[e.tag] = _convert_element(e)
555
+ else:
556
+ dict_to_pass.update(
557
+ {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()}
558
+ )
509
559
  else:
510
560
  non_attr_kwargs = [k for k in kwargs if k not in self._attr_to_rest_field]
511
561
  if non_attr_kwargs:
@@ -525,7 +575,7 @@ class Model(_MyMutableMapping):
525
575
 
526
576
  def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: # pylint: disable=unused-argument
527
577
  if f"{cls.__module__}.{cls.__qualname__}" not in cls._calculated:
528
- # we know the last nine classes in mro are going to be 'Model', '_MyMutableMapping', 'MutableMapping',
578
+ # we know the last nine classes in mro are going to be 'Model', '_MyMutableMapping', 'MutableMapping',
529
579
  # 'Mapping', 'Collection', 'Sized', 'Iterable', 'Container' and 'object'
530
580
  mros = cls.__mro__[:-9][::-1] # ignore parents, and reverse the mro order
531
581
  attr_to_rest_field: typing.Dict[str, _RestField] = { # map attribute name to rest_field property
@@ -554,10 +604,10 @@ class Model(_MyMutableMapping):
554
604
  base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member
555
605
 
556
606
  @classmethod
557
- def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]:
607
+ def _get_discriminator(cls, exist_discriminators) -> typing.Optional["_RestField"]:
558
608
  for v in cls.__dict__.values():
559
- if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: # pylint: disable=protected-access
560
- return v._rest_name # pylint: disable=protected-access
609
+ if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators:
610
+ return v
561
611
  return None
562
612
 
563
613
  @classmethod
@@ -565,13 +615,25 @@ class Model(_MyMutableMapping):
565
615
  if not hasattr(cls, "__mapping__"): # pylint: disable=no-member
566
616
  return cls(data)
567
617
  discriminator = cls._get_discriminator(exist_discriminators)
568
- exist_discriminators.append(discriminator)
569
- mapped_cls = cls.__mapping__.get(
570
- data.get(discriminator), cls
571
- ) # pyright: ignore # pylint: disable=no-member
572
- if mapped_cls == cls:
618
+ if discriminator is None:
573
619
  return cls(data)
574
- return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access
620
+ exist_discriminators.append(discriminator._rest_name)
621
+ if isinstance(data, ET.Element):
622
+ model_meta = getattr(cls, "_xml", {})
623
+ prop_meta = getattr(discriminator, "_xml", {})
624
+ xml_name = prop_meta.get("name", discriminator._rest_name)
625
+ xml_ns = prop_meta.get("ns", model_meta.get("ns", None))
626
+ if xml_ns:
627
+ xml_name = "{" + xml_ns + "}" + xml_name
628
+
629
+ if data.get(xml_name) is not None:
630
+ discriminator_value = data.get(xml_name)
631
+ else:
632
+ discriminator_value = data.find(xml_name).text # pyright: ignore
633
+ else:
634
+ discriminator_value = data.get(discriminator._rest_name)
635
+ mapped_cls = cls.__mapping__.get(discriminator_value, cls) # pyright: ignore # pylint: disable=no-member
636
+ return mapped_cls._deserialize(data, exist_discriminators)
575
637
 
576
638
  def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]:
577
639
  """Return a dict that can be JSONify using json.dump.
@@ -590,7 +652,9 @@ class Model(_MyMutableMapping):
590
652
  continue
591
653
  is_multipart_file_input = False
592
654
  try:
593
- is_multipart_file_input = next(rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k)._is_multipart_file_input
655
+ is_multipart_file_input = next(
656
+ rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k
657
+ )._is_multipart_file_input
594
658
  except StopIteration:
595
659
  pass
596
660
  result[k] = v if is_multipart_file_input else Model._as_dict_value(v, exclude_readonly=exclude_readonly)
@@ -601,27 +665,24 @@ class Model(_MyMutableMapping):
601
665
  if v is None or isinstance(v, _Null):
602
666
  return None
603
667
  if isinstance(v, (list, tuple, set)):
604
- return type(v)(
605
- Model._as_dict_value(x, exclude_readonly=exclude_readonly)
606
- for x in v
607
- )
668
+ return type(v)(Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v)
608
669
  if isinstance(v, dict):
609
- return {
610
- dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly)
611
- for dk, dv in v.items()
612
- }
670
+ return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()}
613
671
  return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v
614
672
 
673
+
615
674
  def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj):
616
675
  if _is_model(obj):
617
676
  return obj
618
677
  return _deserialize(model_deserializer, obj)
619
678
 
679
+
620
680
  def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj):
621
681
  if obj is None:
622
682
  return obj
623
683
  return _deserialize_with_callable(if_obj_deserializer, obj)
624
684
 
685
+
625
686
  def _deserialize_with_union(deserializers, obj):
626
687
  for deserializer in deserializers:
627
688
  try:
@@ -630,6 +691,7 @@ def _deserialize_with_union(deserializers, obj):
630
691
  pass
631
692
  raise DeserializationError()
632
693
 
694
+
633
695
  def _deserialize_dict(
634
696
  value_deserializer: typing.Optional[typing.Callable],
635
697
  module: typing.Optional[str],
@@ -637,10 +699,10 @@ def _deserialize_dict(
637
699
  ):
638
700
  if obj is None:
639
701
  return obj
640
- return {
641
- k: _deserialize(value_deserializer, v, module)
642
- for k, v in obj.items()
643
- }
702
+ if isinstance(obj, ET.Element):
703
+ obj = {child.tag: child for child in obj}
704
+ return {k: _deserialize(value_deserializer, v, module) for k, v in obj.items()}
705
+
644
706
 
645
707
  def _deserialize_multiple_sequence(
646
708
  entry_deserializers: typing.List[typing.Optional[typing.Callable]],
@@ -649,10 +711,8 @@ def _deserialize_multiple_sequence(
649
711
  ):
650
712
  if obj is None:
651
713
  return obj
652
- return type(obj)(
653
- _deserialize(deserializer, entry, module)
654
- for entry, deserializer in zip(obj, entry_deserializers)
655
- )
714
+ return type(obj)(_deserialize(deserializer, entry, module) for entry, deserializer in zip(obj, entry_deserializers))
715
+
656
716
 
657
717
  def _deserialize_sequence(
658
718
  deserializer: typing.Optional[typing.Callable],
@@ -661,20 +721,24 @@ def _deserialize_sequence(
661
721
  ):
662
722
  if obj is None:
663
723
  return obj
724
+ if isinstance(obj, ET.Element):
725
+ obj = list(obj)
664
726
  return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)
665
727
 
728
+
666
729
  def _sorted_annotations(types: typing.List[typing.Any]) -> typing.List[typing.Any]:
667
730
  return sorted(
668
731
  types,
669
732
  key=lambda x: hasattr(x, "__name__") and x.__name__.lower() in ("str", "float", "int", "bool"),
670
733
  )
671
734
 
735
+
672
736
  def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912
673
737
  annotation: typing.Any,
674
738
  module: typing.Optional[str],
675
739
  rf: typing.Optional["_RestField"] = None,
676
740
  ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]:
677
- if not annotation or annotation in [int, float]:
741
+ if not annotation:
678
742
  return None
679
743
 
680
744
  # is it a type alias?
@@ -710,15 +774,15 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
710
774
  # is it optional?
711
775
  try:
712
776
  if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore
713
- if len(annotation.__args__) <= 2: # pyright: ignore
777
+ if len(annotation.__args__) <= 2: # pyright: ignore
714
778
  if_obj_deserializer = _get_deserialize_callable_from_annotation(
715
- next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
779
+ next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
716
780
  )
717
781
 
718
782
  return functools.partial(_deserialize_with_optional, if_obj_deserializer)
719
783
  # the type is Optional[Union[...]], we need to remove the None type from the Union
720
784
  annotation_copy = copy.copy(annotation)
721
- annotation_copy.__args__ = [a for a in annotation_copy.__args__ if a != type(None)] # pyright: ignore
785
+ annotation_copy.__args__ = [a for a in annotation_copy.__args__ if a != type(None)] # pyright: ignore
722
786
  return _get_deserialize_callable_from_annotation(annotation_copy, module, rf)
723
787
  except AttributeError:
724
788
  pass
@@ -739,7 +803,6 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
739
803
  annotation.__args__[1], module, rf # pyright: ignore
740
804
  )
741
805
 
742
-
743
806
  return functools.partial(
744
807
  _deserialize_dict,
745
808
  value_deserializer,
@@ -750,18 +813,15 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
750
813
  try:
751
814
  if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore
752
815
  if len(annotation.__args__) > 1: # pyright: ignore
753
-
754
-
755
816
  entry_deserializers = [
756
- _get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__ # pyright: ignore
817
+ _get_deserialize_callable_from_annotation(dt, module, rf)
818
+ for dt in annotation.__args__ # pyright: ignore
757
819
  ]
758
820
  return functools.partial(_deserialize_multiple_sequence, entry_deserializers, module)
759
821
  deserializer = _get_deserialize_callable_from_annotation(
760
822
  annotation.__args__[0], module, rf # pyright: ignore
761
823
  )
762
824
 
763
-
764
-
765
825
  return functools.partial(_deserialize_sequence, deserializer, module)
766
826
  except (TypeError, IndexError, AttributeError, SyntaxError):
767
827
  pass
@@ -787,12 +847,23 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
787
847
  def _deserialize_with_callable(
788
848
  deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]],
789
849
  value: typing.Any,
790
- ):
850
+ ): # pylint: disable=too-many-return-statements
791
851
  try:
792
852
  if value is None or isinstance(value, _Null):
793
853
  return None
854
+ if isinstance(value, ET.Element):
855
+ if deserializer is str:
856
+ return value.text or ""
857
+ if deserializer is int:
858
+ return int(value.text) if value.text else None
859
+ if deserializer is float:
860
+ return float(value.text) if value.text else None
861
+ if deserializer is bool:
862
+ return value.text == "true" if value.text else None
794
863
  if deserializer is None:
795
864
  return value
865
+ if deserializer in [int, float, bool]:
866
+ return deserializer(value)
796
867
  if isinstance(deserializer, CaseInsensitiveEnumMeta):
797
868
  try:
798
869
  return deserializer(value)
@@ -833,6 +904,7 @@ class _RestField:
833
904
  default: typing.Any = _UNSET,
834
905
  format: typing.Optional[str] = None,
835
906
  is_multipart_file_input: bool = False,
907
+ xml: typing.Optional[typing.Dict[str, typing.Any]] = None,
836
908
  ):
837
909
  self._type = type
838
910
  self._rest_name_input = name
@@ -843,6 +915,7 @@ class _RestField:
843
915
  self._default = default
844
916
  self._format = format
845
917
  self._is_multipart_file_input = is_multipart_file_input
918
+ self._xml = xml if xml is not None else {}
846
919
 
847
920
  @property
848
921
  def _class_type(self) -> typing.Any:
@@ -893,8 +966,17 @@ def rest_field(
893
966
  default: typing.Any = _UNSET,
894
967
  format: typing.Optional[str] = None,
895
968
  is_multipart_file_input: bool = False,
969
+ xml: typing.Optional[typing.Dict[str, typing.Any]] = None,
896
970
  ) -> typing.Any:
897
- return _RestField(name=name, type=type, visibility=visibility, default=default, format=format, is_multipart_file_input=is_multipart_file_input)
971
+ return _RestField(
972
+ name=name,
973
+ type=type,
974
+ visibility=visibility,
975
+ default=default,
976
+ format=format,
977
+ is_multipart_file_input=is_multipart_file_input,
978
+ xml=xml,
979
+ )
898
980
 
899
981
 
900
982
  def rest_discriminator(
@@ -902,5 +984,175 @@ def rest_discriminator(
902
984
  name: typing.Optional[str] = None,
903
985
  type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
904
986
  visibility: typing.Optional[typing.List[str]] = None,
987
+ xml: typing.Optional[typing.Dict[str, typing.Any]] = None,
905
988
  ) -> typing.Any:
906
- return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility)
989
+ return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility, xml=xml)
990
+
991
+
992
+ def serialize_xml(model: Model, exclude_readonly: bool = False) -> str:
993
+ """Serialize a model to XML.
994
+
995
+ :param Model model: The model to serialize.
996
+ :param bool exclude_readonly: Whether to exclude readonly properties.
997
+ :returns: The XML representation of the model.
998
+ :rtype: str
999
+ """
1000
+ return ET.tostring(_get_element(model, exclude_readonly), encoding="unicode") # type: ignore
1001
+
1002
+
1003
+ def _get_element(
1004
+ o: typing.Any,
1005
+ exclude_readonly: bool = False,
1006
+ parent_meta: typing.Optional[typing.Dict[str, typing.Any]] = None,
1007
+ wrapped_element: typing.Optional[ET.Element] = None,
1008
+ ) -> typing.Union[ET.Element, typing.List[ET.Element]]:
1009
+ if _is_model(o):
1010
+ model_meta = getattr(o, "_xml", {})
1011
+
1012
+ # if prop is a model, then use the prop element directly, else generate a wrapper of model
1013
+ if wrapped_element is None:
1014
+ wrapped_element = _create_xml_element(
1015
+ model_meta.get("name", o.__class__.__name__),
1016
+ model_meta.get("prefix"),
1017
+ model_meta.get("ns"),
1018
+ )
1019
+
1020
+ readonly_props = []
1021
+ if exclude_readonly:
1022
+ readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)]
1023
+
1024
+ for k, v in o.items():
1025
+ # do not serialize readonly properties
1026
+ if exclude_readonly and k in readonly_props:
1027
+ continue
1028
+
1029
+ prop_rest_field = _get_rest_field(o._attr_to_rest_field, k)
1030
+ if prop_rest_field:
1031
+ prop_meta = getattr(prop_rest_field, "_xml").copy()
1032
+ # use the wire name as xml name if no specific name is set
1033
+ if prop_meta.get("name") is None:
1034
+ prop_meta["name"] = k
1035
+ else:
1036
+ # additional properties will not have rest field, use the wire name as xml name
1037
+ prop_meta = {"name": k}
1038
+
1039
+ # if no ns for prop, use model's
1040
+ if prop_meta.get("ns") is None and model_meta.get("ns"):
1041
+ prop_meta["ns"] = model_meta.get("ns")
1042
+ prop_meta["prefix"] = model_meta.get("prefix")
1043
+
1044
+ if prop_meta.get("unwrapped", False):
1045
+ # unwrapped could only set on array
1046
+ wrapped_element.extend(_get_element(v, exclude_readonly, prop_meta))
1047
+ elif prop_meta.get("text", False):
1048
+ # text could only set on primitive type
1049
+ wrapped_element.text = _get_primitive_type_value(v)
1050
+ elif prop_meta.get("attribute", False):
1051
+ xml_name = prop_meta.get("name", k)
1052
+ if prop_meta.get("ns"):
1053
+ ET.register_namespace(prop_meta.get("prefix"), prop_meta.get("ns")) # pyright: ignore
1054
+ xml_name = "{" + prop_meta.get("ns") + "}" + xml_name # pyright: ignore
1055
+ # attribute should be primitive type
1056
+ wrapped_element.set(xml_name, _get_primitive_type_value(v))
1057
+ else:
1058
+ # other wrapped prop element
1059
+ wrapped_element.append(_get_wrapped_element(v, exclude_readonly, prop_meta))
1060
+ return wrapped_element
1061
+ if isinstance(o, list):
1062
+ return [_get_element(x, exclude_readonly, parent_meta) for x in o] # type: ignore
1063
+ if isinstance(o, dict):
1064
+ result = []
1065
+ for k, v in o.items():
1066
+ result.append(
1067
+ _get_wrapped_element(
1068
+ v,
1069
+ exclude_readonly,
1070
+ {
1071
+ "name": k,
1072
+ "ns": parent_meta.get("ns") if parent_meta else None,
1073
+ "prefix": parent_meta.get("prefix") if parent_meta else None,
1074
+ },
1075
+ )
1076
+ )
1077
+ return result
1078
+
1079
+ # primitive case need to create element based on parent_meta
1080
+ if parent_meta:
1081
+ return _get_wrapped_element(
1082
+ o,
1083
+ exclude_readonly,
1084
+ {
1085
+ "name": parent_meta.get("itemsName", parent_meta.get("name")),
1086
+ "prefix": parent_meta.get("itemsPrefix", parent_meta.get("prefix")),
1087
+ "ns": parent_meta.get("itemsNs", parent_meta.get("ns")),
1088
+ },
1089
+ )
1090
+
1091
+ raise ValueError("Could not serialize value into xml: " + o)
1092
+
1093
+
1094
+ def _get_wrapped_element(
1095
+ v: typing.Any,
1096
+ exclude_readonly: bool,
1097
+ meta: typing.Optional[typing.Dict[str, typing.Any]],
1098
+ ) -> ET.Element:
1099
+ wrapped_element = _create_xml_element(
1100
+ meta.get("name") if meta else None, meta.get("prefix") if meta else None, meta.get("ns") if meta else None
1101
+ )
1102
+ if isinstance(v, (dict, list)):
1103
+ wrapped_element.extend(_get_element(v, exclude_readonly, meta))
1104
+ elif _is_model(v):
1105
+ _get_element(v, exclude_readonly, meta, wrapped_element)
1106
+ else:
1107
+ wrapped_element.text = _get_primitive_type_value(v)
1108
+ return wrapped_element
1109
+
1110
+
1111
+ def _get_primitive_type_value(v) -> str:
1112
+ if v is True:
1113
+ return "true"
1114
+ if v is False:
1115
+ return "false"
1116
+ if isinstance(v, _Null):
1117
+ return ""
1118
+ return str(v)
1119
+
1120
+
1121
+ def _create_xml_element(tag, prefix=None, ns=None):
1122
+ if prefix and ns:
1123
+ ET.register_namespace(prefix, ns)
1124
+ if ns:
1125
+ return ET.Element("{" + ns + "}" + tag)
1126
+ return ET.Element(tag)
1127
+
1128
+
1129
+ def _deserialize_xml(
1130
+ deserializer: typing.Any,
1131
+ value: str,
1132
+ ) -> typing.Any:
1133
+ element = ET.fromstring(value)
1134
+ return _deserialize(deserializer, element)
1135
+
1136
+
1137
+ def _convert_element(e: ET.Element):
1138
+ # dict case
1139
+ if len(e.attrib) > 0 or len({child.tag for child in e}) > 1:
1140
+ dict_result: typing.Dict[str, typing.Any] = {}
1141
+ for child in e:
1142
+ if dict_result.get(child.tag) is not None:
1143
+ if isinstance(dict_result[child.tag], list):
1144
+ dict_result[child.tag].append(_convert_element(child))
1145
+ else:
1146
+ dict_result[child.tag] = [dict_result[child.tag], _convert_element(child)]
1147
+ else:
1148
+ dict_result[child.tag] = _convert_element(child)
1149
+ dict_result.update(e.attrib)
1150
+ return dict_result
1151
+ # array case
1152
+ if len(e) > 0:
1153
+ array_result: typing.List[typing.Any] = []
1154
+ for child in e:
1155
+ array_result.append(_convert_element(child))
1156
+ return array_result
1157
+ # primitive case
1158
+ return e.text
@@ -45,6 +45,11 @@
45
45
  __flattened_items = ["{{ model.flattened_items|join('\", \"') }}"]
46
46
  {% endif %}
47
47
 
48
+ {% if model.xml_metadata %}
49
+ _xml = {{model.xml_metadata}}
50
+ {% endif %}
51
+
52
+
48
53
  {% set need_init = (not model.internal) and (serializer.init_line(model) or model.discriminator)%}
49
54
  {% if need_init %}
50
55
  @overload