@azure-tools/typespec-python 0.30.0 → 0.31.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/dist/scripts/eng/lint.js.map +1 -1
  2. package/dist/scripts/eng/regenerate.js +0 -6
  3. package/dist/scripts/eng/regenerate.js.map +1 -1
  4. package/dist/scripts/eng/run-tests.js +16 -7
  5. package/dist/scripts/eng/run-tests.js.map +1 -1
  6. package/dist/scripts/eng/utils.js.map +1 -1
  7. package/dist/scripts/system-requirements.d.ts.map +1 -1
  8. package/dist/scripts/system-requirements.js.map +1 -1
  9. package/dist/src/code-model.js.map +1 -1
  10. package/dist/src/emitter.d.ts.map +1 -1
  11. package/dist/src/emitter.js +3 -1
  12. package/dist/src/emitter.js.map +1 -1
  13. package/dist/src/external-process.d.ts +0 -1
  14. package/dist/src/external-process.d.ts.map +1 -1
  15. package/dist/src/http.js.map +1 -1
  16. package/dist/src/types.d.ts.map +1 -1
  17. package/dist/src/types.js +77 -4
  18. package/dist/src/types.js.map +1 -1
  19. package/dist/src/utils.d.ts.map +1 -1
  20. package/dist/src/utils.js +60 -5
  21. package/dist/src/utils.js.map +1 -1
  22. package/generator/pygen/codegen/models/operation.py +13 -2
  23. package/generator/pygen/codegen/models/parameter.py +5 -1
  24. package/generator/pygen/codegen/models/primitive_types.py +2 -1
  25. package/generator/pygen/codegen/models/property.py +5 -1
  26. package/generator/pygen/codegen/serializers/builder_serializer.py +12 -9
  27. package/generator/pygen/codegen/serializers/model_serializer.py +3 -0
  28. package/generator/pygen/codegen/serializers/test_serializer.py +6 -0
  29. package/generator/pygen/codegen/templates/model_base.py.jinja2 +315 -69
  30. package/generator/pygen/codegen/templates/model_dpg.py.jinja2 +5 -0
  31. package/generator/pygen/codegen/templates/serialization.py.jinja2 +271 -162
  32. package/generator/pygen/codegen/templates/test.py.jinja2 +2 -2
  33. package/generator/pygen/utils.py +9 -0
  34. package/package.json +6 -5
  35. package/scripts/__pycache__/venvtools.cpython-310.pyc +0 -0
  36. package/scripts/eng/regenerate.ts +0 -6
  37. package/scripts/eng/run-tests.ts +17 -7
@@ -4,7 +4,7 @@
4
4
  # Licensed under the MIT License. See License.txt in the project root for
5
5
  # license information.
6
6
  # --------------------------------------------------------------------------
7
- # pylint: disable=protected-access, arguments-differ, signature-differs, broad-except
7
+ # pylint: disable=protected-access, arguments-differ, signature-differs, broad-except, too-many-lines
8
8
 
9
9
  import copy
10
10
  import calendar
@@ -19,6 +19,7 @@ import enum
19
19
  import email.utils
20
20
  from datetime import datetime, date, time, timedelta, timezone
21
21
  from json import JSONEncoder
22
+ import xml.etree.ElementTree as ET
22
23
  from typing_extensions import Self
23
24
  import isodate
24
25
  from {{ code_model.core_library }}.exceptions import DeserializationError
@@ -123,7 +124,7 @@ def _serialize_datetime(o, format: typing.Optional[str] = None):
123
124
 
124
125
  def _is_readonly(p):
125
126
  try:
126
- return p._visibility == ["read"] # pylint: disable=protected-access
127
+ return p._visibility == ["read"]
127
128
  except AttributeError:
128
129
  return False
129
130
 
@@ -286,6 +287,12 @@ def _deserialize_decimal(attr):
286
287
  return decimal.Decimal(str(attr))
287
288
 
288
289
 
290
+ def _deserialize_int_as_str(attr):
291
+ if isinstance(attr, int):
292
+ return attr
293
+ return int(attr)
294
+
295
+
289
296
  _DESERIALIZE_MAPPING = {
290
297
  datetime: _deserialize_datetime,
291
298
  date: _deserialize_date,
@@ -307,9 +314,11 @@ _DESERIALIZE_MAPPING_WITHFORMAT = {
307
314
 
308
315
 
309
316
  def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None):
317
+ if annotation is int and rf and rf._format == "str":
318
+ return _deserialize_int_as_str
310
319
  if rf and rf._format:
311
320
  return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format)
312
- return _DESERIALIZE_MAPPING.get(annotation)
321
+ return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore
313
322
 
314
323
 
315
324
  def _get_type_alias_type(module_name: str, alias_name: str):
@@ -324,17 +333,9 @@ def _get_type_alias_type(module_name: str, alias_name: str):
324
333
 
325
334
 
326
335
  def _get_model(module_name: str, model_name: str):
327
- models = {
328
- k: v
329
- for k, v in sys.modules[module_name].__dict__.items()
330
- if isinstance(v, type)
331
- }
336
+ models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)}
332
337
  module_end = module_name.rsplit(".", 1)[0]
333
- models.update({
334
- k: v
335
- for k, v in sys.modules[module_end].__dict__.items()
336
- if isinstance(v, type)
337
- })
338
+ models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)})
338
339
  if isinstance(model_name, str):
339
340
  model_name = model_name.split(".")[-1]
340
341
  if model_name not in models:
@@ -386,16 +387,13 @@ class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=uns
386
387
  return default
387
388
 
388
389
  @typing.overload
389
- def pop(self, key: str) -> typing.Any:
390
- ...
390
+ def pop(self, key: str) -> typing.Any: ...
391
391
 
392
392
  @typing.overload
393
- def pop(self, key: str, default: _T) -> _T:
394
- ...
393
+ def pop(self, key: str, default: _T) -> _T: ...
395
394
 
396
395
  @typing.overload
397
- def pop(self, key: str, default: typing.Any) -> typing.Any:
398
- ...
396
+ def pop(self, key: str, default: typing.Any) -> typing.Any: ...
399
397
 
400
398
  def pop(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
401
399
  if default is _UNSET:
@@ -412,12 +410,10 @@ class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=uns
412
410
  self._data.update(*args, **kwargs)
413
411
 
414
412
  @typing.overload
415
- def setdefault(self, key: str, default: None = None) -> None:
416
- ...
413
+ def setdefault(self, key: str, default: None = None) -> None: ...
417
414
 
418
415
  @typing.overload
419
- def setdefault(self, key: str, default: typing.Any) -> typing.Any:
420
- ...
416
+ def setdefault(self, key: str, default: typing.Any) -> typing.Any: ...
421
417
 
422
418
  def setdefault(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
423
419
  if default is _UNSET:
@@ -488,6 +484,8 @@ def _create_value(rf: typing.Optional["_RestField"], value: typing.Any) -> typin
488
484
  return value
489
485
  if rf._is_model:
490
486
  return _deserialize(rf._type, value)
487
+ if isinstance(value, ET.Element):
488
+ value = _deserialize(rf._type, value)
491
489
  return _serialize(value, rf._format)
492
490
 
493
491
 
@@ -506,10 +504,58 @@ class Model(_MyMutableMapping):
506
504
  for rest_field in self._attr_to_rest_field.values()
507
505
  if rest_field._default is not _UNSET
508
506
  }
509
- if args:
510
- dict_to_pass.update(
511
- {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()}
512
- )
507
+ if args: # pylint: disable=too-many-nested-blocks
508
+ if isinstance(args[0], ET.Element):
509
+ existed_attr_keys = []
510
+ model_meta = getattr(self, "_xml", {})
511
+
512
+ for rf in self._attr_to_rest_field.values():
513
+ prop_meta = getattr(rf, "_xml", {})
514
+ xml_name = prop_meta.get("name", rf._rest_name)
515
+ xml_ns = prop_meta.get("ns", model_meta.get("ns", None))
516
+ if xml_ns:
517
+ xml_name = "{" + xml_ns + "}" + xml_name
518
+
519
+ # attribute
520
+ if prop_meta.get("attribute", False) and args[0].get(xml_name) is not None:
521
+ existed_attr_keys.append(xml_name)
522
+ dict_to_pass[rf._rest_name] = _deserialize(rf._type, args[0].get(xml_name))
523
+ continue
524
+
525
+ # unwrapped element is array
526
+ if prop_meta.get("unwrapped", False):
527
+ # unwrapped array could either use prop items meta/prop meta
528
+ if prop_meta.get("itemsName"):
529
+ xml_name = prop_meta.get("itemsName")
530
+ xml_ns = prop_meta.get("itemNs")
531
+ if xml_ns:
532
+ xml_name = "{" + xml_ns + "}" + xml_name
533
+ items = args[0].findall(xml_name) # pyright: ignore
534
+ if len(items) > 0:
535
+ existed_attr_keys.append(xml_name)
536
+ dict_to_pass[rf._rest_name] = _deserialize(rf._type, items)
537
+ continue
538
+
539
+ # text element is primitive type
540
+ if prop_meta.get("text", False):
541
+ if args[0].text is not None:
542
+ dict_to_pass[rf._rest_name] = _deserialize(rf._type, args[0].text)
543
+ continue
544
+
545
+ # wrapped element could be normal property or array, it should only have one element
546
+ item = args[0].find(xml_name)
547
+ if item is not None:
548
+ existed_attr_keys.append(xml_name)
549
+ dict_to_pass[rf._rest_name] = _deserialize(rf._type, item)
550
+
551
+ # rest thing is additional properties
552
+ for e in args[0]:
553
+ if e.tag not in existed_attr_keys:
554
+ dict_to_pass[e.tag] = _convert_element(e)
555
+ else:
556
+ dict_to_pass.update(
557
+ {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()}
558
+ )
513
559
  else:
514
560
  non_attr_kwargs = [k for k in kwargs if k not in self._attr_to_rest_field]
515
561
  if non_attr_kwargs:
@@ -529,7 +575,7 @@ class Model(_MyMutableMapping):
529
575
 
530
576
  def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: # pylint: disable=unused-argument
531
577
  if f"{cls.__module__}.{cls.__qualname__}" not in cls._calculated:
532
- # we know the last nine classes in mro are going to be 'Model', '_MyMutableMapping', 'MutableMapping',
578
+ # we know the last nine classes in mro are going to be 'Model', '_MyMutableMapping', 'MutableMapping',
533
579
  # 'Mapping', 'Collection', 'Sized', 'Iterable', 'Container' and 'object'
534
580
  mros = cls.__mro__[:-9][::-1] # ignore parents, and reverse the mro order
535
581
  attr_to_rest_field: typing.Dict[str, _RestField] = { # map attribute name to rest_field property
@@ -558,10 +604,10 @@ class Model(_MyMutableMapping):
558
604
  base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member
559
605
 
560
606
  @classmethod
561
- def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]:
607
+ def _get_discriminator(cls, exist_discriminators) -> typing.Optional["_RestField"]:
562
608
  for v in cls.__dict__.values():
563
- if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: # pylint: disable=protected-access
564
- return v._rest_name # pylint: disable=protected-access
609
+ if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators:
610
+ return v
565
611
  return None
566
612
 
567
613
  @classmethod
@@ -569,13 +615,25 @@ class Model(_MyMutableMapping):
569
615
  if not hasattr(cls, "__mapping__"): # pylint: disable=no-member
570
616
  return cls(data)
571
617
  discriminator = cls._get_discriminator(exist_discriminators)
572
- exist_discriminators.append(discriminator)
573
- mapped_cls = cls.__mapping__.get(
574
- data.get(discriminator), cls
575
- ) # pyright: ignore # pylint: disable=no-member
576
- if mapped_cls == cls:
618
+ if discriminator is None:
577
619
  return cls(data)
578
- return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access
620
+ exist_discriminators.append(discriminator._rest_name)
621
+ if isinstance(data, ET.Element):
622
+ model_meta = getattr(cls, "_xml", {})
623
+ prop_meta = getattr(discriminator, "_xml", {})
624
+ xml_name = prop_meta.get("name", discriminator._rest_name)
625
+ xml_ns = prop_meta.get("ns", model_meta.get("ns", None))
626
+ if xml_ns:
627
+ xml_name = "{" + xml_ns + "}" + xml_name
628
+
629
+ if data.get(xml_name) is not None:
630
+ discriminator_value = data.get(xml_name)
631
+ else:
632
+ discriminator_value = data.find(xml_name).text # pyright: ignore
633
+ else:
634
+ discriminator_value = data.get(discriminator._rest_name)
635
+ mapped_cls = cls.__mapping__.get(discriminator_value, cls) # pyright: ignore # pylint: disable=no-member
636
+ return mapped_cls._deserialize(data, exist_discriminators)
579
637
 
580
638
  def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]:
581
639
  """Return a dict that can be JSONify using json.dump.
@@ -594,7 +652,9 @@ class Model(_MyMutableMapping):
594
652
  continue
595
653
  is_multipart_file_input = False
596
654
  try:
597
- is_multipart_file_input = next(rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k)._is_multipart_file_input
655
+ is_multipart_file_input = next(
656
+ rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k
657
+ )._is_multipart_file_input
598
658
  except StopIteration:
599
659
  pass
600
660
  result[k] = v if is_multipart_file_input else Model._as_dict_value(v, exclude_readonly=exclude_readonly)
@@ -605,27 +665,24 @@ class Model(_MyMutableMapping):
605
665
  if v is None or isinstance(v, _Null):
606
666
  return None
607
667
  if isinstance(v, (list, tuple, set)):
608
- return type(v)(
609
- Model._as_dict_value(x, exclude_readonly=exclude_readonly)
610
- for x in v
611
- )
668
+ return type(v)(Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v)
612
669
  if isinstance(v, dict):
613
- return {
614
- dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly)
615
- for dk, dv in v.items()
616
- }
670
+ return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()}
617
671
  return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v
618
672
 
673
+
619
674
  def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj):
620
675
  if _is_model(obj):
621
676
  return obj
622
677
  return _deserialize(model_deserializer, obj)
623
678
 
679
+
624
680
  def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj):
625
681
  if obj is None:
626
682
  return obj
627
683
  return _deserialize_with_callable(if_obj_deserializer, obj)
628
684
 
685
+
629
686
  def _deserialize_with_union(deserializers, obj):
630
687
  for deserializer in deserializers:
631
688
  try:
@@ -634,6 +691,7 @@ def _deserialize_with_union(deserializers, obj):
634
691
  pass
635
692
  raise DeserializationError()
636
693
 
694
+
637
695
  def _deserialize_dict(
638
696
  value_deserializer: typing.Optional[typing.Callable],
639
697
  module: typing.Optional[str],
@@ -641,10 +699,10 @@ def _deserialize_dict(
641
699
  ):
642
700
  if obj is None:
643
701
  return obj
644
- return {
645
- k: _deserialize(value_deserializer, v, module)
646
- for k, v in obj.items()
647
- }
702
+ if isinstance(obj, ET.Element):
703
+ obj = {child.tag: child for child in obj}
704
+ return {k: _deserialize(value_deserializer, v, module) for k, v in obj.items()}
705
+
648
706
 
649
707
  def _deserialize_multiple_sequence(
650
708
  entry_deserializers: typing.List[typing.Optional[typing.Callable]],
@@ -653,10 +711,8 @@ def _deserialize_multiple_sequence(
653
711
  ):
654
712
  if obj is None:
655
713
  return obj
656
- return type(obj)(
657
- _deserialize(deserializer, entry, module)
658
- for entry, deserializer in zip(obj, entry_deserializers)
659
- )
714
+ return type(obj)(_deserialize(deserializer, entry, module) for entry, deserializer in zip(obj, entry_deserializers))
715
+
660
716
 
661
717
  def _deserialize_sequence(
662
718
  deserializer: typing.Optional[typing.Callable],
@@ -665,22 +721,24 @@ def _deserialize_sequence(
665
721
  ):
666
722
  if obj is None:
667
723
  return obj
724
+ if isinstance(obj, ET.Element):
725
+ obj = list(obj)
668
726
  return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)
669
727
 
728
+
670
729
  def _sorted_annotations(types: typing.List[typing.Any]) -> typing.List[typing.Any]:
671
730
  return sorted(
672
731
  types,
673
732
  key=lambda x: hasattr(x, "__name__") and x.__name__.lower() in ("str", "float", "int", "bool"),
674
733
  )
675
734
 
735
+
676
736
  def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912
677
737
  annotation: typing.Any,
678
738
  module: typing.Optional[str],
679
739
  rf: typing.Optional["_RestField"] = None,
680
740
  ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]:
681
- if not annotation or annotation in [int, float]:
682
- if annotation is int and rf and rf._format == "str":
683
- return int
741
+ if not annotation:
684
742
  return None
685
743
 
686
744
  # is it a type alias?
@@ -716,15 +774,15 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
716
774
  # is it optional?
717
775
  try:
718
776
  if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore
719
- if len(annotation.__args__) <= 2: # pyright: ignore
777
+ if len(annotation.__args__) <= 2: # pyright: ignore
720
778
  if_obj_deserializer = _get_deserialize_callable_from_annotation(
721
- next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
779
+ next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
722
780
  )
723
781
 
724
782
  return functools.partial(_deserialize_with_optional, if_obj_deserializer)
725
783
  # the type is Optional[Union[...]], we need to remove the None type from the Union
726
784
  annotation_copy = copy.copy(annotation)
727
- annotation_copy.__args__ = [a for a in annotation_copy.__args__ if a != type(None)] # pyright: ignore
785
+ annotation_copy.__args__ = [a for a in annotation_copy.__args__ if a != type(None)] # pyright: ignore
728
786
  return _get_deserialize_callable_from_annotation(annotation_copy, module, rf)
729
787
  except AttributeError:
730
788
  pass
@@ -745,7 +803,6 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
745
803
  annotation.__args__[1], module, rf # pyright: ignore
746
804
  )
747
805
 
748
-
749
806
  return functools.partial(
750
807
  _deserialize_dict,
751
808
  value_deserializer,
@@ -756,18 +813,15 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
756
813
  try:
757
814
  if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore
758
815
  if len(annotation.__args__) > 1: # pyright: ignore
759
-
760
-
761
816
  entry_deserializers = [
762
- _get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__ # pyright: ignore
817
+ _get_deserialize_callable_from_annotation(dt, module, rf)
818
+ for dt in annotation.__args__ # pyright: ignore
763
819
  ]
764
820
  return functools.partial(_deserialize_multiple_sequence, entry_deserializers, module)
765
821
  deserializer = _get_deserialize_callable_from_annotation(
766
822
  annotation.__args__[0], module, rf # pyright: ignore
767
823
  )
768
824
 
769
-
770
-
771
825
  return functools.partial(_deserialize_sequence, deserializer, module)
772
826
  except (TypeError, IndexError, AttributeError, SyntaxError):
773
827
  pass
@@ -793,12 +847,23 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
793
847
  def _deserialize_with_callable(
794
848
  deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]],
795
849
  value: typing.Any,
796
- ):
850
+ ): # pylint: disable=too-many-return-statements
797
851
  try:
798
852
  if value is None or isinstance(value, _Null):
799
853
  return None
854
+ if isinstance(value, ET.Element):
855
+ if deserializer is str:
856
+ return value.text or ""
857
+ if deserializer is int:
858
+ return int(value.text) if value.text else None
859
+ if deserializer is float:
860
+ return float(value.text) if value.text else None
861
+ if deserializer is bool:
862
+ return value.text == "true" if value.text else None
800
863
  if deserializer is None:
801
864
  return value
865
+ if deserializer in [int, float, bool]:
866
+ return deserializer(value)
802
867
  if isinstance(deserializer, CaseInsensitiveEnumMeta):
803
868
  try:
804
869
  return deserializer(value)
@@ -839,6 +904,7 @@ class _RestField:
839
904
  default: typing.Any = _UNSET,
840
905
  format: typing.Optional[str] = None,
841
906
  is_multipart_file_input: bool = False,
907
+ xml: typing.Optional[typing.Dict[str, typing.Any]] = None,
842
908
  ):
843
909
  self._type = type
844
910
  self._rest_name_input = name
@@ -849,6 +915,7 @@ class _RestField:
849
915
  self._default = default
850
916
  self._format = format
851
917
  self._is_multipart_file_input = is_multipart_file_input
918
+ self._xml = xml if xml is not None else {}
852
919
 
853
920
  @property
854
921
  def _class_type(self) -> typing.Any:
@@ -899,8 +966,17 @@ def rest_field(
899
966
  default: typing.Any = _UNSET,
900
967
  format: typing.Optional[str] = None,
901
968
  is_multipart_file_input: bool = False,
969
+ xml: typing.Optional[typing.Dict[str, typing.Any]] = None,
902
970
  ) -> typing.Any:
903
- return _RestField(name=name, type=type, visibility=visibility, default=default, format=format, is_multipart_file_input=is_multipart_file_input)
971
+ return _RestField(
972
+ name=name,
973
+ type=type,
974
+ visibility=visibility,
975
+ default=default,
976
+ format=format,
977
+ is_multipart_file_input=is_multipart_file_input,
978
+ xml=xml,
979
+ )
904
980
 
905
981
 
906
982
  def rest_discriminator(
@@ -908,5 +984,175 @@ def rest_discriminator(
908
984
  name: typing.Optional[str] = None,
909
985
  type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
910
986
  visibility: typing.Optional[typing.List[str]] = None,
987
+ xml: typing.Optional[typing.Dict[str, typing.Any]] = None,
911
988
  ) -> typing.Any:
912
- return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility)
989
+ return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility, xml=xml)
990
+
991
+
992
+ def serialize_xml(model: Model, exclude_readonly: bool = False) -> str:
993
+ """Serialize a model to XML.
994
+
995
+ :param Model model: The model to serialize.
996
+ :param bool exclude_readonly: Whether to exclude readonly properties.
997
+ :returns: The XML representation of the model.
998
+ :rtype: str
999
+ """
1000
+ return ET.tostring(_get_element(model, exclude_readonly), encoding="unicode") # type: ignore
1001
+
1002
+
1003
+ def _get_element(
1004
+ o: typing.Any,
1005
+ exclude_readonly: bool = False,
1006
+ parent_meta: typing.Optional[typing.Dict[str, typing.Any]] = None,
1007
+ wrapped_element: typing.Optional[ET.Element] = None,
1008
+ ) -> typing.Union[ET.Element, typing.List[ET.Element]]:
1009
+ if _is_model(o):
1010
+ model_meta = getattr(o, "_xml", {})
1011
+
1012
+ # if prop is a model, then use the prop element directly, else generate a wrapper of model
1013
+ if wrapped_element is None:
1014
+ wrapped_element = _create_xml_element(
1015
+ model_meta.get("name", o.__class__.__name__),
1016
+ model_meta.get("prefix"),
1017
+ model_meta.get("ns"),
1018
+ )
1019
+
1020
+ readonly_props = []
1021
+ if exclude_readonly:
1022
+ readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)]
1023
+
1024
+ for k, v in o.items():
1025
+ # do not serialize readonly properties
1026
+ if exclude_readonly and k in readonly_props:
1027
+ continue
1028
+
1029
+ prop_rest_field = _get_rest_field(o._attr_to_rest_field, k)
1030
+ if prop_rest_field:
1031
+ prop_meta = getattr(prop_rest_field, "_xml").copy()
1032
+ # use the wire name as xml name if no specific name is set
1033
+ if prop_meta.get("name") is None:
1034
+ prop_meta["name"] = k
1035
+ else:
1036
+ # additional properties will not have rest field, use the wire name as xml name
1037
+ prop_meta = {"name": k}
1038
+
1039
+ # if no ns for prop, use model's
1040
+ if prop_meta.get("ns") is None and model_meta.get("ns"):
1041
+ prop_meta["ns"] = model_meta.get("ns")
1042
+ prop_meta["prefix"] = model_meta.get("prefix")
1043
+
1044
+ if prop_meta.get("unwrapped", False):
1045
+ # unwrapped could only set on array
1046
+ wrapped_element.extend(_get_element(v, exclude_readonly, prop_meta))
1047
+ elif prop_meta.get("text", False):
1048
+ # text could only set on primitive type
1049
+ wrapped_element.text = _get_primitive_type_value(v)
1050
+ elif prop_meta.get("attribute", False):
1051
+ xml_name = prop_meta.get("name", k)
1052
+ if prop_meta.get("ns"):
1053
+ ET.register_namespace(prop_meta.get("prefix"), prop_meta.get("ns")) # pyright: ignore
1054
+ xml_name = "{" + prop_meta.get("ns") + "}" + xml_name # pyright: ignore
1055
+ # attribute should be primitive type
1056
+ wrapped_element.set(xml_name, _get_primitive_type_value(v))
1057
+ else:
1058
+ # other wrapped prop element
1059
+ wrapped_element.append(_get_wrapped_element(v, exclude_readonly, prop_meta))
1060
+ return wrapped_element
1061
+ if isinstance(o, list):
1062
+ return [_get_element(x, exclude_readonly, parent_meta) for x in o] # type: ignore
1063
+ if isinstance(o, dict):
1064
+ result = []
1065
+ for k, v in o.items():
1066
+ result.append(
1067
+ _get_wrapped_element(
1068
+ v,
1069
+ exclude_readonly,
1070
+ {
1071
+ "name": k,
1072
+ "ns": parent_meta.get("ns") if parent_meta else None,
1073
+ "prefix": parent_meta.get("prefix") if parent_meta else None,
1074
+ },
1075
+ )
1076
+ )
1077
+ return result
1078
+
1079
+ # primitive case need to create element based on parent_meta
1080
+ if parent_meta:
1081
+ return _get_wrapped_element(
1082
+ o,
1083
+ exclude_readonly,
1084
+ {
1085
+ "name": parent_meta.get("itemsName", parent_meta.get("name")),
1086
+ "prefix": parent_meta.get("itemsPrefix", parent_meta.get("prefix")),
1087
+ "ns": parent_meta.get("itemsNs", parent_meta.get("ns")),
1088
+ },
1089
+ )
1090
+
1091
+ raise ValueError("Could not serialize value into xml: " + o)
1092
+
1093
+
1094
+ def _get_wrapped_element(
1095
+ v: typing.Any,
1096
+ exclude_readonly: bool,
1097
+ meta: typing.Optional[typing.Dict[str, typing.Any]],
1098
+ ) -> ET.Element:
1099
+ wrapped_element = _create_xml_element(
1100
+ meta.get("name") if meta else None, meta.get("prefix") if meta else None, meta.get("ns") if meta else None
1101
+ )
1102
+ if isinstance(v, (dict, list)):
1103
+ wrapped_element.extend(_get_element(v, exclude_readonly, meta))
1104
+ elif _is_model(v):
1105
+ _get_element(v, exclude_readonly, meta, wrapped_element)
1106
+ else:
1107
+ wrapped_element.text = _get_primitive_type_value(v)
1108
+ return wrapped_element
1109
+
1110
+
1111
+ def _get_primitive_type_value(v) -> str:
1112
+ if v is True:
1113
+ return "true"
1114
+ if v is False:
1115
+ return "false"
1116
+ if isinstance(v, _Null):
1117
+ return ""
1118
+ return str(v)
1119
+
1120
+
1121
+ def _create_xml_element(tag, prefix=None, ns=None):
1122
+ if prefix and ns:
1123
+ ET.register_namespace(prefix, ns)
1124
+ if ns:
1125
+ return ET.Element("{" + ns + "}" + tag)
1126
+ return ET.Element(tag)
1127
+
1128
+
1129
+ def _deserialize_xml(
1130
+ deserializer: typing.Any,
1131
+ value: str,
1132
+ ) -> typing.Any:
1133
+ element = ET.fromstring(value)
1134
+ return _deserialize(deserializer, element)
1135
+
1136
+
1137
+ def _convert_element(e: ET.Element):
1138
+ # dict case
1139
+ if len(e.attrib) > 0 or len({child.tag for child in e}) > 1:
1140
+ dict_result: typing.Dict[str, typing.Any] = {}
1141
+ for child in e:
1142
+ if dict_result.get(child.tag) is not None:
1143
+ if isinstance(dict_result[child.tag], list):
1144
+ dict_result[child.tag].append(_convert_element(child))
1145
+ else:
1146
+ dict_result[child.tag] = [dict_result[child.tag], _convert_element(child)]
1147
+ else:
1148
+ dict_result[child.tag] = _convert_element(child)
1149
+ dict_result.update(e.attrib)
1150
+ return dict_result
1151
+ # array case
1152
+ if len(e) > 0:
1153
+ array_result: typing.List[typing.Any] = []
1154
+ for child in e:
1155
+ array_result.append(_convert_element(child))
1156
+ return array_result
1157
+ # primitive case
1158
+ return e.text
@@ -45,6 +45,11 @@
45
45
  __flattened_items = ["{{ model.flattened_items|join('\", \"') }}"]
46
46
  {% endif %}
47
47
 
48
+ {% if model.xml_metadata %}
49
+ _xml = {{model.xml_metadata}}
50
+ {% endif %}
51
+
52
+
48
53
  {% set need_init = (not model.internal) and (serializer.init_line(model) or model.discriminator)%}
49
54
  {% if need_init %}
50
55
  @overload