pyglove 0.4.5.dev202412130809__py3-none-any.whl → 0.4.5.dev202412150808__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyglove/core/object_utils/common_traits.py +2 -2
- pyglove/core/object_utils/json_conversion.py +1 -1
- pyglove/core/symbolic/base.py +33 -5
- pyglove/core/symbolic/dict.py +50 -38
- pyglove/core/symbolic/dict_test.py +71 -31
- pyglove/core/symbolic/object.py +11 -0
- pyglove/core/symbolic/object_test.py +21 -1
- pyglove/core/symbolic/ref.py +21 -8
- pyglove/core/symbolic/ref_test.py +17 -0
- pyglove/core/typing/annotation_conversion.py +5 -0
- pyglove/core/typing/annotation_conversion_test.py +8 -0
- pyglove/core/typing/class_schema.py +4 -0
- pyglove/core/typing/value_specs.py +79 -5
- pyglove/core/typing/value_specs_test.py +136 -12
- pyglove/core/views/html/tree_view.py +2 -2
- {pyglove-0.4.5.dev202412130809.dist-info → pyglove-0.4.5.dev202412150808.dist-info}/METADATA +1 -1
- {pyglove-0.4.5.dev202412130809.dist-info → pyglove-0.4.5.dev202412150808.dist-info}/RECORD +20 -20
- {pyglove-0.4.5.dev202412130809.dist-info → pyglove-0.4.5.dev202412150808.dist-info}/LICENSE +0 -0
- {pyglove-0.4.5.dev202412130809.dist-info → pyglove-0.4.5.dev202412150808.dist-info}/WHEEL +0 -0
- {pyglove-0.4.5.dev202412130809.dist-info → pyglove-0.4.5.dev202412150808.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ object, for example, partiality (MaybePartial), functor (Functor).
|
|
18
18
|
"""
|
19
19
|
|
20
20
|
import abc
|
21
|
-
from typing import Any, Dict, Optional
|
21
|
+
from typing import Any, Dict, Optional, Union
|
22
22
|
|
23
23
|
|
24
24
|
class MaybePartial(metaclass=abc.ABCMeta):
|
@@ -47,7 +47,7 @@ class MaybePartial(metaclass=abc.ABCMeta):
|
|
47
47
|
return len(self.missing_values()) > 0 # pylint: disable=g-explicit-length-test
|
48
48
|
|
49
49
|
@abc.abstractmethod
|
50
|
-
def missing_values(self, flatten: bool = True) -> Dict[str, Any]: # pylint: disable=redefined-outer-name
|
50
|
+
def missing_values(self, flatten: bool = True) -> Dict[Union[str, int], Any]: # pylint: disable=redefined-outer-name
|
51
51
|
"""Returns missing values from this object.
|
52
52
|
|
53
53
|
Args:
|
@@ -37,7 +37,7 @@ JSONPrimitiveType = Union[int, float, bool, str]
|
|
37
37
|
# pytype doesn't support recursion. Use Any instead of 'JSONValueType'
|
38
38
|
# in List and Dict.
|
39
39
|
JSONListType = List[Any]
|
40
|
-
JSONDictType = Dict[str, Any]
|
40
|
+
JSONDictType = Dict[Union[str, int], Any]
|
41
41
|
JSONValueType = Union[JSONPrimitiveType, JSONListType, JSONDictType]
|
42
42
|
|
43
43
|
# pylint: enable=invalid-name
|
pyglove/core/symbolic/base.py
CHANGED
@@ -310,7 +310,7 @@ class Symbolic(
|
|
310
310
|
"""Seals or unseals current object from further modification."""
|
311
311
|
return self._set_raw_attr('_sealed', is_seal)
|
312
312
|
|
313
|
-
def sym_missing(self, flatten: bool = True) -> Dict[str, Any]:
|
313
|
+
def sym_missing(self, flatten: bool = True) -> Dict[Union[str, int], Any]:
|
314
314
|
"""Returns missing values."""
|
315
315
|
missing = getattr(self, '_sym_missing_values')
|
316
316
|
if missing is None:
|
@@ -729,7 +729,7 @@ class Symbolic(
|
|
729
729
|
"""Returns if current object is deterministic."""
|
730
730
|
return is_deterministic(self)
|
731
731
|
|
732
|
-
def missing_values(self, flatten: bool = True) -> Dict[str, Any]:
|
732
|
+
def missing_values(self, flatten: bool = True) -> Dict[Union[str, int], Any]:
|
733
733
|
"""Alias for `sym_missing`."""
|
734
734
|
return self.sym_missing(flatten)
|
735
735
|
|
@@ -1094,7 +1094,7 @@ class Symbolic(
|
|
1094
1094
|
"""
|
1095
1095
|
|
1096
1096
|
@abc.abstractmethod
|
1097
|
-
def _sym_missing(self) -> Dict[str, Any]:
|
1097
|
+
def _sym_missing(self) -> Dict[Union[str, int], Any]:
|
1098
1098
|
"""Returns missing values."""
|
1099
1099
|
|
1100
1100
|
@abc.abstractmethod
|
@@ -2142,8 +2142,23 @@ def from_json_str(json_str: str,
|
|
2142
2142
|
Returns:
|
2143
2143
|
A deserialized value.
|
2144
2144
|
"""
|
2145
|
+
def _get_key(k: str) -> Union[str, int]:
|
2146
|
+
if k.startswith('n_:'):
|
2147
|
+
return int(k[3:])
|
2148
|
+
return k
|
2149
|
+
|
2150
|
+
def _decode_int_keys(v):
|
2151
|
+
if isinstance(v, dict):
|
2152
|
+
return {
|
2153
|
+
_get_key(k): _decode_int_keys(v)
|
2154
|
+
for k, v in v.items()
|
2155
|
+
}
|
2156
|
+
elif isinstance(v, list):
|
2157
|
+
return [_decode_int_keys(v) for v in v]
|
2158
|
+
return v
|
2159
|
+
|
2145
2160
|
return from_json(
|
2146
|
-
json.loads(json_str),
|
2161
|
+
_decode_int_keys(json.loads(json_str)),
|
2147
2162
|
allow_partial=allow_partial,
|
2148
2163
|
root_path=root_path,
|
2149
2164
|
auto_import=auto_import,
|
@@ -2217,7 +2232,20 @@ def to_json_str(value: Any,
|
|
2217
2232
|
Returns:
|
2218
2233
|
A JSON string.
|
2219
2234
|
"""
|
2220
|
-
|
2235
|
+
def _encode_int_keys(v):
|
2236
|
+
if isinstance(v, dict):
|
2237
|
+
return {
|
2238
|
+
f'n_:{k}' if isinstance(k, int) else k: _encode_int_keys(v)
|
2239
|
+
for k, v in v.items()
|
2240
|
+
}
|
2241
|
+
elif isinstance(v, list):
|
2242
|
+
return [
|
2243
|
+
_encode_int_keys(v) for v in v
|
2244
|
+
]
|
2245
|
+
return v
|
2246
|
+
return json.dumps(
|
2247
|
+
_encode_int_keys(to_json(value, **kwargs)), indent=json_indent
|
2248
|
+
)
|
2221
2249
|
|
2222
2250
|
|
2223
2251
|
def load(path: str, *args, **kwargs) -> Any:
|
pyglove/core/symbolic/dict.py
CHANGED
@@ -97,7 +97,7 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
97
97
|
|
98
98
|
@classmethod
|
99
99
|
def partial(cls,
|
100
|
-
dict_obj: Optional[typing.Dict[str, Any]] = None,
|
100
|
+
dict_obj: Optional[typing.Dict[Union[str, int], Any]] = None,
|
101
101
|
value_spec: Optional[pg_typing.Dict] = None,
|
102
102
|
*,
|
103
103
|
onchange_callback: Optional[Callable[
|
@@ -169,8 +169,8 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
169
169
|
def __init__(self,
|
170
170
|
dict_obj: Union[
|
171
171
|
None,
|
172
|
-
Iterable[Tuple[str, Any]],
|
173
|
-
typing.Dict[str, Any]] = None,
|
172
|
+
Iterable[Tuple[Union[str, int], Any]],
|
173
|
+
typing.Dict[Union[str, int], Any]] = None,
|
174
174
|
*,
|
175
175
|
value_spec: Optional[pg_typing.Dict] = None,
|
176
176
|
onchange_callback: Optional[Callable[
|
@@ -345,7 +345,7 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
345
345
|
updates.append(update)
|
346
346
|
return updates
|
347
347
|
|
348
|
-
def _sym_missing(self) -> typing.Dict[str, Any]:
|
348
|
+
def _sym_missing(self) -> typing.Dict[Union[str, int], Any]:
|
349
349
|
"""Returns missing values.
|
350
350
|
|
351
351
|
Returns:
|
@@ -375,7 +375,7 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
375
375
|
missing[k] = missing_child
|
376
376
|
return missing
|
377
377
|
|
378
|
-
def _sym_nondefault(self) -> typing.Dict[str, Any]:
|
378
|
+
def _sym_nondefault(self) -> typing.Dict[Union[str, int], Any]:
|
379
379
|
"""Returns non-default values as key/value pairs in a dict."""
|
380
380
|
non_defaults = dict()
|
381
381
|
if self._value_spec is not None and self._value_spec.schema:
|
@@ -444,7 +444,7 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
444
444
|
"""Tests if a symbolic attribute exists."""
|
445
445
|
return key in self
|
446
446
|
|
447
|
-
def sym_keys(self) -> Iterator[str]:
|
447
|
+
def sym_keys(self) -> Iterator[Union[str, int]]:
|
448
448
|
"""Iterates the keys of symbolic attributes."""
|
449
449
|
if self._value_spec is None or self._value_spec.schema is None:
|
450
450
|
for key in super().__iter__():
|
@@ -467,7 +467,7 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
467
467
|
yield self._sym_getattr(k)
|
468
468
|
|
469
469
|
def sym_items(self) -> Iterator[
|
470
|
-
Tuple[str, Any]]:
|
470
|
+
Tuple[Union[str, int], Any]]:
|
471
471
|
"""Iterates the (key, value) pairs of symbolic attributes."""
|
472
472
|
for k in self.sym_keys():
|
473
473
|
yield k, self._sym_getattr(k)
|
@@ -490,7 +490,7 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
490
490
|
if v != pg_typing.MISSING_VALUE])))
|
491
491
|
|
492
492
|
def _sym_getattr( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
|
493
|
-
self, key: str) -> Any:
|
493
|
+
self, key: Union[str, int]) -> Any:
|
494
494
|
"""Gets symbolic attribute by key."""
|
495
495
|
return super().__getitem__(key)
|
496
496
|
|
@@ -524,11 +524,11 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
524
524
|
v.sym_setpath(object_utils.KeyPath(k, new_path))
|
525
525
|
|
526
526
|
def _set_item_without_permission_check( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
|
527
|
-
self, key: str, value: Any) -> Optional[base.FieldUpdate]:
|
527
|
+
self, key: Union[str, int], value: Any) -> Optional[base.FieldUpdate]:
|
528
528
|
"""Set item without permission check."""
|
529
|
-
if not isinstance(key, str):
|
529
|
+
if not isinstance(key, (str, int)):
|
530
530
|
raise KeyError(self._error_message(
|
531
|
-
f'Key must be string type. Encountered {key!r}.'))
|
531
|
+
f'Key must be string or int type. Encountered {key!r}.'))
|
532
532
|
|
533
533
|
old_value = self.get(key, pg_typing.MISSING_VALUE)
|
534
534
|
if old_value is value:
|
@@ -545,7 +545,7 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
545
545
|
container_cls = self.__class__
|
546
546
|
raise KeyError(
|
547
547
|
self._error_message(
|
548
|
-
f'Key
|
548
|
+
f'Key {key!r} is not allowed for {container_cls}.'))
|
549
549
|
|
550
550
|
# Detach old value from object tree.
|
551
551
|
if isinstance(old_value, base.TopologyAware):
|
@@ -575,9 +575,11 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
575
575
|
return base.FieldUpdate(
|
576
576
|
self.sym_path + key, target, field, old_value, new_value)
|
577
577
|
|
578
|
-
def _formalized_value(
|
579
|
-
|
580
|
-
|
578
|
+
def _formalized_value(
|
579
|
+
self, name: Union[str, int],
|
580
|
+
field: Optional[pg_typing.Field],
|
581
|
+
value: Any
|
582
|
+
) -> Any:
|
581
583
|
"""Get transformed (formal) value from user input."""
|
582
584
|
allow_partial = base.accepts_partial(self)
|
583
585
|
if field and pg_typing.MISSING_VALUE == value:
|
@@ -625,14 +627,14 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
625
627
|
"""Customizes pickle.load."""
|
626
628
|
self.__init__(state['value'], **state['kwargs'])
|
627
629
|
|
628
|
-
def __getitem__(self, key: str) -> Any:
|
630
|
+
def __getitem__(self, key: Union[str, int]) -> Any:
|
629
631
|
"""Get item in this Dict."""
|
630
632
|
try:
|
631
633
|
return self.sym_inferred(key)
|
632
634
|
except AttributeError as e:
|
633
635
|
raise KeyError(key) from e
|
634
636
|
|
635
|
-
def __setitem__(self, key: str, value: Any) -> None:
|
637
|
+
def __setitem__(self, key: Union[str, int], value: Any) -> None:
|
636
638
|
"""Set item in this Dict.
|
637
639
|
|
638
640
|
Args:
|
@@ -733,11 +735,11 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
733
735
|
"""Iterate keys in field declaration order."""
|
734
736
|
return self.sym_keys()
|
735
737
|
|
736
|
-
def keys(self) -> Iterator[str]: # pytype: disable=signature-mismatch
|
738
|
+
def keys(self) -> Iterator[Union[str, int]]: # pytype: disable=signature-mismatch
|
737
739
|
"""Returns an iterator of keys in current dict."""
|
738
740
|
return self.sym_keys()
|
739
741
|
|
740
|
-
def items(self) -> Iterator[Tuple[str, Any]]: # pytype: disable=signature-mismatch
|
742
|
+
def items(self) -> Iterator[Tuple[Union[str, int], Any]]: # pytype: disable=signature-mismatch
|
741
743
|
"""Returns an iterator of (key, value) items in current dict."""
|
742
744
|
return self.sym_items()
|
743
745
|
|
@@ -750,7 +752,7 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
750
752
|
return self.sym_clone(deep=False)
|
751
753
|
|
752
754
|
def pop(
|
753
|
-
self, key:
|
755
|
+
self, key: Union[str, int], default: Any = base.RAISE_IF_NOT_FOUND # pylint: disable=protected-access
|
754
756
|
) -> Any:
|
755
757
|
"""Pops a key from current dict."""
|
756
758
|
if key in self:
|
@@ -762,7 +764,7 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
762
764
|
raise KeyError(key)
|
763
765
|
return default
|
764
766
|
|
765
|
-
def popitem(self) -> Tuple[str, Any]:
|
767
|
+
def popitem(self) -> Tuple[Union[str, int], Any]:
|
766
768
|
if self._value_spec is not None:
|
767
769
|
raise ValueError(
|
768
770
|
'\'popitem\' cannot be performed on a Dict with value spec.')
|
@@ -781,7 +783,7 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
781
783
|
if value_spec:
|
782
784
|
self.use_value_spec(value_spec, self._allow_partial)
|
783
785
|
|
784
|
-
def setdefault(self, key: str, default: Any = None) -> Any:
|
786
|
+
def setdefault(self, key: Union[str, int], default: Any = None) -> Any:
|
785
787
|
"""Sets default as the value to key if not present."""
|
786
788
|
value = pg_typing.MISSING_VALUE
|
787
789
|
if key in self:
|
@@ -791,12 +793,15 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
791
793
|
value = default
|
792
794
|
return value
|
793
795
|
|
794
|
-
def update(
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
796
|
+
def update(
|
797
|
+
self,
|
798
|
+
other: Union[
|
799
|
+
None,
|
800
|
+
typing.Dict[Union[str, int], Any],
|
801
|
+
Iterable[Tuple[Union[str, int], Any]]
|
802
|
+
] = None,
|
803
|
+
**kwargs
|
804
|
+
) -> None: # pytype: disable=signature-mismatch
|
800
805
|
"""Update Dict with the same semantic as update on standard dict."""
|
801
806
|
updates = dict(other) if other else {}
|
802
807
|
updates.update(kwargs)
|
@@ -807,9 +812,10 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
807
812
|
self,
|
808
813
|
hide_frozen: bool = True,
|
809
814
|
hide_default_values: bool = False,
|
810
|
-
exclude_keys: Optional[Sequence[str]] = None,
|
815
|
+
exclude_keys: Optional[Sequence[Union[str, int]]] = None,
|
811
816
|
use_inferred: bool = False,
|
812
|
-
**kwargs
|
817
|
+
**kwargs
|
818
|
+
) -> object_utils.JSONValueType:
|
813
819
|
"""Converts current object to a dict with plain Python objects."""
|
814
820
|
exclude_keys = set(exclude_keys or [])
|
815
821
|
if self._value_spec and self._value_spec.schema:
|
@@ -896,8 +902,8 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
896
902
|
hide_frozen: bool = True,
|
897
903
|
hide_default_values: bool = False,
|
898
904
|
hide_missing_values: bool = False,
|
899
|
-
include_keys: Optional[Set[str]] = None,
|
900
|
-
exclude_keys: Optional[Set[str]] = None,
|
905
|
+
include_keys: Optional[Set[Union[str, int]]] = None,
|
906
|
+
exclude_keys: Optional[Set[Union[str, int]]] = None,
|
901
907
|
use_inferred: bool = False,
|
902
908
|
cls_name: Optional[str] = None,
|
903
909
|
bracket_type: object_utils.BracketType = object_utils.BracketType.CURLY,
|
@@ -963,10 +969,12 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
963
969
|
extra_blankline_for_field_docstr=extra_blankline_for_field_docstr,
|
964
970
|
**kwargs)
|
965
971
|
if not python_format or key_as_attribute:
|
966
|
-
|
972
|
+
if isinstance(k, int):
|
973
|
+
k = f'[{k}]'
|
974
|
+
item = f'{k}={v_str}'
|
967
975
|
else:
|
968
|
-
|
969
|
-
|
976
|
+
item = f'{k!r}: {v_str}'
|
977
|
+
kv_strs.append(item)
|
970
978
|
s.append(', '.join(kv_strs))
|
971
979
|
s.append(close_bracket)
|
972
980
|
else:
|
@@ -993,15 +1001,19 @@ class Dict(dict, base.Symbolic, pg_typing.CustomTyping):
|
|
993
1001
|
use_inferred=use_inferred,
|
994
1002
|
extra_blankline_for_field_docstr=extra_blankline_for_field_docstr,
|
995
1003
|
**kwargs)
|
1004
|
+
|
996
1005
|
if not python_format:
|
997
1006
|
# Format in PyGlove's format (default).
|
998
|
-
|
1007
|
+
if isinstance(k, int):
|
1008
|
+
k = f'[{k}]'
|
1009
|
+
item = f'{k} = {v_str}'
|
999
1010
|
elif key_as_attribute:
|
1000
1011
|
# Format `pg.Objects` under Python format.
|
1001
|
-
|
1012
|
+
item = f'{k}={v_str}'
|
1002
1013
|
else:
|
1003
1014
|
# Format regular `pg.Dict` under Python format.
|
1004
|
-
|
1015
|
+
item = f'{k!r}: {v_str}'
|
1016
|
+
s.append(_indent(item, root_indent + 1))
|
1005
1017
|
s.append('\n')
|
1006
1018
|
s.append(_indent(close_bracket, root_indent))
|
1007
1019
|
return ''.join(s)
|
@@ -44,14 +44,14 @@ class DictTest(unittest.TestCase):
|
|
44
44
|
self.assertEqual(len(sd), 0)
|
45
45
|
|
46
46
|
# Schemaless dict created from a regular dict.
|
47
|
-
sd = Dict({'a': 1})
|
47
|
+
sd = Dict({'a': 1, 1: 1})
|
48
48
|
self.assertIsNone(sd.value_spec)
|
49
|
-
self.assertEqual(sd,
|
49
|
+
self.assertEqual(sd, {'a': 1, 1: 1})
|
50
50
|
|
51
51
|
# Schemaless dict created from key value pairs.
|
52
|
-
sd = Dict((('a', 1),))
|
52
|
+
sd = Dict((('a', 1), (1, 1)))
|
53
53
|
self.assertIsNone(sd.value_spec)
|
54
|
-
self.assertEqual(sd,
|
54
|
+
self.assertEqual(sd, {'a': 1, 1: 1})
|
55
55
|
|
56
56
|
# Schemaless dict created from keyword args.
|
57
57
|
sd = Dict(a=1)
|
@@ -59,9 +59,9 @@ class DictTest(unittest.TestCase):
|
|
59
59
|
self.assertEqual(sd, dict(a=1))
|
60
60
|
|
61
61
|
# Schemaless dict created from both a regular dict and keyword args.
|
62
|
-
sd = Dict({'a': 1}, a=2)
|
62
|
+
sd = Dict({'a': 1, 1: 1}, a=2)
|
63
63
|
self.assertIsNone(sd.value_spec)
|
64
|
-
self.assertEqual(sd,
|
64
|
+
self.assertEqual(sd, {'a': 2, 1: 1})
|
65
65
|
|
66
66
|
# Schematized dict.
|
67
67
|
vs = pg_typing.Dict([('a', pg_typing.Int())])
|
@@ -305,8 +305,8 @@ class DictTest(unittest.TestCase):
|
|
305
305
|
with self.assertRaisesRegex(
|
306
306
|
base.WritePermissionError, 'Cannot modify field of a sealed Dict.'):
|
307
307
|
sd['b'] = 1
|
308
|
-
with self.assertRaisesRegex(KeyError, 'Key must be string type'):
|
309
|
-
sd[0] = 1
|
308
|
+
with self.assertRaisesRegex(KeyError, 'Key must be string or int type'):
|
309
|
+
sd[0.5] = 1
|
310
310
|
|
311
311
|
# Set item in a schematized dict.
|
312
312
|
sd = Dict(value_spec=pg_typing.Dict([('a', pg_typing.Int(default=0))]))
|
@@ -786,35 +786,45 @@ class DictTest(unittest.TestCase):
|
|
786
786
|
sd.missing_values(flatten=False), {'a': MISSING_VALUE})
|
787
787
|
|
788
788
|
# A non-schematized dict has a schematized child.
|
789
|
-
sd = Dict(x
|
789
|
+
sd = Dict({'x': sd, 1: sd})
|
790
790
|
self.assertIsNone(sd.value_spec)
|
791
|
-
self.assertEqual(
|
791
|
+
self.assertEqual(
|
792
|
+
sd.missing_values(), {'x.a': MISSING_VALUE, '[1].a': MISSING_VALUE}
|
793
|
+
)
|
792
794
|
|
793
795
|
def test_sym_has(self):
|
794
|
-
sd = Dict(x
|
796
|
+
sd = Dict({'x': 1, 1: dict(a=3), 'y': Dict({'z': 2, 2: 3})})
|
795
797
|
self.assertTrue(sd.sym_has('x'))
|
798
|
+
self.assertTrue(sd.sym_has(1))
|
796
799
|
self.assertTrue(sd.sym_has('y.z'))
|
800
|
+
self.assertTrue(sd.sym_has('[1].a'))
|
801
|
+
self.assertTrue(sd.sym_has('y[2]'))
|
797
802
|
self.assertTrue(sd.sym_has(object_utils.KeyPath.parse('y.z')))
|
798
803
|
self.assertFalse(sd.sym_has('x.z'))
|
799
804
|
|
800
805
|
def test_sym_get(self):
|
801
|
-
sd = Dict(x
|
806
|
+
sd = Dict({'x': 1, 1: dict(a=3), 'y': Dict({'z': 2, 2: 3})})
|
802
807
|
self.assertEqual(sd.sym_get('x'), 1)
|
808
|
+
self.assertEqual(sd.sym_get(1), dict(a=3))
|
803
809
|
self.assertEqual(sd.sym_get('y.z'), 2)
|
810
|
+
self.assertEqual(sd.sym_get('[1].a'), 3)
|
811
|
+
self.assertEqual(sd.sym_get('y[2]'), 3)
|
804
812
|
self.assertIsNone(sd.sym_get('x.z', None))
|
805
813
|
with self.assertRaisesRegex(
|
806
814
|
KeyError, 'Cannot query sub-key \'z\' of object.'):
|
807
815
|
sd.sym_get('x.z')
|
808
816
|
|
809
817
|
def test_sym_hasattr(self):
|
810
|
-
sd = Dict(x
|
818
|
+
sd = Dict({'x': 1, 1: dict(a=3), 'y': Dict({'z': 2, 2: 3})})
|
811
819
|
self.assertTrue(sd.sym_hasattr('x'))
|
820
|
+
self.assertTrue(sd.sym_hasattr(1))
|
812
821
|
self.assertFalse(sd.sym_hasattr('y.z'))
|
813
822
|
self.assertFalse(sd.sym_hasattr('a'))
|
814
823
|
|
815
824
|
def test_sym_getattr(self):
|
816
|
-
sd = Dict(x
|
825
|
+
sd = Dict({'x': 1, 1: dict(a=3), 'y': Dict({'z': 2, 2: 3})})
|
817
826
|
self.assertEqual(sd.sym_getattr('x'), 1)
|
827
|
+
self.assertEqual(sd.sym_getattr(1), dict(a=3))
|
818
828
|
self.assertIsNone(sd.sym_getattr('a', None))
|
819
829
|
with self.assertRaisesRegex(
|
820
830
|
AttributeError,
|
@@ -832,8 +842,9 @@ class DictTest(unittest.TestCase):
|
|
832
842
|
with self.assertRaisesRegex(AttributeError, 'z'):
|
833
843
|
_ = sd.sym_inferred('z')
|
834
844
|
|
835
|
-
sd = Dict(y=1, x=
|
845
|
+
sd = Dict(y=1, x={'x': Dict(y=inferred.ValueFromParentChain()), 1: 2})
|
836
846
|
self.assertEqual(sd.x.x.y, 1)
|
847
|
+
self.assertEqual(sd.x[1], 2)
|
837
848
|
|
838
849
|
def test_sym_field(self):
|
839
850
|
sd = Dict(x=1, y=Dict(z=2))
|
@@ -859,9 +870,9 @@ class DictTest(unittest.TestCase):
|
|
859
870
|
self.assertIs(sd.sym_attr_field('x'), spec.schema.get_field('x'))
|
860
871
|
|
861
872
|
def test_sym_keys(self):
|
862
|
-
sd = Dict(x
|
873
|
+
sd = Dict({'x': 1, 'y': 2, 1: 3})
|
863
874
|
self.assertEqual(next(sd.sym_keys()), 'x')
|
864
|
-
self.assertEqual(list(sd.sym_keys()), ['x', 'y'])
|
875
|
+
self.assertEqual(list(sd.sym_keys()), ['x', 'y', 1])
|
865
876
|
|
866
877
|
sd = Dict(x=1, z=3, y=2, value_spec=pg_typing.Dict([
|
867
878
|
(pg_typing.StrKey(), pg_typing.Int())
|
@@ -874,9 +885,9 @@ class DictTest(unittest.TestCase):
|
|
874
885
|
self.assertEqual(list(sd.sym_keys()), ['x', 'y'])
|
875
886
|
|
876
887
|
def test_sym_values(self):
|
877
|
-
sd = Dict(x
|
888
|
+
sd = Dict({'x': 1, 'y': 2, 1: 3})
|
878
889
|
self.assertEqual(next(sd.sym_values()), 1)
|
879
|
-
self.assertEqual(list(sd.sym_values()), [1, 2])
|
890
|
+
self.assertEqual(list(sd.sym_values()), [1, 2, 3])
|
880
891
|
|
881
892
|
sd = Dict(x=1, z=3, y=2, value_spec=pg_typing.Dict([
|
882
893
|
(pg_typing.StrKey(), pg_typing.Int())
|
@@ -891,9 +902,9 @@ class DictTest(unittest.TestCase):
|
|
891
902
|
)
|
892
903
|
|
893
904
|
def test_sym_items(self):
|
894
|
-
sd = Dict(x
|
905
|
+
sd = Dict({'x': 1, 'y': 2, 1: 3})
|
895
906
|
self.assertEqual(next(sd.sym_items()), ('x', 1))
|
896
|
-
self.assertEqual(list(sd.sym_items()), [('x', 1), ('y', 2)])
|
907
|
+
self.assertEqual(list(sd.sym_items()), [('x', 1), ('y', 2), (1, 3)])
|
897
908
|
|
898
909
|
sd = Dict(x=1, z=3, y=2, value_spec=pg_typing.Dict([
|
899
910
|
(pg_typing.StrKey(), pg_typing.Int())
|
@@ -917,9 +928,9 @@ class DictTest(unittest.TestCase):
|
|
917
928
|
|
918
929
|
def test_sym_rebind(self):
|
919
930
|
# Refer to RebindTest for more detailed tests.
|
920
|
-
sd = Dict(x
|
921
|
-
sd.sym_rebind(x=2)
|
922
|
-
self.assertEqual(sd,
|
931
|
+
sd = Dict({'x': 1, 'y': 2, 1: 3})
|
932
|
+
sd.sym_rebind({1: 4}, x=2)
|
933
|
+
self.assertEqual(sd, {'x': 2, 'y': 2, 1: 4})
|
923
934
|
|
924
935
|
def test_sym_clone(self):
|
925
936
|
class A:
|
@@ -997,6 +1008,12 @@ class DictTest(unittest.TestCase):
|
|
997
1008
|
y: int = 1
|
998
1009
|
use_symbolic_comparison = True
|
999
1010
|
|
1011
|
+
sd = Dict({'x': A(2), 1: B(2)})
|
1012
|
+
self.assertEqual(sd.sym_nondefault(), {'x.x': 2, '[1].y': 2})
|
1013
|
+
|
1014
|
+
sd = Dict({'x': A(1), 1: B(1)})
|
1015
|
+
self.assertEqual(sd.sym_nondefault(), {'x.x': 1})
|
1016
|
+
|
1000
1017
|
sd = Dict(x=1, y=dict(a1=A(1)), value_spec=pg_typing.Dict([
|
1001
1018
|
('x', pg_typing.Int(default=0)),
|
1002
1019
|
('y', pg_typing.Dict([
|
@@ -1080,7 +1097,7 @@ class DictTest(unittest.TestCase):
|
|
1080
1097
|
self.assertTrue(Dict().sym_eq(Dict()))
|
1081
1098
|
self.assertTrue(base.eq(Dict(), Dict()))
|
1082
1099
|
|
1083
|
-
self.assertEqual(Dict(a
|
1100
|
+
self.assertEqual(Dict({'a': 1, 1: 2}), Dict({'a': 1, 1: 2}))
|
1084
1101
|
self.assertTrue(Dict(a=1).sym_eq(Dict(a=1)))
|
1085
1102
|
self.assertTrue(base.eq(Dict(a=1), Dict(a=1)))
|
1086
1103
|
self.assertTrue(
|
@@ -1130,6 +1147,8 @@ class DictTest(unittest.TestCase):
|
|
1130
1147
|
self.assertTrue(base.ne(Dict(), 1))
|
1131
1148
|
self.assertNotEqual(Dict(), Dict(a=1))
|
1132
1149
|
self.assertTrue(base.ne(Dict(), Dict(a=1)))
|
1150
|
+
self.assertNotEqual(Dict({1: 1}), Dict({'1': 1}))
|
1151
|
+
self.assertTrue(base.ne(Dict({1: 1}), Dict({'1': 1})))
|
1133
1152
|
self.assertNotEqual(Dict(a=0), Dict(a=1))
|
1134
1153
|
self.assertTrue(base.ne(Dict(a=0), Dict(a=1)))
|
1135
1154
|
|
@@ -1192,6 +1211,7 @@ class DictTest(unittest.TestCase):
|
|
1192
1211
|
self.assertEqual(hash(Dict(a=1)), hash(Dict(a=1)))
|
1193
1212
|
self.assertEqual(hash(Dict(a=dict(x=1))), hash(Dict(a=dict(x=1))))
|
1194
1213
|
self.assertNotEqual(hash(Dict()), hash(Dict(a=1)))
|
1214
|
+
self.assertNotEqual(hash(Dict({'1': 1})), hash(Dict({1: 1})))
|
1195
1215
|
self.assertNotEqual(hash(Dict(a=1)), hash(Dict(a=2)))
|
1196
1216
|
|
1197
1217
|
class A:
|
@@ -1845,10 +1865,6 @@ class RebindTest(unittest.TestCase):
|
|
1845
1865
|
ValueError, 'There are no values to rebind.'):
|
1846
1866
|
Dict().rebind({})
|
1847
1867
|
|
1848
|
-
with self.assertRaisesRegex(
|
1849
|
-
KeyError, 'Key must be string type. Encountered 1'):
|
1850
|
-
Dict().rebind({1: 1})
|
1851
|
-
|
1852
1868
|
with self.assertRaisesRegex(
|
1853
1869
|
ValueError, 'Required value is not specified.'):
|
1854
1870
|
Dict(a=1, value_spec=pg_typing.Dict([('a', pg_typing.Int())])).rebind({
|
@@ -1859,13 +1875,17 @@ class SerializationTest(unittest.TestCase):
|
|
1859
1875
|
"""Dedicated tests for `pg.Dict` serialization."""
|
1860
1876
|
|
1861
1877
|
def test_schemaless(self):
|
1862
|
-
sd = Dict()
|
1878
|
+
sd = Dict({1: 2})
|
1863
1879
|
sd.b = 0
|
1864
1880
|
sd.c = None
|
1865
1881
|
sd.a = 'foo'
|
1866
1882
|
|
1867
1883
|
# Key order is preserved.
|
1868
|
-
self.assertEqual(
|
1884
|
+
self.assertEqual(
|
1885
|
+
sd.to_json_str(),
|
1886
|
+
'{"n_:1": 2, "b": 0, "c": null, "a": "foo"}'
|
1887
|
+
)
|
1888
|
+
self.assertEqual(base.from_json_str(sd.to_json_str()), sd)
|
1869
1889
|
|
1870
1890
|
def test_schematized(self):
|
1871
1891
|
sd = Dict.partial(
|
@@ -2188,6 +2208,7 @@ class FormatTest(unittest.TestCase):
|
|
2188
2208
|
}"""))
|
2189
2209
|
|
2190
2210
|
def test_noncompact_verbose_with_extra_blankline_for_field_docstr(self):
|
2211
|
+
self.maxDiff = None
|
2191
2212
|
self.assertEqual(
|
2192
2213
|
self._dict.format(
|
2193
2214
|
compact=False, verbose=True, extra_blankline_for_field_docstr=True),
|
@@ -2306,6 +2327,25 @@ class FormatTest(unittest.TestCase):
|
|
2306
2327
|
'{\n x = 1,\n y = True,\n z = {\n v = 1,\n w = True\n }\n}'
|
2307
2328
|
)
|
2308
2329
|
|
2330
|
+
def test_compact_int_key(self):
|
2331
|
+
d = Dict({1: 1, '1': 2})
|
2332
|
+
self.assertEqual(d.format(compact=True), '{[1]=1, 1=2}')
|
2333
|
+
self.assertEqual(
|
2334
|
+
d.format(compact=True, python_format=True),
|
2335
|
+
'{1: 1, \'1\': 2}'
|
2336
|
+
)
|
2337
|
+
|
2338
|
+
def test_non_compact_int_key(self):
|
2339
|
+
d = Dict({1: 1, '1': 2})
|
2340
|
+
self.assertEqual(
|
2341
|
+
d.format(compact=False, verbose=False),
|
2342
|
+
'{\n [1] = 1,\n 1 = 2\n}'
|
2343
|
+
)
|
2344
|
+
self.assertEqual(
|
2345
|
+
d.format(compact=False, python_format=True),
|
2346
|
+
'{\n 1: 1,\n \'1\': 2\n}'
|
2347
|
+
)
|
2348
|
+
|
2309
2349
|
|
2310
2350
|
def _on_change_callback(updates):
|
2311
2351
|
del updates
|
pyglove/core/symbolic/object.py
CHANGED
@@ -164,11 +164,22 @@ class ObjectMeta(abc.ABCMeta):
|
|
164
164
|
if key is None:
|
165
165
|
continue
|
166
166
|
|
167
|
+
# Skip class-level attributes that are not symbolic fields.
|
168
|
+
if typing.get_origin(attr_annotation) is typing.ClassVar:
|
169
|
+
continue
|
170
|
+
|
167
171
|
field = pg_typing.Field.from_annotation(key, attr_annotation)
|
168
172
|
if isinstance(key, pg_typing.ConstStrKey):
|
169
173
|
attr_value = cls.__dict__.get(attr_name, pg_typing.MISSING_VALUE)
|
170
174
|
if attr_value != pg_typing.MISSING_VALUE:
|
171
175
|
field.value.set_default(attr_value)
|
176
|
+
|
177
|
+
if (field.value.frozen and
|
178
|
+
field.value.default is
|
179
|
+
pg_typing.value_specs._FROZEN_VALUE_PLACEHOLDER): # pylint: disable=protected-access
|
180
|
+
raise TypeError(
|
181
|
+
f'Field {field.key!r} is marked as final but has no default value.'
|
182
|
+
)
|
172
183
|
fields.append(field)
|
173
184
|
|
174
185
|
# Trigger event so subclass could modify the fields.
|
@@ -324,6 +324,26 @@ class ObjectTest(unittest.TestCase):
|
|
324
324
|
self.assertEqual(e.x, 1)
|
325
325
|
self.assertEqual(e.y, 3)
|
326
326
|
|
327
|
+
class F(Object):
|
328
|
+
x: typing.Literal[1, 'a']
|
329
|
+
|
330
|
+
class G(F):
|
331
|
+
x: typing.Final[int] = 1
|
332
|
+
y: typing.ClassVar[int] = 2
|
333
|
+
|
334
|
+
self.assertEqual(G().x, 1)
|
335
|
+
self.assertEqual(G.y, 2)
|
336
|
+
|
337
|
+
with self.assertRaisesRegex(
|
338
|
+
ValueError, 'Frozen field is not assignable'):
|
339
|
+
G(x=2)
|
340
|
+
|
341
|
+
with self.assertRaisesRegex(
|
342
|
+
TypeError, 'Field x is marked as final but has no default value'):
|
343
|
+
|
344
|
+
class H(Object): # pylint: disable=unused-variable
|
345
|
+
x: typing.Final[int] # pylint: disable=invalid-name
|
346
|
+
|
327
347
|
def test_init_arg_list(self):
|
328
348
|
|
329
349
|
def _update_init_arg_list(cls, init_arg_list):
|
@@ -2349,7 +2369,7 @@ class RebindTest(unittest.TestCase):
|
|
2349
2369
|
A(1).rebind({})
|
2350
2370
|
|
2351
2371
|
with self.assertRaisesRegex(
|
2352
|
-
KeyError, 'Key
|
2372
|
+
KeyError, 'Key 1 is not allowed for .*'):
|
2353
2373
|
A(1).rebind({1: 1})
|
2354
2374
|
|
2355
2375
|
with self.assertRaisesRegex(
|