pyglove 0.4.5.dev202410020809__py3-none-any.whl → 0.4.5.dev202410100808__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyglove/core/__init__.py +9 -0
- pyglove/core/object_utils/__init__.py +1 -0
- pyglove/core/object_utils/value_location.py +10 -0
- pyglove/core/object_utils/value_location_test.py +22 -0
- pyglove/core/symbolic/base.py +40 -2
- pyglove/core/symbolic/base_test.py +67 -0
- pyglove/core/symbolic/diff.py +190 -1
- pyglove/core/symbolic/diff_test.py +290 -0
- pyglove/core/symbolic/object_test.py +3 -8
- pyglove/core/symbolic/ref.py +29 -0
- pyglove/core/symbolic/ref_test.py +143 -0
- pyglove/core/typing/__init__.py +4 -0
- pyglove/core/typing/callable_ext.py +240 -1
- pyglove/core/typing/callable_ext_test.py +255 -0
- pyglove/core/typing/inspect.py +63 -0
- pyglove/core/typing/inspect_test.py +39 -0
- pyglove/core/views/__init__.py +30 -0
- pyglove/core/views/base.py +906 -0
- pyglove/core/views/base_test.py +615 -0
- pyglove/core/views/html/__init__.py +27 -0
- pyglove/core/views/html/base.py +529 -0
- pyglove/core/views/html/base_test.py +804 -0
- pyglove/core/views/html/tree_view.py +1052 -0
- pyglove/core/views/html/tree_view_test.py +748 -0
- {pyglove-0.4.5.dev202410020809.dist-info → pyglove-0.4.5.dev202410100808.dist-info}/METADATA +1 -1
- {pyglove-0.4.5.dev202410020809.dist-info → pyglove-0.4.5.dev202410100808.dist-info}/RECORD +29 -21
- {pyglove-0.4.5.dev202410020809.dist-info → pyglove-0.4.5.dev202410100808.dist-info}/LICENSE +0 -0
- {pyglove-0.4.5.dev202410020809.dist-info → pyglove-0.4.5.dev202410100808.dist-info}/WHEEL +0 -0
- {pyglove-0.4.5.dev202410020809.dist-info → pyglove-0.4.5.dev202410100808.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,250 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Callable extensions."""
|
15
15
|
|
16
|
-
|
16
|
+
import contextlib
|
17
|
+
import functools
|
18
|
+
import inspect
|
19
|
+
import types
|
20
|
+
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
|
17
21
|
|
22
|
+
from pyglove.core import object_utils
|
18
23
|
from pyglove.core.typing import callable_signature
|
19
24
|
|
20
25
|
|
26
|
+
_TLS_KEY_PRESET_KWARGS = '__preset_kwargs__'
|
27
|
+
|
28
|
+
|
29
|
+
class PresetArgValue(object_utils.Formattable):
|
30
|
+
"""Value placeholder for arguments whose value will be provided by presets.
|
31
|
+
|
32
|
+
Example:
|
33
|
+
|
34
|
+
def foo(x, y=pg.PresetArgValue(default=1))
|
35
|
+
return x + y
|
36
|
+
|
37
|
+
with pg.preset_args(y=2):
|
38
|
+
print(foo(x=1)) # 3: y=2
|
39
|
+
print(foo(x=1)) # 2: y=1
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self, default: Any = object_utils.MISSING_VALUE):
|
43
|
+
self.default = default
|
44
|
+
|
45
|
+
@property
|
46
|
+
def has_default(self) -> bool:
|
47
|
+
return self.default != object_utils.MISSING_VALUE
|
48
|
+
|
49
|
+
def __eq__(self, other: Any) -> bool:
|
50
|
+
return isinstance(other, PresetArgValue) and (
|
51
|
+
self.default == other.default
|
52
|
+
)
|
53
|
+
|
54
|
+
def __ne__(self, other: Any) -> bool:
|
55
|
+
return not self.__eq__(other)
|
56
|
+
|
57
|
+
def format(self, *args, **kwargs):
|
58
|
+
return object_utils.kvlist_str(
|
59
|
+
[
|
60
|
+
('default', self.default, object_utils.MISSING_VALUE),
|
61
|
+
],
|
62
|
+
label='PresetArgValue',
|
63
|
+
*args,
|
64
|
+
**kwargs
|
65
|
+
)
|
66
|
+
|
67
|
+
@classmethod
|
68
|
+
def inspect(
|
69
|
+
cls, func: types.FunctionType) -> Dict[str, 'PresetArgValue']:
|
70
|
+
"""Gets the PresetArgValue specified in a function's signature."""
|
71
|
+
assert inspect.isfunction(func), func
|
72
|
+
sig = inspect.signature(func)
|
73
|
+
preset_arg_markers = {}
|
74
|
+
for p in sig.parameters.values():
|
75
|
+
if isinstance(p.default, cls):
|
76
|
+
preset_arg_markers[p.name] = p.default
|
77
|
+
return preset_arg_markers
|
78
|
+
|
79
|
+
@classmethod
|
80
|
+
def resolve_args(
|
81
|
+
cls,
|
82
|
+
call_args: Tuple[Any, ...],
|
83
|
+
call_kwargs: Dict[str, Any],
|
84
|
+
positional_arg_names: Sequence[str],
|
85
|
+
arg_defaults: Dict[str, Any],
|
86
|
+
preset_kwargs: Dict[str, Any],
|
87
|
+
include_all_preset_kwargs: bool = False,
|
88
|
+
) -> Tuple[Sequence[Any], Dict[str, Any]]:
|
89
|
+
"""Resolves calling arguments passed to a method with presets."""
|
90
|
+
# Step 1: compute marked kwargs.
|
91
|
+
resolved_kwargs = {}
|
92
|
+
for arg_name, arg_default in arg_defaults.items():
|
93
|
+
if not isinstance(arg_default, PresetArgValue):
|
94
|
+
resolved_kwargs[arg_name] = arg_default
|
95
|
+
continue
|
96
|
+
if arg_name in preset_kwargs:
|
97
|
+
resolved_kwargs[arg_name] = preset_kwargs[arg_name]
|
98
|
+
elif arg_default.has_default:
|
99
|
+
resolved_kwargs[arg_name] = arg_default.default
|
100
|
+
else:
|
101
|
+
raise ValueError(
|
102
|
+
f'Argument {arg_name!r} is not present as a keyword argument '
|
103
|
+
'from the caller.'
|
104
|
+
)
|
105
|
+
|
106
|
+
# Step 2: add preset kwargs
|
107
|
+
if include_all_preset_kwargs:
|
108
|
+
for k, v in preset_kwargs.items():
|
109
|
+
if k not in resolved_kwargs:
|
110
|
+
resolved_kwargs[k] = v
|
111
|
+
|
112
|
+
# Step 3: merge call kwargs with resolved preset kwargs.
|
113
|
+
resolved_kwargs.update(call_kwargs)
|
114
|
+
|
115
|
+
# Step 3: remove resolved kwargs items as it's present in call args.
|
116
|
+
for i in range(len(call_args)):
|
117
|
+
if i >= len(positional_arg_names):
|
118
|
+
break
|
119
|
+
resolved_kwargs.pop(positional_arg_names[i], None)
|
120
|
+
|
121
|
+
# Step 4: convert kwargs back to postional arguments if applicable
|
122
|
+
resolved_args = call_args
|
123
|
+
if len(positional_arg_names) > len(call_args):
|
124
|
+
resolved_args = list(resolved_args)
|
125
|
+
for arg_name in positional_arg_names[len(call_args):]:
|
126
|
+
if arg_name not in resolved_kwargs:
|
127
|
+
break
|
128
|
+
arg_value = resolved_kwargs.pop(arg_name)
|
129
|
+
resolved_args.append(arg_value)
|
130
|
+
return resolved_args, resolved_kwargs
|
131
|
+
|
132
|
+
|
133
|
+
class _ArgPresets:
|
134
|
+
"""Preset argument collection."""
|
135
|
+
|
136
|
+
def __init__(self, presets: Optional[Dict[str, Dict[str, Any]]] = None):
|
137
|
+
self._presets: Dict[str, Dict[str, Any]] = presets or {}
|
138
|
+
|
139
|
+
def derive(
|
140
|
+
self,
|
141
|
+
kwargs: Dict[str, Any],
|
142
|
+
preset_name: str = 'global',
|
143
|
+
inherit_preset: Union[str, bool] = False
|
144
|
+
) -> '_ArgPresets':
|
145
|
+
"""Derives new presets from current presets."""
|
146
|
+
presets = self._presets.copy() # Just do a shallow copy.
|
147
|
+
if isinstance(inherit_preset, bool) and inherit_preset:
|
148
|
+
inherit_preset = preset_name
|
149
|
+
|
150
|
+
if inherit_preset and inherit_preset in presets:
|
151
|
+
current_preset = presets[inherit_preset].copy()
|
152
|
+
current_preset.update(kwargs)
|
153
|
+
else:
|
154
|
+
current_preset = kwargs
|
155
|
+
presets[preset_name] = current_preset
|
156
|
+
return _ArgPresets(presets)
|
157
|
+
|
158
|
+
def get_preset(self, preset_name: str) -> Dict[str, Any]:
|
159
|
+
return self._presets.get(preset_name, {})
|
160
|
+
|
161
|
+
|
162
|
+
@contextlib.contextmanager
|
163
|
+
def preset_args(
|
164
|
+
kwargs: Dict[str, Any],
|
165
|
+
*,
|
166
|
+
preset_name: str = 'global',
|
167
|
+
inherit_preset: Union[str, bool] = False
|
168
|
+
) -> Iterator[Dict[str, Any]]:
|
169
|
+
"""Context manager to enable calling with user kwargs.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
kwargs: The preset kwargs to be used by preset-enabled functions within
|
173
|
+
the context.
|
174
|
+
preset_name: The name of the preset to specify kwargs.
|
175
|
+
`enable_preset_args` allows users to pass a preset name, which will be
|
176
|
+
used to identify the present to be used.
|
177
|
+
inherit_preset: The name of the preset defined by the parent context to
|
178
|
+
inherit kwargs from. Or a boolean to indicate whether to inherit a
|
179
|
+
parent preset of the same name.
|
180
|
+
|
181
|
+
Yields:
|
182
|
+
Current preset kwargs.
|
183
|
+
"""
|
184
|
+
|
185
|
+
parent_presets = object_utils.thread_local_peek(
|
186
|
+
_TLS_KEY_PRESET_KWARGS, _ArgPresets()
|
187
|
+
)
|
188
|
+
current_preset = parent_presets.derive(kwargs, preset_name, inherit_preset)
|
189
|
+
object_utils.thread_local_push(_TLS_KEY_PRESET_KWARGS, current_preset)
|
190
|
+
try:
|
191
|
+
yield current_preset
|
192
|
+
finally:
|
193
|
+
object_utils.thread_local_pop(_TLS_KEY_PRESET_KWARGS, None)
|
194
|
+
|
195
|
+
|
196
|
+
def enable_preset_args(
|
197
|
+
include_all_preset_kwargs: bool = False,
|
198
|
+
preset_name: str = 'global'
|
199
|
+
) -> Callable[[types.FunctionType], types.FunctionType]:
|
200
|
+
"""Decorator for functions that maybe use preset argument values.
|
201
|
+
|
202
|
+
Usage::
|
203
|
+
|
204
|
+
@pg.typing.enable_preset_args
|
205
|
+
def foo(x, y=pg.typing.PresetArgValue(default=1)):
|
206
|
+
return x + y
|
207
|
+
|
208
|
+
with pg.typing.preset_args(y=2):
|
209
|
+
print(foo(x=1)) # 3: y=2
|
210
|
+
print(foo(x=1)) # 2: y=1
|
211
|
+
|
212
|
+
Args:
|
213
|
+
include_all_preset_kwargs: Whether to include all preset kwargs (even
|
214
|
+
not makred as `PresetArgValue`) when callng the function.
|
215
|
+
preset_name: The name of the preset to specify kwargs.
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
A decorated function that could consume the preset argument values.
|
219
|
+
"""
|
220
|
+
def decorator(func):
|
221
|
+
sig = inspect.signature(func)
|
222
|
+
positional_arg_names = [
|
223
|
+
p.name for p in sig.parameters.values()
|
224
|
+
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
225
|
+
]
|
226
|
+
|
227
|
+
arg_defaults = {}
|
228
|
+
has_preset_value = False
|
229
|
+
has_varkw = False
|
230
|
+
for p in sig.parameters.values():
|
231
|
+
if p.kind == inspect.Parameter.VAR_KEYWORD:
|
232
|
+
has_varkw = True
|
233
|
+
continue
|
234
|
+
if p.kind == inspect.Parameter.VAR_POSITIONAL:
|
235
|
+
continue
|
236
|
+
if p.default == inspect.Parameter.empty:
|
237
|
+
continue
|
238
|
+
if isinstance(p.default, PresetArgValue):
|
239
|
+
has_preset_value = True
|
240
|
+
arg_defaults[p.name] = p.default
|
241
|
+
|
242
|
+
if has_preset_value:
|
243
|
+
@functools.wraps(func)
|
244
|
+
def _func(*args, **kwargs):
|
245
|
+
# Map positional arguments to keyword arguments.
|
246
|
+
presets = object_utils.thread_local_peek(
|
247
|
+
_TLS_KEY_PRESET_KWARGS, None
|
248
|
+
)
|
249
|
+
preset_kwargs = presets.get_preset(preset_name) if presets else {}
|
250
|
+
args, kwargs = PresetArgValue.resolve_args(
|
251
|
+
args, kwargs, positional_arg_names, arg_defaults, preset_kwargs,
|
252
|
+
include_all_preset_kwargs=include_all_preset_kwargs and has_varkw
|
253
|
+
)
|
254
|
+
return func(*args, **kwargs)
|
255
|
+
return _func
|
256
|
+
return func
|
257
|
+
return decorator
|
258
|
+
|
259
|
+
|
21
260
|
class CallableWithOptionalKeywordArgs:
|
22
261
|
"""Helper class for invoking callable objects with optional keyword args.
|
23
262
|
|
@@ -18,6 +18,261 @@ from pyglove.core.typing import annotation_conversion # pylint: disable=unused
|
|
18
18
|
from pyglove.core.typing import callable_ext
|
19
19
|
|
20
20
|
|
21
|
+
class PresetArgValueTest(unittest.TestCase):
|
22
|
+
"""Tests for typing.PresetArgValue."""
|
23
|
+
|
24
|
+
def test_basics(self):
|
25
|
+
v = callable_ext.PresetArgValue()
|
26
|
+
self.assertFalse(v.has_default)
|
27
|
+
v = callable_ext.PresetArgValue(default=1)
|
28
|
+
self.assertTrue(v.has_default)
|
29
|
+
self.assertEqual(v.default, 1)
|
30
|
+
self.assertEqual(repr(v), 'PresetArgValue(default=1)')
|
31
|
+
self.assertEqual(str(v), 'PresetArgValue(\n default=1\n)')
|
32
|
+
self.assertEqual(
|
33
|
+
callable_ext.PresetArgValue(), callable_ext.PresetArgValue()
|
34
|
+
)
|
35
|
+
self.assertEqual(
|
36
|
+
callable_ext.PresetArgValue(1), callable_ext.PresetArgValue(1)
|
37
|
+
)
|
38
|
+
self.assertNotEqual(
|
39
|
+
callable_ext.PresetArgValue(), callable_ext.PresetArgValue(default=1)
|
40
|
+
)
|
41
|
+
|
42
|
+
def test_inspect(self):
|
43
|
+
|
44
|
+
def foo(
|
45
|
+
x=callable_ext.PresetArgValue(),
|
46
|
+
y=1,
|
47
|
+
*,
|
48
|
+
z=callable_ext.PresetArgValue(default=2)
|
49
|
+
):
|
50
|
+
return x + y + z
|
51
|
+
|
52
|
+
self.assertEqual(
|
53
|
+
callable_ext.PresetArgValue.inspect(foo),
|
54
|
+
dict(
|
55
|
+
x=callable_ext.PresetArgValue(),
|
56
|
+
z=callable_ext.PresetArgValue(default=2)
|
57
|
+
)
|
58
|
+
)
|
59
|
+
|
60
|
+
def bar(x, y=1):
|
61
|
+
return x + y
|
62
|
+
|
63
|
+
self.assertEqual(
|
64
|
+
callable_ext.PresetArgValue.inspect(bar),
|
65
|
+
{}
|
66
|
+
)
|
67
|
+
|
68
|
+
def test_resolve_args(self):
|
69
|
+
# Both positional and keyword arguments are from preset.
|
70
|
+
self.assertEqual(
|
71
|
+
callable_ext.PresetArgValue.resolve_args(
|
72
|
+
call_args=[],
|
73
|
+
call_kwargs=dict(),
|
74
|
+
positional_arg_names=['x', 'y'],
|
75
|
+
arg_defaults={
|
76
|
+
'x': callable_ext.PresetArgValue(),
|
77
|
+
'z': callable_ext.PresetArgValue(default=2),
|
78
|
+
},
|
79
|
+
preset_kwargs=dict(x=1, y=2, z=3, w=4)
|
80
|
+
),
|
81
|
+
([1], dict(z=3))
|
82
|
+
)
|
83
|
+
# Both positional and keyword arguments are absent and use the default.
|
84
|
+
self.assertEqual(
|
85
|
+
callable_ext.PresetArgValue.resolve_args(
|
86
|
+
call_args=[],
|
87
|
+
call_kwargs=dict(),
|
88
|
+
positional_arg_names=['x', 'y'],
|
89
|
+
arg_defaults={
|
90
|
+
'x': callable_ext.PresetArgValue(default=1),
|
91
|
+
'y': 0,
|
92
|
+
'z': callable_ext.PresetArgValue(default=2),
|
93
|
+
},
|
94
|
+
preset_kwargs=dict()
|
95
|
+
),
|
96
|
+
([1, 0], dict(z=2))
|
97
|
+
)
|
98
|
+
# Positional args from preset, keyword argument use preset default.
|
99
|
+
self.assertEqual(
|
100
|
+
callable_ext.PresetArgValue.resolve_args(
|
101
|
+
call_args=[],
|
102
|
+
call_kwargs=dict(),
|
103
|
+
positional_arg_names=['x', 'y'],
|
104
|
+
arg_defaults={
|
105
|
+
'x': callable_ext.PresetArgValue(),
|
106
|
+
'y': 0,
|
107
|
+
'z': callable_ext.PresetArgValue(default=2),
|
108
|
+
},
|
109
|
+
preset_kwargs=dict(x=1, y=2)
|
110
|
+
),
|
111
|
+
([1, 0], dict(z=2))
|
112
|
+
)
|
113
|
+
# Postional argument provided by user, which should take precedence over
|
114
|
+
# the preset value.
|
115
|
+
self.assertEqual(
|
116
|
+
callable_ext.PresetArgValue.resolve_args(
|
117
|
+
call_args=[2],
|
118
|
+
call_kwargs={},
|
119
|
+
positional_arg_names=['x', 'y'],
|
120
|
+
arg_defaults={
|
121
|
+
'x': callable_ext.PresetArgValue(),
|
122
|
+
'y': 0,
|
123
|
+
'z': callable_ext.PresetArgValue(default=2),
|
124
|
+
},
|
125
|
+
preset_kwargs=dict(x=1, y=2, z=3, w=4)
|
126
|
+
),
|
127
|
+
([2, 0], dict(z=3))
|
128
|
+
)
|
129
|
+
# Postional argument provided in keyword.
|
130
|
+
self.assertEqual(
|
131
|
+
callable_ext.PresetArgValue.resolve_args(
|
132
|
+
call_args=[],
|
133
|
+
call_kwargs=dict(x=2),
|
134
|
+
positional_arg_names=['x', 'y'],
|
135
|
+
arg_defaults={
|
136
|
+
'x': callable_ext.PresetArgValue(),
|
137
|
+
'y': 0,
|
138
|
+
'z': callable_ext.PresetArgValue(default=2),
|
139
|
+
},
|
140
|
+
preset_kwargs=dict(x=1, y=2, z=3, w=4)
|
141
|
+
),
|
142
|
+
([2, 0], dict(z=3))
|
143
|
+
)
|
144
|
+
# Postional argument provided in keyword, and there are more args
|
145
|
+
# (due to varargs)
|
146
|
+
self.assertEqual(
|
147
|
+
callable_ext.PresetArgValue.resolve_args(
|
148
|
+
call_args=[1, 2, 3],
|
149
|
+
call_kwargs=dict(x=2),
|
150
|
+
positional_arg_names=['x', 'y'],
|
151
|
+
arg_defaults={
|
152
|
+
'x': callable_ext.PresetArgValue(),
|
153
|
+
'y': 0,
|
154
|
+
'z': callable_ext.PresetArgValue(default=2),
|
155
|
+
},
|
156
|
+
preset_kwargs=dict(x=1, y=2, z=3, w=4)
|
157
|
+
),
|
158
|
+
([1, 2, 3], dict(z=3))
|
159
|
+
)
|
160
|
+
# Required preset argument is not provided.
|
161
|
+
with self.assertRaisesRegex(ValueError, 'Argument .* is not present.'):
|
162
|
+
callable_ext.PresetArgValue.resolve_args(
|
163
|
+
call_args=[],
|
164
|
+
call_kwargs=dict(),
|
165
|
+
positional_arg_names=['x'],
|
166
|
+
arg_defaults={
|
167
|
+
'x': callable_ext.PresetArgValue(),
|
168
|
+
},
|
169
|
+
preset_kwargs=dict()
|
170
|
+
)
|
171
|
+
|
172
|
+
# Include all preset kwargs.
|
173
|
+
self.assertEqual(
|
174
|
+
callable_ext.PresetArgValue.resolve_args(
|
175
|
+
call_args=[],
|
176
|
+
call_kwargs=dict(),
|
177
|
+
positional_arg_names=['x', 'y'],
|
178
|
+
arg_defaults={
|
179
|
+
'x': callable_ext.PresetArgValue(),
|
180
|
+
'y': 0,
|
181
|
+
'z': callable_ext.PresetArgValue(default=2),
|
182
|
+
},
|
183
|
+
preset_kwargs=dict(x=1, y=2, z=3, w=4),
|
184
|
+
include_all_preset_kwargs=True,
|
185
|
+
),
|
186
|
+
([1, 0], dict(z=3, w=4))
|
187
|
+
)
|
188
|
+
|
189
|
+
def test_preset_args(self):
|
190
|
+
@callable_ext.enable_preset_args()
|
191
|
+
def foo(
|
192
|
+
x=callable_ext.PresetArgValue(),
|
193
|
+
y=1,
|
194
|
+
*args,
|
195
|
+
z=callable_ext.PresetArgValue(default=2)
|
196
|
+
):
|
197
|
+
del args
|
198
|
+
return x + y + z
|
199
|
+
|
200
|
+
with self.assertRaisesRegex(ValueError, 'Argument \'x\' is not present.'):
|
201
|
+
foo()
|
202
|
+
|
203
|
+
with callable_ext.preset_args(dict(x=1)):
|
204
|
+
self.assertEqual(foo(), 1 + 1 + 2)
|
205
|
+
|
206
|
+
# `y`` should not take precedence over the non-preset default.
|
207
|
+
with callable_ext.preset_args(dict(x=1, y=2)):
|
208
|
+
self.assertEqual(foo(), 1 + 1 + 2)
|
209
|
+
|
210
|
+
with callable_ext.preset_args(dict(x=1, y=2, z=3)):
|
211
|
+
self.assertEqual(foo(3), 3 + 1 + 3)
|
212
|
+
|
213
|
+
with callable_ext.preset_args(dict(x=1, y=2, z=3)):
|
214
|
+
self.assertEqual(foo(3, 3, z=4), 3 + 3 + 4)
|
215
|
+
|
216
|
+
def test_enable_preset_args(self):
|
217
|
+
|
218
|
+
# No side-effect if function does not have PresetArgValue.
|
219
|
+
def bar(x, y):
|
220
|
+
return x + y
|
221
|
+
self.assertIs(bar, callable_ext.enable_preset_args()(bar))
|
222
|
+
|
223
|
+
# `include_all_preset_kwargs` sets to False.
|
224
|
+
@callable_ext.enable_preset_args()
|
225
|
+
def baz(x, y=callable_ext.PresetArgValue(default=1), **kwargs):
|
226
|
+
return x + y + sum(kwargs.values())
|
227
|
+
|
228
|
+
with callable_ext.preset_args(dict(z=3, p=4)):
|
229
|
+
self.assertEqual(baz(1), 1 + 1)
|
230
|
+
|
231
|
+
# `include_all_prset_kwargs` is effective only when there is varkw.
|
232
|
+
@callable_ext.enable_preset_args(include_all_preset_kwargs=True)
|
233
|
+
def foo(x, y=callable_ext.PresetArgValue(default=1), **kwargs):
|
234
|
+
return x + y + sum(kwargs.values())
|
235
|
+
|
236
|
+
with callable_ext.preset_args(dict(z=3, p=4)):
|
237
|
+
self.assertEqual(foo(1), 1 + 1 + 3 + 4)
|
238
|
+
|
239
|
+
# `include_all_preset_kwargs` should be ignored if there is no varkw.
|
240
|
+
@callable_ext.enable_preset_args(include_all_preset_kwargs=True)
|
241
|
+
def fuz(x, y=callable_ext.PresetArgValue(default=1)):
|
242
|
+
return x + y
|
243
|
+
|
244
|
+
with callable_ext.preset_args(dict(y=2, z=3)):
|
245
|
+
self.assertEqual(fuz(1), 1 + 2)
|
246
|
+
|
247
|
+
def test_preset_args_nesting(self):
|
248
|
+
@callable_ext.enable_preset_args()
|
249
|
+
def foo(
|
250
|
+
x=callable_ext.PresetArgValue(),
|
251
|
+
y=1,
|
252
|
+
*,
|
253
|
+
z=callable_ext.PresetArgValue(default=2)
|
254
|
+
):
|
255
|
+
return x + y + z
|
256
|
+
|
257
|
+
def bar(inherit_preset: bool = False, **kwargs):
|
258
|
+
with callable_ext.preset_args(
|
259
|
+
{k: v + 1 for k, v in kwargs.items()},
|
260
|
+
inherit_preset=inherit_preset
|
261
|
+
):
|
262
|
+
return foo()
|
263
|
+
|
264
|
+
with callable_ext.preset_args(dict(x=1)):
|
265
|
+
self.assertEqual(foo(), 1 + 1 + 2)
|
266
|
+
|
267
|
+
self.assertEqual(bar(x=1), 2 + 1 + 2)
|
268
|
+
self.assertEqual(bar(x=1, z=2), 2 + 1 + 3)
|
269
|
+
self.assertEqual(bar(x=1, z=3), 2 + 1 + 4)
|
270
|
+
|
271
|
+
with self.assertRaisesRegex(ValueError, 'Argument \'x\' is not present.'):
|
272
|
+
bar()
|
273
|
+
self.assertEqual(bar(inherit_preset=True), 1 + 1 + 2)
|
274
|
+
|
275
|
+
|
21
276
|
class CallWithOptionalKeywordArgsTest(unittest.TestCase):
|
22
277
|
"""Tests for typing.CallWithOptionalKeywordArgs."""
|
23
278
|
|
pyglove/core/typing/inspect.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Utility module for inspecting generics types."""
|
15
15
|
|
16
16
|
import inspect
|
17
|
+
import sys
|
17
18
|
import typing
|
18
19
|
from typing import Any, Callable, Optional, Tuple, Type, Union
|
19
20
|
|
@@ -116,6 +117,68 @@ def get_type_args(
|
|
116
117
|
return ()
|
117
118
|
|
118
119
|
|
120
|
+
def get_outer_class(
|
121
|
+
cls: Type[Any],
|
122
|
+
base_cls: Union[Type[Any], Tuple[Type[Any], ...], None] = None,
|
123
|
+
immediate: bool = False,
|
124
|
+
) -> Optional[Type[Any]]:
|
125
|
+
"""Returns the outer class.
|
126
|
+
|
127
|
+
Example::
|
128
|
+
|
129
|
+
class A:
|
130
|
+
pass
|
131
|
+
|
132
|
+
class A1:
|
133
|
+
class B:
|
134
|
+
class C:
|
135
|
+
...
|
136
|
+
|
137
|
+
pg.typing.outer_class(B) is A1
|
138
|
+
pg.typing.outer_class(C) is B
|
139
|
+
pg.typing.outer_class(C, base_cls=A) is None
|
140
|
+
pg.typing.outer_class(C, base_cls=A1) is None
|
141
|
+
|
142
|
+
Args:
|
143
|
+
cls: The class to get the outer class for.
|
144
|
+
base_cls: The base class of the outer class. If provided, an outer class
|
145
|
+
that is not a subclass of `base_cls` will be returned as None.
|
146
|
+
immediate: Whether to return the immediate outer class or a class in the
|
147
|
+
nesting hierarchy that is a subclass of `base_cls`. Applicable when
|
148
|
+
`base_cls` is not None.
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
The outer class of `cls`. None if cannot find one or the outer class is
|
152
|
+
not a subclass of `base_cls`.
|
153
|
+
"""
|
154
|
+
if '<locals>' in cls.__qualname__:
|
155
|
+
raise ValueError(
|
156
|
+
'Cannot find the outer class for locally defined class '
|
157
|
+
f'{cls.__qualname__!r}'
|
158
|
+
)
|
159
|
+
|
160
|
+
names = cls.__qualname__.split('.')
|
161
|
+
if len(names) < 2:
|
162
|
+
return None
|
163
|
+
|
164
|
+
parent = sys.modules[cls.__module__]
|
165
|
+
symbols = []
|
166
|
+
for name in names[:-1]:
|
167
|
+
symbol = getattr(parent, name, None)
|
168
|
+
if symbol is None:
|
169
|
+
return None
|
170
|
+
assert inspect.isclass(symbol), symbol
|
171
|
+
symbols.append(symbol)
|
172
|
+
parent = symbol
|
173
|
+
|
174
|
+
for symbol in reversed(symbols):
|
175
|
+
if immediate:
|
176
|
+
return symbol if not base_cls or issubclass(symbol, base_cls) else None
|
177
|
+
if not base_cls or issubclass(symbol, base_cls):
|
178
|
+
return symbol
|
179
|
+
return None
|
180
|
+
|
181
|
+
|
119
182
|
def callable_eq(
|
120
183
|
x: Optional[Callable[..., Any]], y: Optional[Callable[..., Any]]
|
121
184
|
) -> bool:
|
@@ -16,8 +16,10 @@
|
|
16
16
|
from typing import Any, Generic, TypeVar
|
17
17
|
import unittest
|
18
18
|
|
19
|
+
from pyglove.core.typing import callable_signature
|
19
20
|
from pyglove.core.typing import inspect
|
20
21
|
|
22
|
+
|
21
23
|
XType = TypeVar('XType')
|
22
24
|
YType = TypeVar('YType')
|
23
25
|
|
@@ -50,6 +52,16 @@ class D(C):
|
|
50
52
|
pass
|
51
53
|
|
52
54
|
|
55
|
+
class AA:
|
56
|
+
pass
|
57
|
+
|
58
|
+
|
59
|
+
class AA1(AA):
|
60
|
+
class BB1:
|
61
|
+
class CC1:
|
62
|
+
pass
|
63
|
+
|
64
|
+
|
53
65
|
class InspectTest(unittest.TestCase):
|
54
66
|
|
55
67
|
def test_issubclass(self):
|
@@ -141,6 +153,33 @@ class InspectTest(unittest.TestCase):
|
|
141
153
|
self.assertEqual(inspect.get_type_args(C, A), (str, int))
|
142
154
|
self.assertEqual(inspect.get_type_args(C, B), (Str,))
|
143
155
|
|
156
|
+
def test_outer_class(self):
|
157
|
+
class Foo:
|
158
|
+
pass
|
159
|
+
|
160
|
+
with self.assertRaisesRegex(ValueError, '.* locally defined class'):
|
161
|
+
inspect.get_outer_class(Foo)
|
162
|
+
|
163
|
+
self.assertIsNone(inspect.get_outer_class(AA))
|
164
|
+
self.assertIs(inspect.get_outer_class(AA1.BB1), AA1)
|
165
|
+
self.assertIs(inspect.get_outer_class(AA1.BB1, AA), AA1)
|
166
|
+
self.assertIs(inspect.get_outer_class(AA1.BB1, A), None)
|
167
|
+
self.assertIs(inspect.get_outer_class(AA1.BB1.CC1), AA1.BB1)
|
168
|
+
self.assertIsNone(
|
169
|
+
inspect.get_outer_class(AA1.BB1.CC1, base_cls=AA, immediate=True)
|
170
|
+
)
|
171
|
+
self.assertIs(inspect.get_outer_class(AA1.BB1.CC1, AA), AA1)
|
172
|
+
self.assertIs(
|
173
|
+
inspect.get_outer_class(callable_signature.Argument.Kind),
|
174
|
+
callable_signature.Argument
|
175
|
+
)
|
176
|
+
|
177
|
+
class Bar:
|
178
|
+
pass
|
179
|
+
|
180
|
+
Bar.__qualname__ = 'NonExist.Bar'
|
181
|
+
self.assertIsNone(inspect.get_outer_class(Bar))
|
182
|
+
|
144
183
|
def test_callable_eq(self):
|
145
184
|
def foo(unused_x):
|
146
185
|
pass
|
@@ -0,0 +1,30 @@
|
|
1
|
+
# Copyright 2024 The PyGlove Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""PyGlove views."""
|
15
|
+
|
16
|
+
from pyglove.core.views import base
|
17
|
+
from pyglove.core.views import html
|
18
|
+
|
19
|
+
View = base.View
|
20
|
+
view = base.view
|
21
|
+
|
22
|
+
# Pytype annotation.
|
23
|
+
NodeFilter = base.NodeFilter
|
24
|
+
|
25
|
+
Html = html.Html
|
26
|
+
HtmlView = html.HtmlView
|
27
|
+
HtmlTreeView = html.HtmlTreeView
|
28
|
+
|
29
|
+
to_html = html.to_html
|
30
|
+
to_html_str = html.to_html_str
|