pyglove 0.4.5.dev202502040809__py3-none-any.whl → 0.4.5.dev202502050809__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyglove/core/symbolic/object.py +4 -1
- pyglove/core/typing/__init__.py +3 -0
- pyglove/core/typing/annotation_conversion.py +196 -6
- pyglove/core/typing/annotation_conversion_test.py +141 -8
- pyglove/core/typing/annotation_future_test.py +135 -0
- pyglove/core/typing/class_schema.py +44 -22
- pyglove/core/typing/class_schema_test.py +15 -7
- {pyglove-0.4.5.dev202502040809.dist-info → pyglove-0.4.5.dev202502050809.dist-info}/METADATA +1 -1
- {pyglove-0.4.5.dev202502040809.dist-info → pyglove-0.4.5.dev202502050809.dist-info}/RECORD +12 -11
- {pyglove-0.4.5.dev202502040809.dist-info → pyglove-0.4.5.dev202502050809.dist-info}/LICENSE +0 -0
- {pyglove-0.4.5.dev202502040809.dist-info → pyglove-0.4.5.dev202502050809.dist-info}/WHEEL +0 -0
- {pyglove-0.4.5.dev202502040809.dist-info → pyglove-0.4.5.dev202502050809.dist-info}/top_level.txt +0 -0
pyglove/core/symbolic/object.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
import abc
|
17
17
|
import functools
|
18
18
|
import inspect
|
19
|
+
import sys
|
19
20
|
import typing
|
20
21
|
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
|
21
22
|
|
@@ -154,7 +155,9 @@ class ObjectMeta(abc.ABCMeta):
|
|
154
155
|
if typing.get_origin(attr_annotation) is typing.ClassVar:
|
155
156
|
continue
|
156
157
|
|
157
|
-
field = pg_typing.Field.from_annotation(
|
158
|
+
field = pg_typing.Field.from_annotation(
|
159
|
+
key, attr_annotation, parent_module=sys.modules[cls.__module__]
|
160
|
+
)
|
158
161
|
if isinstance(key, pg_typing.ConstStrKey):
|
159
162
|
attr_value = cls.__dict__.get(attr_name, pg_typing.MISSING_VALUE)
|
160
163
|
if attr_value != pg_typing.MISSING_VALUE:
|
pyglove/core/typing/__init__.py
CHANGED
@@ -375,6 +375,9 @@ import pyglove.core.typing.annotation_conversion # pylint: disable=unused-impor
|
|
375
375
|
# Interface for custom typing.
|
376
376
|
from pyglove.core.typing.custom_typing import CustomTyping
|
377
377
|
|
378
|
+
# Annotation conversion
|
379
|
+
from pyglove.core.typing.annotation_conversion import annotation_from_str
|
380
|
+
|
378
381
|
# Callable signature.
|
379
382
|
from pyglove.core.typing.callable_signature import Argument
|
380
383
|
from pyglove.core.typing.callable_signature import CallableType
|
@@ -13,11 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Conversion from annotations to PyGlove value specs."""
|
15
15
|
|
16
|
+
import builtins
|
16
17
|
import collections
|
17
18
|
import inspect
|
18
19
|
import types
|
19
20
|
import typing
|
20
21
|
|
22
|
+
from pyglove.core import coding
|
21
23
|
from pyglove.core import utils
|
22
24
|
from pyglove.core.typing import annotated
|
23
25
|
from pyglove.core.typing import class_schema
|
@@ -35,6 +37,191 @@ _Annotated = getattr(typing, 'Annotated', None) # pylint: disable=invalid-name
|
|
35
37
|
_UnionType = getattr(types, 'UnionType', None) # pylint: disable=invalid-name
|
36
38
|
|
37
39
|
|
40
|
+
def annotation_from_str(
|
41
|
+
annotation_str: str,
|
42
|
+
parent_module: typing.Optional[types.ModuleType] = None,
|
43
|
+
) -> typing.Any:
|
44
|
+
"""Parses annotations from str.
|
45
|
+
|
46
|
+
BNF for PyType annotations:
|
47
|
+
|
48
|
+
```
|
49
|
+
<maybe_union> ::= <type> | <type> "|" <maybe_union>
|
50
|
+
<type> ::= <literal_type> | <non_literal_type>
|
51
|
+
|
52
|
+
<literal_type> ::= "Literal"<literal_params>
|
53
|
+
<literal_params> ::= "["<python_values>"]" (parsed by `pg.coding.evaluate`)
|
54
|
+
|
55
|
+
<non_literal_type> ::= <type_id> | <type_id>"["<type_arg>"]"
|
56
|
+
<type_arg> ::= <maybe_type_list> | <maybe_type_list>","<maybe_type_list>
|
57
|
+
<maybe_type_list> ::= "["<type_arg>"]" | <maybe_union>
|
58
|
+
<type_id> ::= 'aAz_.1-9'
|
59
|
+
```
|
60
|
+
|
61
|
+
Args:
|
62
|
+
annotation_str: String form of type annotations. E.g. "list[str]"
|
63
|
+
parent_module: The module where the annotation was defined.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
Object form of the annotation.
|
67
|
+
|
68
|
+
Raises:
|
69
|
+
SyntaxError: If the annotation string is invalid.
|
70
|
+
"""
|
71
|
+
s = annotation_str
|
72
|
+
context = dict(pos=0)
|
73
|
+
|
74
|
+
def _eof() -> bool:
|
75
|
+
return context['pos'] == len(s)
|
76
|
+
|
77
|
+
def _pos() -> int:
|
78
|
+
return context['pos']
|
79
|
+
|
80
|
+
def _next(n: int = 1, offset: int = 0) -> str:
|
81
|
+
if _eof():
|
82
|
+
return '<EOF>'
|
83
|
+
return s[_pos() + offset:_pos() + offset + n]
|
84
|
+
|
85
|
+
def _advance(n: int) -> None:
|
86
|
+
context['pos'] += n
|
87
|
+
|
88
|
+
def _error_illustration() -> str:
|
89
|
+
return f'{s}\n{" " * _pos()}' + '^'
|
90
|
+
|
91
|
+
def _match(ch) -> bool:
|
92
|
+
if _next(len(ch)) == ch:
|
93
|
+
_advance(len(ch))
|
94
|
+
return True
|
95
|
+
return False
|
96
|
+
|
97
|
+
def _skip_whitespaces() -> None:
|
98
|
+
while _next() in ' \t':
|
99
|
+
_advance(1)
|
100
|
+
|
101
|
+
def _maybe_union():
|
102
|
+
t = _type()
|
103
|
+
while not _eof():
|
104
|
+
_skip_whitespaces()
|
105
|
+
if _match('|'):
|
106
|
+
t = t | _type()
|
107
|
+
else:
|
108
|
+
break
|
109
|
+
return t
|
110
|
+
|
111
|
+
def _type():
|
112
|
+
type_id = _type_id()
|
113
|
+
t = _resolve(type_id)
|
114
|
+
if t is typing.Literal:
|
115
|
+
return t[_literal_params()]
|
116
|
+
elif _match('['):
|
117
|
+
arg = _type_arg()
|
118
|
+
if not _match(']'):
|
119
|
+
raise SyntaxError(
|
120
|
+
f'Expected "]" at position {_pos()}.\n\n' + _error_illustration()
|
121
|
+
)
|
122
|
+
return t[arg]
|
123
|
+
return t
|
124
|
+
|
125
|
+
def _literal_params():
|
126
|
+
if not _match('['):
|
127
|
+
raise SyntaxError(
|
128
|
+
f'Expected "[" at position {_pos()}.\n\n' + _error_illustration()
|
129
|
+
)
|
130
|
+
arg_start = _pos()
|
131
|
+
in_str = False
|
132
|
+
escape_mode = False
|
133
|
+
num_open_bracket = 1
|
134
|
+
|
135
|
+
while num_open_bracket > 0:
|
136
|
+
ch = _next()
|
137
|
+
if _eof():
|
138
|
+
raise SyntaxError(
|
139
|
+
f'Unexpected end of annotation at position {_pos()}.\n\n'
|
140
|
+
+ _error_illustration()
|
141
|
+
)
|
142
|
+
if ch == '\\':
|
143
|
+
escape_mode = not escape_mode
|
144
|
+
else:
|
145
|
+
escape_mode = False
|
146
|
+
|
147
|
+
if ch == "'" and not escape_mode:
|
148
|
+
in_str = not in_str
|
149
|
+
elif not in_str:
|
150
|
+
if ch == '[':
|
151
|
+
num_open_bracket += 1
|
152
|
+
elif ch == ']':
|
153
|
+
num_open_bracket -= 1
|
154
|
+
_advance(1)
|
155
|
+
|
156
|
+
arg_str = s[arg_start:_pos() - 1]
|
157
|
+
return coding.evaluate(
|
158
|
+
'(' + arg_str + ')', permission=coding.CodePermission.BASIC
|
159
|
+
)
|
160
|
+
|
161
|
+
def _type_arg():
|
162
|
+
t_args = []
|
163
|
+
t_args.append(_maybe_type_list())
|
164
|
+
while _match(','):
|
165
|
+
t_args.append(_maybe_type_list())
|
166
|
+
return tuple(t_args) if len(t_args) > 1 else t_args[0]
|
167
|
+
|
168
|
+
def _maybe_type_list():
|
169
|
+
if _match('['):
|
170
|
+
ret = _type_arg()
|
171
|
+
if not _match(']'):
|
172
|
+
raise SyntaxError(
|
173
|
+
f'Expected "]" at position {_pos()}.\n\n' + _error_illustration()
|
174
|
+
)
|
175
|
+
return list(ret) if isinstance(ret, tuple) else [ret]
|
176
|
+
return _maybe_union()
|
177
|
+
|
178
|
+
def _type_id() -> str:
|
179
|
+
_skip_whitespaces()
|
180
|
+
if _match('...'):
|
181
|
+
return '...'
|
182
|
+
start = _pos()
|
183
|
+
while not _eof():
|
184
|
+
c = _next()
|
185
|
+
if c.isalnum() or c in '_.':
|
186
|
+
_advance(1)
|
187
|
+
else:
|
188
|
+
break
|
189
|
+
t_id = s[start:_pos()]
|
190
|
+
if not all(x.isidentifier() for x in t_id.split('.')):
|
191
|
+
raise SyntaxError(
|
192
|
+
f'Expected type identifier, got {t_id!r} at position {start}.\n\n'
|
193
|
+
+ _error_illustration()
|
194
|
+
)
|
195
|
+
return t_id
|
196
|
+
|
197
|
+
def _resolve(type_id: str):
|
198
|
+
def _resolve_name(name: str, parent_obj: typing.Any):
|
199
|
+
if name == 'None':
|
200
|
+
return None
|
201
|
+
if parent_obj is not None and hasattr(parent_obj, name):
|
202
|
+
return getattr(parent_obj, name)
|
203
|
+
if hasattr(builtins, name):
|
204
|
+
return getattr(builtins, name)
|
205
|
+
if type_id == '...':
|
206
|
+
return ...
|
207
|
+
return utils.MISSING_VALUE
|
208
|
+
parent_obj = parent_module
|
209
|
+
for name in type_id.split('.'):
|
210
|
+
parent_obj = _resolve_name(name, parent_obj)
|
211
|
+
if parent_obj == utils.MISSING_VALUE:
|
212
|
+
return typing.ForwardRef( # pytype: disable=not-callable
|
213
|
+
type_id, False, parent_module
|
214
|
+
)
|
215
|
+
return parent_obj
|
216
|
+
|
217
|
+
root = _maybe_union()
|
218
|
+
if _pos() != len(s):
|
219
|
+
raise SyntaxError(
|
220
|
+
'Unexpected end of annotation.\n\n' + _error_illustration()
|
221
|
+
)
|
222
|
+
return root
|
223
|
+
|
224
|
+
|
38
225
|
def _field_from_annotation(
|
39
226
|
key: typing.Union[str, class_schema.KeySpec],
|
40
227
|
annotation: typing.Any,
|
@@ -91,7 +278,12 @@ def _value_spec_from_type_annotation(
|
|
91
278
|
parent_module: typing.Optional[types.ModuleType] = None
|
92
279
|
) -> class_schema.ValueSpec:
|
93
280
|
"""Creates a value spec from type annotation."""
|
94
|
-
if annotation
|
281
|
+
if isinstance(annotation, str) and not accept_value_as_annotation:
|
282
|
+
annotation = annotation_from_str(annotation, parent_module)
|
283
|
+
|
284
|
+
if annotation is None:
|
285
|
+
return vs.Object(type(None))
|
286
|
+
elif annotation is bool:
|
95
287
|
return vs.Bool()
|
96
288
|
elif annotation is int:
|
97
289
|
return vs.Int()
|
@@ -193,10 +385,7 @@ def _value_spec_from_type_annotation(
|
|
193
385
|
elif (
|
194
386
|
inspect.isclass(annotation)
|
195
387
|
or pg_inspect.is_generic(annotation)
|
196
|
-
or (isinstance(annotation, str) and not accept_value_as_annotation)
|
197
388
|
):
|
198
|
-
if isinstance(annotation, str) and parent_module is not None:
|
199
|
-
annotation = class_schema.ForwardRef(parent_module, annotation)
|
200
389
|
return vs.Object(annotation)
|
201
390
|
|
202
391
|
if accept_value_as_annotation:
|
@@ -227,11 +416,12 @@ def _value_spec_from_annotation(
|
|
227
416
|
return annotation
|
228
417
|
elif annotation == inspect.Parameter.empty:
|
229
418
|
return vs.Any()
|
230
|
-
|
419
|
+
|
420
|
+
if annotation is None:
|
231
421
|
if accept_value_as_annotation:
|
232
422
|
return vs.Any().noneable()
|
233
423
|
else:
|
234
|
-
return vs.
|
424
|
+
return vs.Object(type(None))
|
235
425
|
|
236
426
|
if auto_typing:
|
237
427
|
return _value_spec_from_type_annotation(
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The PyGlove Authors
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -11,13 +11,12 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for pyglove.core.typing.annotation_conversion."""
|
15
|
-
|
16
14
|
import inspect
|
17
15
|
import sys
|
18
16
|
import typing
|
19
17
|
import unittest
|
20
18
|
|
19
|
+
from pyglove.core import coding
|
21
20
|
from pyglove.core.typing import annotated
|
22
21
|
from pyglove.core.typing import annotation_conversion
|
23
22
|
from pyglove.core.typing import key_specs as ks
|
@@ -31,6 +30,129 @@ class Foo:
|
|
31
30
|
pass
|
32
31
|
|
33
32
|
|
33
|
+
_MODULE = sys.modules[__name__]
|
34
|
+
|
35
|
+
|
36
|
+
class AnnotationFromStrTest(unittest.TestCase):
|
37
|
+
"""Tests for annotation_from_str."""
|
38
|
+
|
39
|
+
def test_basic_types(self):
|
40
|
+
self.assertIsNone(annotation_conversion.annotation_from_str('None'))
|
41
|
+
self.assertEqual(annotation_conversion.annotation_from_str('str'), str)
|
42
|
+
self.assertEqual(annotation_conversion.annotation_from_str('int'), int)
|
43
|
+
self.assertEqual(annotation_conversion.annotation_from_str('float'), float)
|
44
|
+
self.assertEqual(annotation_conversion.annotation_from_str('bool'), bool)
|
45
|
+
self.assertEqual(annotation_conversion.annotation_from_str('list'), list)
|
46
|
+
self.assertEqual(
|
47
|
+
annotation_conversion.annotation_from_str('list[int]'), list[int]
|
48
|
+
)
|
49
|
+
self.assertEqual(annotation_conversion.annotation_from_str('tuple'), tuple)
|
50
|
+
self.assertEqual(
|
51
|
+
annotation_conversion.annotation_from_str('tuple[int]'), tuple[int]
|
52
|
+
)
|
53
|
+
self.assertEqual(
|
54
|
+
annotation_conversion.annotation_from_str('tuple[int, ...]'),
|
55
|
+
tuple[int, ...]
|
56
|
+
)
|
57
|
+
self.assertEqual(
|
58
|
+
annotation_conversion.annotation_from_str('tuple[int, str]'),
|
59
|
+
tuple[int, str]
|
60
|
+
)
|
61
|
+
|
62
|
+
def test_generic_types(self):
|
63
|
+
self.assertEqual(
|
64
|
+
annotation_conversion.annotation_from_str('typing.List[str]', _MODULE),
|
65
|
+
typing.List[str]
|
66
|
+
)
|
67
|
+
|
68
|
+
def test_union(self):
|
69
|
+
self.assertEqual(
|
70
|
+
annotation_conversion.annotation_from_str(
|
71
|
+
'typing.Union[str, typing.Union[int, float]]', _MODULE),
|
72
|
+
typing.Union[str, int, float]
|
73
|
+
)
|
74
|
+
if sys.version_info >= (3, 10):
|
75
|
+
self.assertEqual(
|
76
|
+
annotation_conversion.annotation_from_str(
|
77
|
+
'str | int | float', _MODULE),
|
78
|
+
typing.Union[str, int, float]
|
79
|
+
)
|
80
|
+
|
81
|
+
def test_literal(self):
|
82
|
+
self.assertEqual(
|
83
|
+
annotation_conversion.annotation_from_str(
|
84
|
+
'typing.Literal[1, True, "a", \'"b"\', "\\"c\\"", "\\\\"]',
|
85
|
+
_MODULE
|
86
|
+
),
|
87
|
+
typing.Literal[1, True, 'a', '"b"', '"c"', '\\']
|
88
|
+
)
|
89
|
+
self.assertEqual(
|
90
|
+
annotation_conversion.annotation_from_str(
|
91
|
+
'typing.Literal[(1, 1), f"A {[1]}"]', _MODULE),
|
92
|
+
typing.Literal[(1, 1), 'A [1]']
|
93
|
+
)
|
94
|
+
with self.assertRaisesRegex(SyntaxError, 'Expected "\\["'):
|
95
|
+
annotation_conversion.annotation_from_str('typing.Literal', _MODULE)
|
96
|
+
|
97
|
+
with self.assertRaisesRegex(SyntaxError, 'Unexpected end of annotation'):
|
98
|
+
annotation_conversion.annotation_from_str('typing.Literal[1', _MODULE)
|
99
|
+
|
100
|
+
with self.assertRaisesRegex(
|
101
|
+
coding.CodeError, 'Function definition is not allowed'
|
102
|
+
):
|
103
|
+
annotation_conversion.annotation_from_str(
|
104
|
+
'typing.Literal[lambda x: x]', _MODULE
|
105
|
+
)
|
106
|
+
|
107
|
+
def test_callable(self):
|
108
|
+
self.assertEqual(
|
109
|
+
annotation_conversion.annotation_from_str(
|
110
|
+
'typing.Callable[int, int]', _MODULE),
|
111
|
+
typing.Callable[[int], int]
|
112
|
+
)
|
113
|
+
self.assertEqual(
|
114
|
+
annotation_conversion.annotation_from_str(
|
115
|
+
'typing.Callable[[int], int]', _MODULE),
|
116
|
+
typing.Callable[[int], int]
|
117
|
+
)
|
118
|
+
self.assertEqual(
|
119
|
+
annotation_conversion.annotation_from_str(
|
120
|
+
'typing.Callable[..., None]', _MODULE),
|
121
|
+
typing.Callable[..., None]
|
122
|
+
)
|
123
|
+
|
124
|
+
def test_forward_ref(self):
|
125
|
+
self.assertEqual(
|
126
|
+
annotation_conversion.annotation_from_str(
|
127
|
+
'AAA', _MODULE),
|
128
|
+
typing.ForwardRef(
|
129
|
+
'AAA', False, _MODULE
|
130
|
+
)
|
131
|
+
)
|
132
|
+
self.assertEqual(
|
133
|
+
annotation_conversion.annotation_from_str(
|
134
|
+
'typing.List[AAA]', _MODULE),
|
135
|
+
typing.List[
|
136
|
+
typing.ForwardRef(
|
137
|
+
'AAA', False, _MODULE
|
138
|
+
)
|
139
|
+
]
|
140
|
+
)
|
141
|
+
|
142
|
+
def test_bad_annotation(self):
|
143
|
+
with self.assertRaisesRegex(SyntaxError, 'Expected type identifier'):
|
144
|
+
annotation_conversion.annotation_from_str('typing.List[]')
|
145
|
+
|
146
|
+
with self.assertRaisesRegex(SyntaxError, 'Expected "]"'):
|
147
|
+
annotation_conversion.annotation_from_str('typing.List[int')
|
148
|
+
|
149
|
+
with self.assertRaisesRegex(SyntaxError, 'Unexpected end of annotation'):
|
150
|
+
annotation_conversion.annotation_from_str('typing.List[int]1', _MODULE)
|
151
|
+
|
152
|
+
with self.assertRaisesRegex(SyntaxError, 'Expected "]"'):
|
153
|
+
annotation_conversion.annotation_from_str('typing.Callable[[x')
|
154
|
+
|
155
|
+
|
34
156
|
class FieldFromAnnotationTest(unittest.TestCase):
|
35
157
|
"""Tests for Field.fromAnnotation."""
|
36
158
|
|
@@ -132,17 +254,24 @@ class ValueSpecFromAnnotationTest(unittest.TestCase):
|
|
132
254
|
|
133
255
|
def test_none(self):
|
134
256
|
self.assertEqual(
|
135
|
-
ValueSpec.from_annotation(None, False), vs.
|
257
|
+
ValueSpec.from_annotation(None, False), vs.Object(type(None)))
|
136
258
|
self.assertEqual(
|
137
|
-
ValueSpec.from_annotation(None, True), vs.
|
259
|
+
ValueSpec.from_annotation('None', True), vs.Object(type(None)))
|
138
260
|
self.assertEqual(
|
139
|
-
ValueSpec.from_annotation(
|
140
|
-
|
261
|
+
ValueSpec.from_annotation(None, True), vs.Object(type(None)))
|
262
|
+
self.assertEqual(
|
263
|
+
ValueSpec.from_annotation(None, accept_value_as_annotation=True),
|
264
|
+
vs.Any().noneable()
|
265
|
+
)
|
141
266
|
|
142
267
|
def test_any(self):
|
143
268
|
self.assertEqual(
|
144
269
|
ValueSpec.from_annotation(typing.Any, False),
|
145
270
|
vs.Any(annotation=typing.Any))
|
271
|
+
self.assertEqual(
|
272
|
+
ValueSpec.from_annotation('typing.Any', True, parent_module=_MODULE),
|
273
|
+
vs.Any(annotation=typing.Any)
|
274
|
+
)
|
146
275
|
self.assertEqual(
|
147
276
|
ValueSpec.from_annotation(typing.Any, True),
|
148
277
|
vs.Any(annotation=typing.Any))
|
@@ -152,6 +281,7 @@ class ValueSpecFromAnnotationTest(unittest.TestCase):
|
|
152
281
|
|
153
282
|
def test_bool(self):
|
154
283
|
self.assertEqual(ValueSpec.from_annotation(bool, True), vs.Bool())
|
284
|
+
self.assertEqual(ValueSpec.from_annotation('bool', True), vs.Bool())
|
155
285
|
self.assertEqual(
|
156
286
|
ValueSpec.from_annotation(bool, False), vs.Any(annotation=bool))
|
157
287
|
self.assertEqual(
|
@@ -159,6 +289,7 @@ class ValueSpecFromAnnotationTest(unittest.TestCase):
|
|
159
289
|
|
160
290
|
def test_int(self):
|
161
291
|
self.assertEqual(ValueSpec.from_annotation(int, True), vs.Int())
|
292
|
+
self.assertEqual(ValueSpec.from_annotation('int', True), vs.Int())
|
162
293
|
self.assertEqual(ValueSpec.from_annotation(int, True, True), vs.Int())
|
163
294
|
self.assertEqual(
|
164
295
|
ValueSpec.from_annotation(int, False), vs.Any(annotation=int))
|
@@ -182,7 +313,9 @@ class ValueSpecFromAnnotationTest(unittest.TestCase):
|
|
182
313
|
ValueSpec.from_annotation(str, False, True), vs.Any(annotation=str))
|
183
314
|
|
184
315
|
self.assertEqual(
|
185
|
-
ValueSpec.from_annotation('A', False, False),
|
316
|
+
ValueSpec.from_annotation('A', False, False),
|
317
|
+
vs.Any(annotation='A')
|
318
|
+
)
|
186
319
|
self.assertEqual(
|
187
320
|
ValueSpec.from_annotation('A', False, True), vs.Str('A'))
|
188
321
|
self.assertEqual(
|
@@ -0,0 +1,135 @@
|
|
1
|
+
# Copyright 2025 The PyGlove Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
import sys
|
17
|
+
import typing
|
18
|
+
from typing import List, Literal, Union
|
19
|
+
import unittest
|
20
|
+
|
21
|
+
from pyglove.core import symbolic as pg
|
22
|
+
from pyglove.core.typing import key_specs as ks
|
23
|
+
from pyglove.core.typing import value_specs as vs
|
24
|
+
|
25
|
+
|
26
|
+
class AnnotationFutureConversionTest(unittest.TestCase):
|
27
|
+
|
28
|
+
# Class with forward declaration must not be defined in functions.
|
29
|
+
class A(pg.Object):
|
30
|
+
a: typing.Optional[AnnotationFutureConversionTest.A]
|
31
|
+
b: List[AnnotationFutureConversionTest.A]
|
32
|
+
|
33
|
+
def assert_value_spec(self, cls, field_name, expected_value_spec):
|
34
|
+
self.assertEqual(cls.__schema__[field_name].value, expected_value_spec)
|
35
|
+
|
36
|
+
def test_basics(self):
|
37
|
+
|
38
|
+
class Foo(pg.Object):
|
39
|
+
a: int
|
40
|
+
b: float
|
41
|
+
c: bool
|
42
|
+
d: str
|
43
|
+
e: typing.Any
|
44
|
+
f: typing.Dict[str, typing.Any]
|
45
|
+
g: typing.List[str]
|
46
|
+
h: typing.Tuple[int, int]
|
47
|
+
i: typing.Callable[[int, int], None]
|
48
|
+
|
49
|
+
self.assert_value_spec(Foo, 'a', vs.Int())
|
50
|
+
self.assert_value_spec(Foo, 'b', vs.Float())
|
51
|
+
self.assert_value_spec(Foo, 'c', vs.Bool())
|
52
|
+
self.assert_value_spec(Foo, 'd', vs.Str())
|
53
|
+
self.assert_value_spec(Foo, 'e', vs.Any(annotation=typing.Any))
|
54
|
+
self.assert_value_spec(
|
55
|
+
Foo, 'f', vs.Dict([(ks.StrKey(), vs.Any(annotation=typing.Any))])
|
56
|
+
)
|
57
|
+
self.assert_value_spec(Foo, 'g', vs.List(vs.Str()))
|
58
|
+
self.assert_value_spec(Foo, 'h', vs.Tuple([vs.Int(), vs.Int()]))
|
59
|
+
self.assert_value_spec(
|
60
|
+
Foo, 'i',
|
61
|
+
vs.Callable([vs.Int(), vs.Int()], returns=vs.Object(type(None)))
|
62
|
+
)
|
63
|
+
|
64
|
+
def test_list(self):
|
65
|
+
if sys.version_info >= (3, 10):
|
66
|
+
|
67
|
+
class Bar(pg.Object):
|
68
|
+
x: list[int | None]
|
69
|
+
|
70
|
+
self.assert_value_spec(Bar, 'x', vs.List(vs.Int().noneable()))
|
71
|
+
|
72
|
+
def test_var_length_tuple(self):
|
73
|
+
|
74
|
+
class Foo(pg.Object):
|
75
|
+
x: typing.Tuple[int, ...]
|
76
|
+
|
77
|
+
self.assert_value_spec(Foo, 'x', vs.Tuple(vs.Int()))
|
78
|
+
|
79
|
+
if sys.version_info >= (3, 10):
|
80
|
+
|
81
|
+
class Bar(pg.Object):
|
82
|
+
x: tuple[int, ...]
|
83
|
+
|
84
|
+
self.assert_value_spec(Bar, 'x', vs.Tuple(vs.Int()))
|
85
|
+
|
86
|
+
def test_optional(self):
|
87
|
+
|
88
|
+
class Foo(pg.Object):
|
89
|
+
x: typing.Optional[int]
|
90
|
+
|
91
|
+
self.assert_value_spec(Foo, 'x', vs.Int().noneable())
|
92
|
+
|
93
|
+
if sys.version_info >= (3, 10):
|
94
|
+
class Bar(pg.Object):
|
95
|
+
x: int | None
|
96
|
+
|
97
|
+
self.assert_value_spec(Bar, 'x', vs.Int().noneable())
|
98
|
+
|
99
|
+
def test_union(self):
|
100
|
+
|
101
|
+
class Foo(pg.Object):
|
102
|
+
x: Union[int, typing.Union[str, bool], None]
|
103
|
+
|
104
|
+
self.assert_value_spec(
|
105
|
+
Foo, 'x', vs.Union([vs.Int(), vs.Str(), vs.Bool()]).noneable()
|
106
|
+
)
|
107
|
+
|
108
|
+
if sys.version_info >= (3, 10):
|
109
|
+
|
110
|
+
class Bar(pg.Object):
|
111
|
+
x: int | str | bool
|
112
|
+
|
113
|
+
self.assert_value_spec(
|
114
|
+
Bar, 'x', vs.Union([vs.Int(), vs.Str(), vs.Bool()])
|
115
|
+
)
|
116
|
+
|
117
|
+
def test_literal(self):
|
118
|
+
|
119
|
+
class Foo(pg.Object):
|
120
|
+
x: Literal[1, True, 'abc']
|
121
|
+
|
122
|
+
self.assert_value_spec(
|
123
|
+
Foo, 'x', vs.Enum(vs.MISSING_VALUE, [1, True, 'abc'])
|
124
|
+
)
|
125
|
+
|
126
|
+
def test_self_referencial(self):
|
127
|
+
self.assert_value_spec(
|
128
|
+
self.A, 'a', vs.Object(self.A).noneable()
|
129
|
+
)
|
130
|
+
self.assert_value_spec(
|
131
|
+
self.A, 'b', vs.List(vs.Object(self.A))
|
132
|
+
)
|
133
|
+
|
134
|
+
if __name__ == '__main__':
|
135
|
+
unittest.main()
|
@@ -97,9 +97,10 @@ class KeySpec(utils.Formattable, utils.JSONConvertible):
|
|
97
97
|
class ForwardRef(utils.Formattable):
|
98
98
|
"""Forward type reference."""
|
99
99
|
|
100
|
-
def __init__(self, module: types.ModuleType,
|
100
|
+
def __init__(self, module: types.ModuleType, qualname: str):
|
101
101
|
self._module = module
|
102
|
-
self.
|
102
|
+
self._qualname = qualname
|
103
|
+
self._resolved_value = None
|
103
104
|
|
104
105
|
@property
|
105
106
|
def module(self) -> types.ModuleType:
|
@@ -109,35 +110,54 @@ class ForwardRef(utils.Formattable):
|
|
109
110
|
@property
|
110
111
|
def name(self) -> str:
|
111
112
|
"""Returns the name of the type reference."""
|
112
|
-
return self.
|
113
|
+
return self._qualname.split('.')[-1]
|
113
114
|
|
114
115
|
@property
|
115
116
|
def qualname(self) -> str:
|
116
117
|
"""Returns the qualified name of the reference."""
|
117
|
-
return
|
118
|
+
return self._qualname
|
119
|
+
|
120
|
+
@property
|
121
|
+
def type_id(self) -> str:
|
122
|
+
"""Returns the type id of the reference."""
|
123
|
+
return f'{self.module.__name__}.{self.qualname}'
|
118
124
|
|
119
125
|
def as_annotation(self) -> Union[Type[Any], str]:
|
120
126
|
"""Returns the forward reference as an annotation."""
|
121
|
-
return self.cls if self.resolved else self.
|
127
|
+
return self.cls if self.resolved else self.qualname
|
122
128
|
|
123
129
|
@property
|
124
130
|
def resolved(self) -> bool:
|
125
131
|
"""Returns True if the symbol for the name is resolved.."""
|
126
|
-
|
132
|
+
if self._resolved_value is None:
|
133
|
+
self._resolved_value = self._resolve()
|
134
|
+
return self._resolved_value is not None
|
135
|
+
|
136
|
+
def _resolve(self) -> Optional[Any]:
|
137
|
+
names = self._qualname.split('.')
|
138
|
+
parent_obj = self.module
|
139
|
+
for name in names:
|
140
|
+
parent_obj = getattr(parent_obj, name, utils.MISSING_VALUE)
|
141
|
+
if parent_obj == utils.MISSING_VALUE:
|
142
|
+
return None
|
143
|
+
if not inspect.isclass(parent_obj):
|
144
|
+
raise TypeError(
|
145
|
+
f'{self.qualname!r} from module {self.module.__name__!r} '
|
146
|
+
'is not a class.'
|
147
|
+
)
|
148
|
+
return parent_obj
|
127
149
|
|
128
150
|
@property
|
129
151
|
def cls(self) -> Type[Any]:
|
130
152
|
"""Returns the resolved reference class.."""
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
)
|
140
|
-
return reference
|
153
|
+
if self._resolved_value is None:
|
154
|
+
self._resolved_value = self._resolve()
|
155
|
+
if self._resolved_value is None:
|
156
|
+
raise TypeError(
|
157
|
+
f'{self.qualname!r} does not exist in '
|
158
|
+
f'module {self.module.__name__!r}'
|
159
|
+
)
|
160
|
+
return self._resolved_value
|
141
161
|
|
142
162
|
def format(
|
143
163
|
self,
|
@@ -150,7 +170,7 @@ class ForwardRef(utils.Formattable):
|
|
150
170
|
return utils.kvlist_str(
|
151
171
|
[
|
152
172
|
('module', self.module.__name__, None),
|
153
|
-
('name', self.
|
173
|
+
('name', self.qualname, None),
|
154
174
|
],
|
155
175
|
label=self.__class__.__name__,
|
156
176
|
compact=compact,
|
@@ -164,7 +184,7 @@ class ForwardRef(utils.Formattable):
|
|
164
184
|
if self is other:
|
165
185
|
return True
|
166
186
|
elif isinstance(other, ForwardRef):
|
167
|
-
return self.module is other.module and self.
|
187
|
+
return self.module is other.module and self.qualname == other.qualname
|
168
188
|
elif inspect.isclass(other):
|
169
189
|
return self.resolved and self.cls is other # pytype: disable=bad-return-type
|
170
190
|
|
@@ -173,11 +193,11 @@ class ForwardRef(utils.Formattable):
|
|
173
193
|
return not self.__eq__(other)
|
174
194
|
|
175
195
|
def __hash__(self) -> int:
|
176
|
-
return hash((self.module, self.
|
196
|
+
return hash((self.module, self.qualname))
|
177
197
|
|
178
198
|
def __deepcopy__(self, memo) -> 'ForwardRef':
|
179
199
|
"""Override deep copy to avoid copying module."""
|
180
|
-
return ForwardRef(self.module, self.
|
200
|
+
return ForwardRef(self.module, self.qualname)
|
181
201
|
|
182
202
|
|
183
203
|
class ValueSpec(utils.Formattable, utils.JSONConvertible):
|
@@ -628,9 +648,11 @@ class Field(utils.Formattable, utils.JSONConvertible):
|
|
628
648
|
annotation: Any,
|
629
649
|
description: Optional[str] = None,
|
630
650
|
metadata: Optional[Dict[str, Any]] = None,
|
631
|
-
auto_typing=True
|
651
|
+
auto_typing=True,
|
652
|
+
parent_module: Optional[types.ModuleType] = None
|
653
|
+
) -> 'Field':
|
632
654
|
"""Gets a Field from annotation."""
|
633
|
-
del key, annotation, description, metadata, auto_typing
|
655
|
+
del key, annotation, description, metadata, auto_typing, parent_module
|
634
656
|
assert False, 'Overridden in `annotation_conversion.py`.'
|
635
657
|
|
636
658
|
@property
|
@@ -31,24 +31,31 @@ from pyglove.core.typing.class_schema import Schema
|
|
31
31
|
class ForwardRefTest(unittest.TestCase):
|
32
32
|
"""Test for `ForwardRef` class."""
|
33
33
|
|
34
|
+
class A:
|
35
|
+
pass
|
36
|
+
|
34
37
|
def setUp(self):
|
35
38
|
super().setUp()
|
36
39
|
self._module = sys.modules[__name__]
|
37
40
|
|
38
41
|
def test_basics(self):
|
39
|
-
r = class_schema.ForwardRef(self._module, '
|
42
|
+
r = class_schema.ForwardRef(self._module, 'ForwardRefTest.A')
|
40
43
|
self.assertIs(r.module, self._module)
|
41
|
-
self.assertEqual(r.name, '
|
42
|
-
self.assertEqual(r.qualname,
|
44
|
+
self.assertEqual(r.name, 'A')
|
45
|
+
self.assertEqual(r.qualname, 'ForwardRefTest.A')
|
46
|
+
self.assertEqual(r.type_id, f'{self._module.__name__}.ForwardRefTest.A')
|
43
47
|
|
44
48
|
def test_resolved(self):
|
45
|
-
self.assertTrue(
|
49
|
+
self.assertTrue(
|
50
|
+
class_schema.ForwardRef(self._module, 'ForwardRefTest.A').resolved
|
51
|
+
)
|
46
52
|
self.assertFalse(class_schema.ForwardRef(self._module, 'Foo').resolved)
|
47
53
|
|
48
54
|
def test_as_annotation(self):
|
49
55
|
self.assertEqual(
|
50
|
-
class_schema.ForwardRef(
|
51
|
-
|
56
|
+
class_schema.ForwardRef(
|
57
|
+
self._module, 'ForwardRefTest.A').as_annotation(),
|
58
|
+
ForwardRefTest.A,
|
52
59
|
)
|
53
60
|
self.assertEqual(
|
54
61
|
class_schema.ForwardRef(self._module, 'Foo').as_annotation(), 'Foo'
|
@@ -56,7 +63,8 @@ class ForwardRefTest(unittest.TestCase):
|
|
56
63
|
|
57
64
|
def test_cls(self):
|
58
65
|
self.assertIs(
|
59
|
-
class_schema.ForwardRef(self._module, '
|
66
|
+
class_schema.ForwardRef(self._module, 'ForwardRefTest.A').cls,
|
67
|
+
ForwardRefTest.A
|
60
68
|
)
|
61
69
|
|
62
70
|
with self.assertRaisesRegex(TypeError, '.* does not exist in module'):
|
@@ -88,7 +88,7 @@ pyglove/core/symbolic/inferred.py,sha256=E4zgphg6NNZad9Fl3jdHQOMZeqEp9XHq5OUYqXE
|
|
88
88
|
pyglove/core/symbolic/inferred_test.py,sha256=G6uPykONcChvs6vZujXHSWaYfjewLTVBscMqzzKNty0,1270
|
89
89
|
pyglove/core/symbolic/list.py,sha256=z8goU0ntd-Q5ADaCGiKsJwPhdRdQb0Kd_p-ZekXaLy4,30303
|
90
90
|
pyglove/core/symbolic/list_test.py,sha256=IAyFQ48nyczKUcPNZFKHBkX5oh7Xuxbnv3rRkONhbHw,61146
|
91
|
-
pyglove/core/symbolic/object.py,sha256=
|
91
|
+
pyglove/core/symbolic/object.py,sha256=6bCg20r76sfHfVA-BBQYFxTYBzOM_c8BihHEKMKt3i8,42093
|
92
92
|
pyglove/core/symbolic/object_test.py,sha256=vIL_ymGpPPXamIljEBdufpw-p82kTeWNy2IbgzGQjig,93811
|
93
93
|
pyglove/core/symbolic/origin.py,sha256=OSWMKjvPcISOXrzuX3lCQC8m_qaGl-9INsIB81erUnU,6124
|
94
94
|
pyglove/core/symbolic/origin_test.py,sha256=dU_ZGrGDetM_lYVMn3wQO0d367_t_t8eESe3NrKPBNE,3159
|
@@ -106,17 +106,18 @@ pyglove/core/tuning/protocols.py,sha256=10Iukt1rqh05caURTZffSsb3CcHo7epBQnNtnyMy
|
|
106
106
|
pyglove/core/tuning/protocols_test.py,sha256=Cbzvz3EacaW2sbm1rTSQXEt_VucMoQbeQ6AeN-GV5Vc,1883
|
107
107
|
pyglove/core/tuning/sample.py,sha256=UzsCY8kiqnzH_mR94zLXhOloyvvEwfmBWluBjmefUFA,12975
|
108
108
|
pyglove/core/tuning/sample_test.py,sha256=JqwDPy3EPC_VjU9dipk90jj1kovZB3Zb9hAjAlZ-U1U,17551
|
109
|
-
pyglove/core/typing/__init__.py,sha256=
|
109
|
+
pyglove/core/typing/__init__.py,sha256=MwraQ-6K8NYI_xk4V3oaVSkrPqiU9InDtAgXPSRBvik,14494
|
110
110
|
pyglove/core/typing/annotated.py,sha256=llaajIDj9GK-4kUGJoO4JsHU6ESPOra2SZ-jG6xmsOQ,3203
|
111
111
|
pyglove/core/typing/annotated_test.py,sha256=p1qid3R-jeiOTTxOVq6hXW8XFvn-h1cUzJWISPst2l8,2484
|
112
|
-
pyglove/core/typing/annotation_conversion.py,sha256=
|
113
|
-
pyglove/core/typing/annotation_conversion_test.py,sha256=
|
112
|
+
pyglove/core/typing/annotation_conversion.py,sha256=8q4-7uo12DZj1_AnDzF4UWQoplmHHtZ0lYIOlY8xVXo,13889
|
113
|
+
pyglove/core/typing/annotation_conversion_test.py,sha256=Zl0LSL6z8exwgflb4kDJW2QfrvbixJMW4_7lXcHMl14,16271
|
114
|
+
pyglove/core/typing/annotation_future_test.py,sha256=ZVLU3kheO2V-nrp-m5jrGMLgGMs4S6RWnJbyD9KbHl4,3746
|
114
115
|
pyglove/core/typing/callable_ext.py,sha256=PiBQWPeUAH7Lgmf2xKCZqgK7N0OSrTdbnEkV8Ph31OA,9127
|
115
116
|
pyglove/core/typing/callable_ext_test.py,sha256=TnWKU4_ZjvpbHZFtFHgFvCMDiCos8VmLlODcM_7Xg8M,10156
|
116
117
|
pyglove/core/typing/callable_signature.py,sha256=DRpt7aShfkn8pb3SCiZzS_27eHbkQ_d2UB8BUhJjs0Q,27176
|
117
118
|
pyglove/core/typing/callable_signature_test.py,sha256=iQmHsKPhJPQlMikDhEyxKyq7yWyXI9juKCLYgKhrH3U,25145
|
118
|
-
pyglove/core/typing/class_schema.py,sha256=
|
119
|
-
pyglove/core/typing/class_schema_test.py,sha256=
|
119
|
+
pyglove/core/typing/class_schema.py,sha256=Xf3koEnx150Prqw3WDOP04xr1H8bEkUfMb4AKZcPg3M,54683
|
120
|
+
pyglove/core/typing/class_schema_test.py,sha256=RurRdCyPuypKJ7izgcq9zW3JNHgODiJdQvDn0BDZDjU,29353
|
120
121
|
pyglove/core/typing/custom_typing.py,sha256=qdnIKHWNt5kZAAFdpQXra8bBu6RljMbbJ_YDG2mhAUA,2205
|
121
122
|
pyglove/core/typing/inspect.py,sha256=VLSz1KAunNm2hx0eEMjiwxKLl9FHlKr9nHelLT25iEA,7726
|
122
123
|
pyglove/core/typing/inspect_test.py,sha256=xclevobF0X8c_B5b1q1dkBJZN1TsVA1RUhk5l25DUCM,10248
|
@@ -211,8 +212,8 @@ pyglove/ext/scalars/randoms.py,sha256=LkMIIx7lOq_lvJvVS3BrgWGuWl7Pi91-lA-O8x_gZs
|
|
211
212
|
pyglove/ext/scalars/randoms_test.py,sha256=nEhiqarg8l_5EOucp59CYrpO2uKxS1pe0hmBdZUzRNM,2000
|
212
213
|
pyglove/ext/scalars/step_wise.py,sha256=IDw3tuTpv0KVh7AN44W43zqm1-E0HWPUlytWOQC9w3Y,3789
|
213
214
|
pyglove/ext/scalars/step_wise_test.py,sha256=TL1vJ19xVx2t5HKuyIzGoogF7N3Rm8YhLE6JF7i0iy8,2540
|
214
|
-
pyglove-0.4.5.
|
215
|
-
pyglove-0.4.5.
|
216
|
-
pyglove-0.4.5.
|
217
|
-
pyglove-0.4.5.
|
218
|
-
pyglove-0.4.5.
|
215
|
+
pyglove-0.4.5.dev202502050809.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
216
|
+
pyglove-0.4.5.dev202502050809.dist-info/METADATA,sha256=IiiuFK9RRBh8I9xh9WW4GM8el-96GLZI9uDK6IAVwho,7067
|
217
|
+
pyglove-0.4.5.dev202502050809.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
218
|
+
pyglove-0.4.5.dev202502050809.dist-info/top_level.txt,sha256=wITzJSKcj8GZUkbq-MvUQnFadkiuAv_qv5qQMw0fIow,8
|
219
|
+
pyglove-0.4.5.dev202502050809.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{pyglove-0.4.5.dev202502040809.dist-info → pyglove-0.4.5.dev202502050809.dist-info}/top_level.txt
RENAMED
File without changes
|