pyglove 0.4.5.dev202502040809__py3-none-any.whl → 0.4.5.dev202502060809__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,6 +16,7 @@
16
16
  import abc
17
17
  import functools
18
18
  import inspect
19
+ import sys
19
20
  import typing
20
21
  from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
21
22
 
@@ -154,7 +155,9 @@ class ObjectMeta(abc.ABCMeta):
154
155
  if typing.get_origin(attr_annotation) is typing.ClassVar:
155
156
  continue
156
157
 
157
- field = pg_typing.Field.from_annotation(key, attr_annotation)
158
+ field = pg_typing.Field.from_annotation(
159
+ key, attr_annotation, parent_module=sys.modules[cls.__module__]
160
+ )
158
161
  if isinstance(key, pg_typing.ConstStrKey):
159
162
  attr_value = cls.__dict__.get(attr_name, pg_typing.MISSING_VALUE)
160
163
  if attr_value != pg_typing.MISSING_VALUE:
@@ -375,6 +375,9 @@ import pyglove.core.typing.annotation_conversion # pylint: disable=unused-impor
375
375
  # Interface for custom typing.
376
376
  from pyglove.core.typing.custom_typing import CustomTyping
377
377
 
378
+ # Annotation conversion
379
+ from pyglove.core.typing.annotation_conversion import annotation_from_str
380
+
378
381
  # Callable signature.
379
382
  from pyglove.core.typing.callable_signature import Argument
380
383
  from pyglove.core.typing.callable_signature import CallableType
@@ -13,11 +13,13 @@
13
13
  # limitations under the License.
14
14
  """Conversion from annotations to PyGlove value specs."""
15
15
 
16
+ import builtins
16
17
  import collections
17
18
  import inspect
18
19
  import types
19
20
  import typing
20
21
 
22
+ from pyglove.core import coding
21
23
  from pyglove.core import utils
22
24
  from pyglove.core.typing import annotated
23
25
  from pyglove.core.typing import class_schema
@@ -35,6 +37,231 @@ _Annotated = getattr(typing, 'Annotated', None) # pylint: disable=invalid-name
35
37
  _UnionType = getattr(types, 'UnionType', None) # pylint: disable=invalid-name
36
38
 
37
39
 
40
+ def annotation_from_str(
41
+ annotation_str: str,
42
+ parent_module: typing.Optional[types.ModuleType] = None,
43
+ ) -> typing.Any:
44
+ """Parses annotations from str.
45
+
46
+ BNF for PyType annotations:
47
+
48
+ ```
49
+ <maybe_union> ::= <type> | <type> "|" <maybe_union>
50
+ <type> ::= <literal_type> | <non_literal_type>
51
+
52
+ <literal_type> ::= "Literal"<literal_params>
53
+ <literal_params> ::= "["<python_values>"]" (parsed by `pg.coding.evaluate`)
54
+
55
+ <non_literal_type> ::= <type_id> | <type_id>"["<type_arg>"]"
56
+ <type_arg> ::= <maybe_type_list> | <maybe_type_list>","<maybe_type_list>
57
+ <maybe_type_list> ::= "["<type_arg>"]" | <maybe_union>
58
+ <type_id> ::= 'aAz_.1-9'
59
+ ```
60
+
61
+ Args:
62
+ annotation_str: String form of type annotations. E.g. "list[str]"
63
+ parent_module: The module where the annotation was defined.
64
+
65
+ Returns:
66
+ Object form of the annotation.
67
+
68
+ Raises:
69
+ SyntaxError: If the annotation string is invalid.
70
+ """
71
+ s = annotation_str
72
+ context = dict(pos=0)
73
+
74
+ def _eof() -> bool:
75
+ return context['pos'] == len(s)
76
+
77
+ def _pos() -> int:
78
+ return context['pos']
79
+
80
+ def _next(n: int = 1, offset: int = 0) -> str:
81
+ if _eof():
82
+ return '<EOF>'
83
+ return s[_pos() + offset:_pos() + offset + n]
84
+
85
+ def _advance(n: int) -> None:
86
+ context['pos'] += n
87
+
88
+ def _error_illustration() -> str:
89
+ return f'{s}\n{" " * _pos()}' + '^'
90
+
91
+ def _match(ch) -> bool:
92
+ if _next(len(ch)) == ch:
93
+ _advance(len(ch))
94
+ return True
95
+ return False
96
+
97
+ def _skip_whitespaces() -> None:
98
+ while _next() in ' \t':
99
+ _advance(1)
100
+
101
+ def _maybe_union():
102
+ t = _type()
103
+ while not _eof():
104
+ _skip_whitespaces()
105
+ if _match('|'):
106
+ t = t | _type()
107
+ else:
108
+ break
109
+ return t
110
+
111
+ def _type():
112
+ type_id = _type_id()
113
+ t = _resolve(type_id)
114
+ if t is typing.Literal:
115
+ return t[_literal_params()]
116
+ elif _match('['):
117
+ arg = _type_arg()
118
+ if not _match(']'):
119
+ raise SyntaxError(
120
+ f'Expected "]" at position {_pos()}.\n\n' + _error_illustration()
121
+ )
122
+ return t[arg]
123
+ return t
124
+
125
+ def _literal_params():
126
+ if not _match('['):
127
+ raise SyntaxError(
128
+ f'Expected "[" at position {_pos()}.\n\n' + _error_illustration()
129
+ )
130
+ arg_start = _pos()
131
+ in_str = False
132
+ escape_mode = False
133
+ num_open_bracket = 1
134
+
135
+ while num_open_bracket > 0:
136
+ ch = _next()
137
+ if _eof():
138
+ raise SyntaxError(
139
+ f'Unexpected end of annotation at position {_pos()}.\n\n'
140
+ + _error_illustration()
141
+ )
142
+ if ch == '\\':
143
+ escape_mode = not escape_mode
144
+ else:
145
+ escape_mode = False
146
+
147
+ if ch == "'" and not escape_mode:
148
+ in_str = not in_str
149
+ elif not in_str:
150
+ if ch == '[':
151
+ num_open_bracket += 1
152
+ elif ch == ']':
153
+ num_open_bracket -= 1
154
+ _advance(1)
155
+
156
+ arg_str = s[arg_start:_pos() - 1]
157
+ return coding.evaluate(
158
+ '(' + arg_str + ')', permission=coding.CodePermission.BASIC
159
+ )
160
+
161
+ def _type_arg():
162
+ t_args = []
163
+ t_args.append(_maybe_type_list())
164
+ while _match(','):
165
+ t_args.append(_maybe_type_list())
166
+ return tuple(t_args) if len(t_args) > 1 else t_args[0]
167
+
168
+ def _maybe_type_list():
169
+ if _match('['):
170
+ ret = _type_arg()
171
+ if not _match(']'):
172
+ raise SyntaxError(
173
+ f'Expected "]" at position {_pos()}.\n\n' + _error_illustration()
174
+ )
175
+ return list(ret) if isinstance(ret, tuple) else [ret]
176
+ return _maybe_union()
177
+
178
+ def _type_id() -> str:
179
+ _skip_whitespaces()
180
+ if _match('...'):
181
+ return '...'
182
+ start = _pos()
183
+ while not _eof():
184
+ c = _next()
185
+ if c.isalnum() or c in '_.':
186
+ _advance(1)
187
+ else:
188
+ break
189
+ t_id = s[start:_pos()]
190
+ if not all(x.isidentifier() for x in t_id.split('.')):
191
+ raise SyntaxError(
192
+ f'Expected type identifier, got {t_id!r} at position {start}.\n\n'
193
+ + _error_illustration()
194
+ )
195
+ return t_id
196
+
197
+ def _resolve(type_id: str):
198
+
199
+ def _as_forward_ref() -> typing.ForwardRef:
200
+ return typing.ForwardRef(type_id, False, parent_module) # pytype: disable=not-callable
201
+
202
+ def _resolve_name(name: str, parent_obj: typing.Any):
203
+ if name == 'None':
204
+ return None, True
205
+ if parent_obj is not None and hasattr(parent_obj, name):
206
+ return getattr(parent_obj, name), False
207
+ if hasattr(builtins, name):
208
+ return getattr(builtins, name), True
209
+ if type_id == '...':
210
+ return ..., True
211
+ return utils.MISSING_VALUE, False
212
+
213
+ names = type_id.split('.')
214
+ if len(names) == 1:
215
+ reference, is_builtin = _resolve_name(names[0], parent_module)
216
+ if is_builtin:
217
+ return reference
218
+ if not is_builtin and (
219
+ # When reference is not found, we should treat it as a forward
220
+ # reference.
221
+ reference == utils.MISSING_VALUE
222
+ # When module is being reloaded, we should treat all non-builtin
223
+ # references as forward references.
224
+ or getattr(parent_module, '__reloading__', False)
225
+ ):
226
+ return _as_forward_ref()
227
+ return reference
228
+
229
+ root_obj, _ = _resolve_name(names[0], parent_module)
230
+ # When root object is not found, we should treat it as a forward reference.
231
+ if root_obj == utils.MISSING_VALUE:
232
+ return _as_forward_ref()
233
+
234
+ parent_obj = root_obj
235
+ # When root object is a module, we should treat reference to its children
236
+ # as non-forward references.
237
+ if inspect.ismodule(root_obj):
238
+ for name in names[1:]:
239
+ parent_obj, _ = _resolve_name(name, parent_obj)
240
+ if parent_obj == utils.MISSING_VALUE:
241
+ raise TypeError(f'{type_id!r} does not exist.')
242
+ return parent_obj
243
+ # When root object is non-module variable of current module, and when the
244
+ # module is being reloaded, we should treat reference to its children as
245
+ # forward references.
246
+ elif getattr(parent_module, '__reloading__', False):
247
+ return _as_forward_ref()
248
+ # When root object is non-module variable of current module, we should treat
249
+ # unresolved reference to its children as forward references.
250
+ else:
251
+ for name in names[1:]:
252
+ parent_obj, _ = _resolve_name(name, parent_obj)
253
+ if parent_obj == utils.MISSING_VALUE:
254
+ return _as_forward_ref()
255
+ return parent_obj
256
+
257
+ root = _maybe_union()
258
+ if _pos() != len(s):
259
+ raise SyntaxError(
260
+ 'Unexpected end of annotation.\n\n' + _error_illustration()
261
+ )
262
+ return root
263
+
264
+
38
265
  def _field_from_annotation(
39
266
  key: typing.Union[str, class_schema.KeySpec],
40
267
  annotation: typing.Any,
@@ -91,7 +318,12 @@ def _value_spec_from_type_annotation(
91
318
  parent_module: typing.Optional[types.ModuleType] = None
92
319
  ) -> class_schema.ValueSpec:
93
320
  """Creates a value spec from type annotation."""
94
- if annotation is bool:
321
+ if isinstance(annotation, str) and not accept_value_as_annotation:
322
+ annotation = annotation_from_str(annotation, parent_module)
323
+
324
+ if annotation is None:
325
+ return vs.Object(type(None))
326
+ elif annotation is bool:
95
327
  return vs.Bool()
96
328
  elif annotation is int:
97
329
  return vs.Int()
@@ -193,10 +425,7 @@ def _value_spec_from_type_annotation(
193
425
  elif (
194
426
  inspect.isclass(annotation)
195
427
  or pg_inspect.is_generic(annotation)
196
- or (isinstance(annotation, str) and not accept_value_as_annotation)
197
428
  ):
198
- if isinstance(annotation, str) and parent_module is not None:
199
- annotation = class_schema.ForwardRef(parent_module, annotation)
200
429
  return vs.Object(annotation)
201
430
 
202
431
  if accept_value_as_annotation:
@@ -227,11 +456,12 @@ def _value_spec_from_annotation(
227
456
  return annotation
228
457
  elif annotation == inspect.Parameter.empty:
229
458
  return vs.Any()
230
- elif annotation is None:
459
+
460
+ if annotation is None:
231
461
  if accept_value_as_annotation:
232
462
  return vs.Any().noneable()
233
463
  else:
234
- return vs.Any().freeze(None)
464
+ return vs.Object(type(None))
235
465
 
236
466
  if auto_typing:
237
467
  return _value_spec_from_type_annotation(
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The PyGlove Authors
1
+ # Copyright 2025 The PyGlove Authors
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -11,13 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for pyglove.core.typing.annotation_conversion."""
15
-
16
14
  import inspect
17
15
  import sys
18
16
  import typing
19
17
  import unittest
20
18
 
19
+ from pyglove.core import coding
21
20
  from pyglove.core.typing import annotated
22
21
  from pyglove.core.typing import annotation_conversion
23
22
  from pyglove.core.typing import key_specs as ks
@@ -28,7 +27,168 @@ from pyglove.core.typing.class_schema import ValueSpec
28
27
 
29
28
 
30
29
  class Foo:
31
- pass
30
+ class Bar:
31
+ pass
32
+
33
+
34
+ _MODULE = sys.modules[__name__]
35
+
36
+
37
+ class AnnotationFromStrTest(unittest.TestCase):
38
+ """Tests for annotation_from_str."""
39
+
40
+ def test_basic_types(self):
41
+ self.assertIsNone(annotation_conversion.annotation_from_str('None'))
42
+ self.assertEqual(annotation_conversion.annotation_from_str('str'), str)
43
+ self.assertEqual(annotation_conversion.annotation_from_str('int'), int)
44
+ self.assertEqual(annotation_conversion.annotation_from_str('float'), float)
45
+ self.assertEqual(annotation_conversion.annotation_from_str('bool'), bool)
46
+ self.assertEqual(annotation_conversion.annotation_from_str('list'), list)
47
+ self.assertEqual(
48
+ annotation_conversion.annotation_from_str('list[int]'), list[int]
49
+ )
50
+ self.assertEqual(annotation_conversion.annotation_from_str('tuple'), tuple)
51
+ self.assertEqual(
52
+ annotation_conversion.annotation_from_str('tuple[int]'), tuple[int]
53
+ )
54
+ self.assertEqual(
55
+ annotation_conversion.annotation_from_str('tuple[int, ...]'),
56
+ tuple[int, ...]
57
+ )
58
+ self.assertEqual(
59
+ annotation_conversion.annotation_from_str('tuple[int, str]'),
60
+ tuple[int, str]
61
+ )
62
+ self.assertEqual(
63
+ annotation_conversion.annotation_from_str('list[Foo]', _MODULE),
64
+ list[Foo]
65
+ )
66
+ self.assertEqual(
67
+ annotation_conversion.annotation_from_str('list[Foo.Bar]', _MODULE),
68
+ list[Foo.Bar]
69
+ )
70
+ self.assertEqual(
71
+ annotation_conversion.annotation_from_str('list[Foo.Baz]', _MODULE),
72
+ list[typing.ForwardRef('Foo.Baz', False, _MODULE)]
73
+ )
74
+
75
+ def test_generic_types(self):
76
+ self.assertEqual(
77
+ annotation_conversion.annotation_from_str('typing.List[str]', _MODULE),
78
+ typing.List[str]
79
+ )
80
+
81
+ def test_union(self):
82
+ self.assertEqual(
83
+ annotation_conversion.annotation_from_str(
84
+ 'typing.Union[str, typing.Union[int, float]]', _MODULE),
85
+ typing.Union[str, int, float]
86
+ )
87
+ if sys.version_info >= (3, 10):
88
+ self.assertEqual(
89
+ annotation_conversion.annotation_from_str(
90
+ 'str | int | float', _MODULE),
91
+ typing.Union[str, int, float]
92
+ )
93
+
94
+ def test_literal(self):
95
+ self.assertEqual(
96
+ annotation_conversion.annotation_from_str(
97
+ 'typing.Literal[1, True, "a", \'"b"\', "\\"c\\"", "\\\\"]',
98
+ _MODULE
99
+ ),
100
+ typing.Literal[1, True, 'a', '"b"', '"c"', '\\']
101
+ )
102
+ self.assertEqual(
103
+ annotation_conversion.annotation_from_str(
104
+ 'typing.Literal[(1, 1), f"A {[1]}"]', _MODULE),
105
+ typing.Literal[(1, 1), 'A [1]']
106
+ )
107
+ with self.assertRaisesRegex(SyntaxError, 'Expected "\\["'):
108
+ annotation_conversion.annotation_from_str('typing.Literal', _MODULE)
109
+
110
+ with self.assertRaisesRegex(SyntaxError, 'Unexpected end of annotation'):
111
+ annotation_conversion.annotation_from_str('typing.Literal[1', _MODULE)
112
+
113
+ with self.assertRaisesRegex(
114
+ coding.CodeError, 'Function definition is not allowed'
115
+ ):
116
+ annotation_conversion.annotation_from_str(
117
+ 'typing.Literal[lambda x: x]', _MODULE
118
+ )
119
+
120
+ def test_callable(self):
121
+ self.assertEqual(
122
+ annotation_conversion.annotation_from_str(
123
+ 'typing.Callable[int, int]', _MODULE),
124
+ typing.Callable[[int], int]
125
+ )
126
+ self.assertEqual(
127
+ annotation_conversion.annotation_from_str(
128
+ 'typing.Callable[[int], int]', _MODULE),
129
+ typing.Callable[[int], int]
130
+ )
131
+ self.assertEqual(
132
+ annotation_conversion.annotation_from_str(
133
+ 'typing.Callable[..., None]', _MODULE),
134
+ typing.Callable[..., None]
135
+ )
136
+
137
+ def test_forward_ref(self):
138
+ self.assertEqual(
139
+ annotation_conversion.annotation_from_str(
140
+ 'AAA', _MODULE),
141
+ typing.ForwardRef(
142
+ 'AAA', False, _MODULE
143
+ )
144
+ )
145
+ self.assertEqual(
146
+ annotation_conversion.annotation_from_str(
147
+ 'typing.List[AAA]', _MODULE),
148
+ typing.List[
149
+ typing.ForwardRef(
150
+ 'AAA', False, _MODULE
151
+ )
152
+ ]
153
+ )
154
+
155
+ def test_reloading(self):
156
+ setattr(_MODULE, '__reloading__', True)
157
+ self.assertEqual(
158
+ annotation_conversion.annotation_from_str(
159
+ 'typing.List[Foo]', _MODULE),
160
+ typing.List[
161
+ typing.ForwardRef(
162
+ 'Foo', False, _MODULE
163
+ )
164
+ ]
165
+ )
166
+ self.assertEqual(
167
+ annotation_conversion.annotation_from_str(
168
+ 'typing.List[Foo.Bar]', _MODULE),
169
+ typing.List[
170
+ typing.ForwardRef(
171
+ 'Foo.Bar', False, _MODULE
172
+ )
173
+ ]
174
+ )
175
+ delattr(_MODULE, '__reloading__')
176
+
177
+ def test_bad_annotation(self):
178
+ with self.assertRaisesRegex(SyntaxError, 'Expected type identifier'):
179
+ annotation_conversion.annotation_from_str('typing.List[]')
180
+
181
+ with self.assertRaisesRegex(SyntaxError, 'Expected "]"'):
182
+ annotation_conversion.annotation_from_str('typing.List[int')
183
+
184
+ with self.assertRaisesRegex(SyntaxError, 'Unexpected end of annotation'):
185
+ annotation_conversion.annotation_from_str('typing.List[int]1', _MODULE)
186
+
187
+ with self.assertRaisesRegex(SyntaxError, 'Expected "]"'):
188
+ annotation_conversion.annotation_from_str('typing.Callable[[x')
189
+
190
+ with self.assertRaisesRegex(TypeError, '.* does not exist'):
191
+ annotation_conversion.annotation_from_str('typing.Foo', _MODULE)
32
192
 
33
193
 
34
194
  class FieldFromAnnotationTest(unittest.TestCase):
@@ -132,17 +292,24 @@ class ValueSpecFromAnnotationTest(unittest.TestCase):
132
292
 
133
293
  def test_none(self):
134
294
  self.assertEqual(
135
- ValueSpec.from_annotation(None, False), vs.Any().freeze(None))
295
+ ValueSpec.from_annotation(None, False), vs.Object(type(None)))
136
296
  self.assertEqual(
137
- ValueSpec.from_annotation(None, True), vs.Any().freeze(None))
297
+ ValueSpec.from_annotation('None', True), vs.Object(type(None)))
138
298
  self.assertEqual(
139
- ValueSpec.from_annotation(
140
- None, accept_value_as_annotation=True), vs.Any().noneable())
299
+ ValueSpec.from_annotation(None, True), vs.Object(type(None)))
300
+ self.assertEqual(
301
+ ValueSpec.from_annotation(None, accept_value_as_annotation=True),
302
+ vs.Any().noneable()
303
+ )
141
304
 
142
305
  def test_any(self):
143
306
  self.assertEqual(
144
307
  ValueSpec.from_annotation(typing.Any, False),
145
308
  vs.Any(annotation=typing.Any))
309
+ self.assertEqual(
310
+ ValueSpec.from_annotation('typing.Any', True, parent_module=_MODULE),
311
+ vs.Any(annotation=typing.Any)
312
+ )
146
313
  self.assertEqual(
147
314
  ValueSpec.from_annotation(typing.Any, True),
148
315
  vs.Any(annotation=typing.Any))
@@ -152,6 +319,7 @@ class ValueSpecFromAnnotationTest(unittest.TestCase):
152
319
 
153
320
  def test_bool(self):
154
321
  self.assertEqual(ValueSpec.from_annotation(bool, True), vs.Bool())
322
+ self.assertEqual(ValueSpec.from_annotation('bool', True), vs.Bool())
155
323
  self.assertEqual(
156
324
  ValueSpec.from_annotation(bool, False), vs.Any(annotation=bool))
157
325
  self.assertEqual(
@@ -159,6 +327,7 @@ class ValueSpecFromAnnotationTest(unittest.TestCase):
159
327
 
160
328
  def test_int(self):
161
329
  self.assertEqual(ValueSpec.from_annotation(int, True), vs.Int())
330
+ self.assertEqual(ValueSpec.from_annotation('int', True), vs.Int())
162
331
  self.assertEqual(ValueSpec.from_annotation(int, True, True), vs.Int())
163
332
  self.assertEqual(
164
333
  ValueSpec.from_annotation(int, False), vs.Any(annotation=int))
@@ -182,7 +351,9 @@ class ValueSpecFromAnnotationTest(unittest.TestCase):
182
351
  ValueSpec.from_annotation(str, False, True), vs.Any(annotation=str))
183
352
 
184
353
  self.assertEqual(
185
- ValueSpec.from_annotation('A', False, False), vs.Any(annotation='A'))
354
+ ValueSpec.from_annotation('A', False, False),
355
+ vs.Any(annotation='A')
356
+ )
186
357
  self.assertEqual(
187
358
  ValueSpec.from_annotation('A', False, True), vs.Str('A'))
188
359
  self.assertEqual(
@@ -0,0 +1,135 @@
1
+ # Copyright 2025 The PyGlove Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+ import sys
17
+ import typing
18
+ from typing import List, Literal, Union
19
+ import unittest
20
+
21
+ from pyglove.core import symbolic as pg
22
+ from pyglove.core.typing import key_specs as ks
23
+ from pyglove.core.typing import value_specs as vs
24
+
25
+
26
+ class AnnotationFutureConversionTest(unittest.TestCase):
27
+
28
+ # Class with forward declaration must not be defined in functions.
29
+ class A(pg.Object):
30
+ a: typing.Optional[AnnotationFutureConversionTest.A]
31
+ b: List[AnnotationFutureConversionTest.A]
32
+
33
+ def assert_value_spec(self, cls, field_name, expected_value_spec):
34
+ self.assertEqual(cls.__schema__[field_name].value, expected_value_spec)
35
+
36
+ def test_basics(self):
37
+
38
+ class Foo(pg.Object):
39
+ a: int
40
+ b: float
41
+ c: bool
42
+ d: str
43
+ e: typing.Any
44
+ f: typing.Dict[str, typing.Any]
45
+ g: typing.List[str]
46
+ h: typing.Tuple[int, int]
47
+ i: typing.Callable[[int, int], None]
48
+
49
+ self.assert_value_spec(Foo, 'a', vs.Int())
50
+ self.assert_value_spec(Foo, 'b', vs.Float())
51
+ self.assert_value_spec(Foo, 'c', vs.Bool())
52
+ self.assert_value_spec(Foo, 'd', vs.Str())
53
+ self.assert_value_spec(Foo, 'e', vs.Any(annotation=typing.Any))
54
+ self.assert_value_spec(
55
+ Foo, 'f', vs.Dict([(ks.StrKey(), vs.Any(annotation=typing.Any))])
56
+ )
57
+ self.assert_value_spec(Foo, 'g', vs.List(vs.Str()))
58
+ self.assert_value_spec(Foo, 'h', vs.Tuple([vs.Int(), vs.Int()]))
59
+ self.assert_value_spec(
60
+ Foo, 'i',
61
+ vs.Callable([vs.Int(), vs.Int()], returns=vs.Object(type(None)))
62
+ )
63
+
64
+ def test_list(self):
65
+ if sys.version_info >= (3, 10):
66
+
67
+ class Bar(pg.Object):
68
+ x: list[int | None]
69
+
70
+ self.assert_value_spec(Bar, 'x', vs.List(vs.Int().noneable()))
71
+
72
+ def test_var_length_tuple(self):
73
+
74
+ class Foo(pg.Object):
75
+ x: typing.Tuple[int, ...]
76
+
77
+ self.assert_value_spec(Foo, 'x', vs.Tuple(vs.Int()))
78
+
79
+ if sys.version_info >= (3, 10):
80
+
81
+ class Bar(pg.Object):
82
+ x: tuple[int, ...]
83
+
84
+ self.assert_value_spec(Bar, 'x', vs.Tuple(vs.Int()))
85
+
86
+ def test_optional(self):
87
+
88
+ class Foo(pg.Object):
89
+ x: typing.Optional[int]
90
+
91
+ self.assert_value_spec(Foo, 'x', vs.Int().noneable())
92
+
93
+ if sys.version_info >= (3, 10):
94
+ class Bar(pg.Object):
95
+ x: int | None
96
+
97
+ self.assert_value_spec(Bar, 'x', vs.Int().noneable())
98
+
99
+ def test_union(self):
100
+
101
+ class Foo(pg.Object):
102
+ x: Union[int, typing.Union[str, bool], None]
103
+
104
+ self.assert_value_spec(
105
+ Foo, 'x', vs.Union([vs.Int(), vs.Str(), vs.Bool()]).noneable()
106
+ )
107
+
108
+ if sys.version_info >= (3, 10):
109
+
110
+ class Bar(pg.Object):
111
+ x: int | str | bool
112
+
113
+ self.assert_value_spec(
114
+ Bar, 'x', vs.Union([vs.Int(), vs.Str(), vs.Bool()])
115
+ )
116
+
117
+ def test_literal(self):
118
+
119
+ class Foo(pg.Object):
120
+ x: Literal[1, True, 'abc']
121
+
122
+ self.assert_value_spec(
123
+ Foo, 'x', vs.Enum(vs.MISSING_VALUE, [1, True, 'abc'])
124
+ )
125
+
126
+ def test_self_referencial(self):
127
+ self.assert_value_spec(
128
+ self.A, 'a', vs.Object(self.A).noneable()
129
+ )
130
+ self.assert_value_spec(
131
+ self.A, 'b', vs.List(vs.Object(self.A))
132
+ )
133
+
134
+ if __name__ == '__main__':
135
+ unittest.main()
@@ -97,9 +97,9 @@ class KeySpec(utils.Formattable, utils.JSONConvertible):
97
97
  class ForwardRef(utils.Formattable):
98
98
  """Forward type reference."""
99
99
 
100
- def __init__(self, module: types.ModuleType, name: str):
100
+ def __init__(self, module: types.ModuleType, qualname: str):
101
101
  self._module = module
102
- self._name = name
102
+ self._qualname = qualname
103
103
 
104
104
  @property
105
105
  def module(self) -> types.ModuleType:
@@ -109,33 +109,49 @@ class ForwardRef(utils.Formattable):
109
109
  @property
110
110
  def name(self) -> str:
111
111
  """Returns the name of the type reference."""
112
- return self._name
112
+ return self._qualname.split('.')[-1]
113
113
 
114
114
  @property
115
115
  def qualname(self) -> str:
116
116
  """Returns the qualified name of the reference."""
117
- return f'{self.module.__name__}.{self.name}'
117
+ return self._qualname
118
+
119
+ @property
120
+ def type_id(self) -> str:
121
+ """Returns the type id of the reference."""
122
+ return f'{self.module.__name__}.{self.qualname}'
118
123
 
119
124
  def as_annotation(self) -> Union[Type[Any], str]:
120
125
  """Returns the forward reference as an annotation."""
121
- return self.cls if self.resolved else self.name
126
+ return self.cls if self.resolved else self.qualname
122
127
 
123
128
  @property
124
129
  def resolved(self) -> bool:
125
130
  """Returns True if the symbol for the name is resolved.."""
126
- return hasattr(self.module, self.name)
131
+ return self._resolve() is not None
132
+
133
+ def _resolve(self) -> Optional[Any]:
134
+ names = self._qualname.split('.')
135
+ parent_obj = self.module
136
+ for name in names:
137
+ parent_obj = getattr(parent_obj, name, utils.MISSING_VALUE)
138
+ if parent_obj == utils.MISSING_VALUE:
139
+ return None
140
+ if not inspect.isclass(parent_obj):
141
+ raise TypeError(
142
+ f'{self.qualname!r} from module {self.module.__name__!r} '
143
+ 'is not a class.'
144
+ )
145
+ return parent_obj
127
146
 
128
147
  @property
129
148
  def cls(self) -> Type[Any]:
130
149
  """Returns the resolved reference class.."""
131
- reference = getattr(self.module, self.name, None)
150
+ reference = self._resolve()
132
151
  if reference is None:
133
152
  raise TypeError(
134
- f'{self.name!r} does not exist in module {self.module.__name__!r}'
135
- )
136
- elif not inspect.isclass(reference):
137
- raise TypeError(
138
- f'{self.name!r} from module {self.module.__name__!r} is not a class.'
153
+ f'{self.qualname!r} does not exist in '
154
+ f'module {self.module.__name__!r}'
139
155
  )
140
156
  return reference
141
157
 
@@ -150,7 +166,7 @@ class ForwardRef(utils.Formattable):
150
166
  return utils.kvlist_str(
151
167
  [
152
168
  ('module', self.module.__name__, None),
153
- ('name', self.name, None),
169
+ ('name', self.qualname, None),
154
170
  ],
155
171
  label=self.__class__.__name__,
156
172
  compact=compact,
@@ -164,7 +180,7 @@ class ForwardRef(utils.Formattable):
164
180
  if self is other:
165
181
  return True
166
182
  elif isinstance(other, ForwardRef):
167
- return self.module is other.module and self.name == other.name
183
+ return self.module is other.module and self.qualname == other.qualname
168
184
  elif inspect.isclass(other):
169
185
  return self.resolved and self.cls is other # pytype: disable=bad-return-type
170
186
 
@@ -173,11 +189,11 @@ class ForwardRef(utils.Formattable):
173
189
  return not self.__eq__(other)
174
190
 
175
191
  def __hash__(self) -> int:
176
- return hash((self.module, self.name))
192
+ return hash((self.module, self.qualname))
177
193
 
178
194
  def __deepcopy__(self, memo) -> 'ForwardRef':
179
195
  """Override deep copy to avoid copying module."""
180
- return ForwardRef(self.module, self.name)
196
+ return ForwardRef(self.module, self.qualname)
181
197
 
182
198
 
183
199
  class ValueSpec(utils.Formattable, utils.JSONConvertible):
@@ -628,9 +644,11 @@ class Field(utils.Formattable, utils.JSONConvertible):
628
644
  annotation: Any,
629
645
  description: Optional[str] = None,
630
646
  metadata: Optional[Dict[str, Any]] = None,
631
- auto_typing=True) -> 'Field':
647
+ auto_typing=True,
648
+ parent_module: Optional[types.ModuleType] = None
649
+ ) -> 'Field':
632
650
  """Gets a Field from annotation."""
633
- del key, annotation, description, metadata, auto_typing
651
+ del key, annotation, description, metadata, auto_typing, parent_module
634
652
  assert False, 'Overridden in `annotation_conversion.py`.'
635
653
 
636
654
  @property
@@ -31,24 +31,31 @@ from pyglove.core.typing.class_schema import Schema
31
31
  class ForwardRefTest(unittest.TestCase):
32
32
  """Test for `ForwardRef` class."""
33
33
 
34
+ class A:
35
+ pass
36
+
34
37
  def setUp(self):
35
38
  super().setUp()
36
39
  self._module = sys.modules[__name__]
37
40
 
38
41
  def test_basics(self):
39
- r = class_schema.ForwardRef(self._module, 'FieldTest')
42
+ r = class_schema.ForwardRef(self._module, 'ForwardRefTest.A')
40
43
  self.assertIs(r.module, self._module)
41
- self.assertEqual(r.name, 'FieldTest')
42
- self.assertEqual(r.qualname, f'{self._module.__name__}.FieldTest')
44
+ self.assertEqual(r.name, 'A')
45
+ self.assertEqual(r.qualname, 'ForwardRefTest.A')
46
+ self.assertEqual(r.type_id, f'{self._module.__name__}.ForwardRefTest.A')
43
47
 
44
48
  def test_resolved(self):
45
- self.assertTrue(class_schema.ForwardRef(self._module, 'FieldTest').resolved)
49
+ self.assertTrue(
50
+ class_schema.ForwardRef(self._module, 'ForwardRefTest.A').resolved
51
+ )
46
52
  self.assertFalse(class_schema.ForwardRef(self._module, 'Foo').resolved)
47
53
 
48
54
  def test_as_annotation(self):
49
55
  self.assertEqual(
50
- class_schema.ForwardRef(self._module, 'FieldTest').as_annotation(),
51
- FieldTest,
56
+ class_schema.ForwardRef(
57
+ self._module, 'ForwardRefTest.A').as_annotation(),
58
+ ForwardRefTest.A,
52
59
  )
53
60
  self.assertEqual(
54
61
  class_schema.ForwardRef(self._module, 'Foo').as_annotation(), 'Foo'
@@ -56,7 +63,8 @@ class ForwardRefTest(unittest.TestCase):
56
63
 
57
64
  def test_cls(self):
58
65
  self.assertIs(
59
- class_schema.ForwardRef(self._module, 'FieldTest').cls, FieldTest
66
+ class_schema.ForwardRef(self._module, 'ForwardRefTest.A').cls,
67
+ ForwardRefTest.A
60
68
  )
61
69
 
62
70
  with self.assertRaisesRegex(TypeError, '.* does not exist in module'):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyglove
3
- Version: 0.4.5.dev202502040809
3
+ Version: 0.4.5.dev202502060809
4
4
  Summary: PyGlove: A library for manipulating Python objects.
5
5
  Home-page: https://github.com/google/pyglove
6
6
  Author: PyGlove Authors
@@ -88,7 +88,7 @@ pyglove/core/symbolic/inferred.py,sha256=E4zgphg6NNZad9Fl3jdHQOMZeqEp9XHq5OUYqXE
88
88
  pyglove/core/symbolic/inferred_test.py,sha256=G6uPykONcChvs6vZujXHSWaYfjewLTVBscMqzzKNty0,1270
89
89
  pyglove/core/symbolic/list.py,sha256=z8goU0ntd-Q5ADaCGiKsJwPhdRdQb0Kd_p-ZekXaLy4,30303
90
90
  pyglove/core/symbolic/list_test.py,sha256=IAyFQ48nyczKUcPNZFKHBkX5oh7Xuxbnv3rRkONhbHw,61146
91
- pyglove/core/symbolic/object.py,sha256=T1HtGYNA1qLp0iTZVmgR90nfMLuOQVNopkBPYw_j8B8,42021
91
+ pyglove/core/symbolic/object.py,sha256=6bCg20r76sfHfVA-BBQYFxTYBzOM_c8BihHEKMKt3i8,42093
92
92
  pyglove/core/symbolic/object_test.py,sha256=vIL_ymGpPPXamIljEBdufpw-p82kTeWNy2IbgzGQjig,93811
93
93
  pyglove/core/symbolic/origin.py,sha256=OSWMKjvPcISOXrzuX3lCQC8m_qaGl-9INsIB81erUnU,6124
94
94
  pyglove/core/symbolic/origin_test.py,sha256=dU_ZGrGDetM_lYVMn3wQO0d367_t_t8eESe3NrKPBNE,3159
@@ -106,17 +106,18 @@ pyglove/core/tuning/protocols.py,sha256=10Iukt1rqh05caURTZffSsb3CcHo7epBQnNtnyMy
106
106
  pyglove/core/tuning/protocols_test.py,sha256=Cbzvz3EacaW2sbm1rTSQXEt_VucMoQbeQ6AeN-GV5Vc,1883
107
107
  pyglove/core/tuning/sample.py,sha256=UzsCY8kiqnzH_mR94zLXhOloyvvEwfmBWluBjmefUFA,12975
108
108
  pyglove/core/tuning/sample_test.py,sha256=JqwDPy3EPC_VjU9dipk90jj1kovZB3Zb9hAjAlZ-U1U,17551
109
- pyglove/core/typing/__init__.py,sha256=Z_jnwUvDQ3wA16a5TiuWbuof9hW0Xm6YoTNwgG4QGqI,14395
109
+ pyglove/core/typing/__init__.py,sha256=MwraQ-6K8NYI_xk4V3oaVSkrPqiU9InDtAgXPSRBvik,14494
110
110
  pyglove/core/typing/annotated.py,sha256=llaajIDj9GK-4kUGJoO4JsHU6ESPOra2SZ-jG6xmsOQ,3203
111
111
  pyglove/core/typing/annotated_test.py,sha256=p1qid3R-jeiOTTxOVq6hXW8XFvn-h1cUzJWISPst2l8,2484
112
- pyglove/core/typing/annotation_conversion.py,sha256=OhBQRCaSh7zJKfb_IQKF7sD32E_YNTekz00qZaSIVnw,9007
113
- pyglove/core/typing/annotation_conversion_test.py,sha256=tZheqbLWbr76WBIDOplLtY3yznMc4m9u7KCznWEJdEs,11660
112
+ pyglove/core/typing/annotation_conversion.py,sha256=7aMxv5AhC5oYteBtVTRp1no16dFgfGFCLbmxVasrdzQ,15557
113
+ pyglove/core/typing/annotation_conversion_test.py,sha256=tOUM7k_VNiGfBlSy_r6vMKk0lD8gRNT6dqnzBzvhul4,17422
114
+ pyglove/core/typing/annotation_future_test.py,sha256=ZVLU3kheO2V-nrp-m5jrGMLgGMs4S6RWnJbyD9KbHl4,3746
114
115
  pyglove/core/typing/callable_ext.py,sha256=PiBQWPeUAH7Lgmf2xKCZqgK7N0OSrTdbnEkV8Ph31OA,9127
115
116
  pyglove/core/typing/callable_ext_test.py,sha256=TnWKU4_ZjvpbHZFtFHgFvCMDiCos8VmLlODcM_7Xg8M,10156
116
117
  pyglove/core/typing/callable_signature.py,sha256=DRpt7aShfkn8pb3SCiZzS_27eHbkQ_d2UB8BUhJjs0Q,27176
117
118
  pyglove/core/typing/callable_signature_test.py,sha256=iQmHsKPhJPQlMikDhEyxKyq7yWyXI9juKCLYgKhrH3U,25145
118
- pyglove/core/typing/class_schema.py,sha256=xfPrLYpqntejMLvSL-7rNiXHTQOgAr2kmxWeR890qLs,53940
119
- pyglove/core/typing/class_schema_test.py,sha256=UWANPqhu9v_FHNo3cVe05P-bO-HliBmrSBywKrlWep0,29204
119
+ pyglove/core/typing/class_schema.py,sha256=dKmF1dwv-7RmCjnAb412AciYZL3LJCToU5WTJWYLj3Y,54482
120
+ pyglove/core/typing/class_schema_test.py,sha256=RurRdCyPuypKJ7izgcq9zW3JNHgODiJdQvDn0BDZDjU,29353
120
121
  pyglove/core/typing/custom_typing.py,sha256=qdnIKHWNt5kZAAFdpQXra8bBu6RljMbbJ_YDG2mhAUA,2205
121
122
  pyglove/core/typing/inspect.py,sha256=VLSz1KAunNm2hx0eEMjiwxKLl9FHlKr9nHelLT25iEA,7726
122
123
  pyglove/core/typing/inspect_test.py,sha256=xclevobF0X8c_B5b1q1dkBJZN1TsVA1RUhk5l25DUCM,10248
@@ -211,8 +212,8 @@ pyglove/ext/scalars/randoms.py,sha256=LkMIIx7lOq_lvJvVS3BrgWGuWl7Pi91-lA-O8x_gZs
211
212
  pyglove/ext/scalars/randoms_test.py,sha256=nEhiqarg8l_5EOucp59CYrpO2uKxS1pe0hmBdZUzRNM,2000
212
213
  pyglove/ext/scalars/step_wise.py,sha256=IDw3tuTpv0KVh7AN44W43zqm1-E0HWPUlytWOQC9w3Y,3789
213
214
  pyglove/ext/scalars/step_wise_test.py,sha256=TL1vJ19xVx2t5HKuyIzGoogF7N3Rm8YhLE6JF7i0iy8,2540
214
- pyglove-0.4.5.dev202502040809.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
215
- pyglove-0.4.5.dev202502040809.dist-info/METADATA,sha256=g47xUh8StCt5ojg7Arw96L0nds0Kl8hrGazZx7IsU10,7067
216
- pyglove-0.4.5.dev202502040809.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
217
- pyglove-0.4.5.dev202502040809.dist-info/top_level.txt,sha256=wITzJSKcj8GZUkbq-MvUQnFadkiuAv_qv5qQMw0fIow,8
218
- pyglove-0.4.5.dev202502040809.dist-info/RECORD,,
215
+ pyglove-0.4.5.dev202502060809.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
216
+ pyglove-0.4.5.dev202502060809.dist-info/METADATA,sha256=r1aB5PMkBiDYY5J67WcefYKfO6WHmnF7nGNB5y6gLVY,7067
217
+ pyglove-0.4.5.dev202502060809.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
218
+ pyglove-0.4.5.dev202502060809.dist-info/top_level.txt,sha256=wITzJSKcj8GZUkbq-MvUQnFadkiuAv_qv5qQMw0fIow,8
219
+ pyglove-0.4.5.dev202502060809.dist-info/RECORD,,