pyglove 0.4.5.dev202410020809__py3-none-any.whl → 0.4.5.dev202410100808__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. pyglove/core/__init__.py +9 -0
  2. pyglove/core/object_utils/__init__.py +1 -0
  3. pyglove/core/object_utils/value_location.py +10 -0
  4. pyglove/core/object_utils/value_location_test.py +22 -0
  5. pyglove/core/symbolic/base.py +40 -2
  6. pyglove/core/symbolic/base_test.py +67 -0
  7. pyglove/core/symbolic/diff.py +190 -1
  8. pyglove/core/symbolic/diff_test.py +290 -0
  9. pyglove/core/symbolic/object_test.py +3 -8
  10. pyglove/core/symbolic/ref.py +29 -0
  11. pyglove/core/symbolic/ref_test.py +143 -0
  12. pyglove/core/typing/__init__.py +4 -0
  13. pyglove/core/typing/callable_ext.py +240 -1
  14. pyglove/core/typing/callable_ext_test.py +255 -0
  15. pyglove/core/typing/inspect.py +63 -0
  16. pyglove/core/typing/inspect_test.py +39 -0
  17. pyglove/core/views/__init__.py +30 -0
  18. pyglove/core/views/base.py +906 -0
  19. pyglove/core/views/base_test.py +615 -0
  20. pyglove/core/views/html/__init__.py +27 -0
  21. pyglove/core/views/html/base.py +529 -0
  22. pyglove/core/views/html/base_test.py +804 -0
  23. pyglove/core/views/html/tree_view.py +1052 -0
  24. pyglove/core/views/html/tree_view_test.py +748 -0
  25. {pyglove-0.4.5.dev202410020809.dist-info → pyglove-0.4.5.dev202410100808.dist-info}/METADATA +1 -1
  26. {pyglove-0.4.5.dev202410020809.dist-info → pyglove-0.4.5.dev202410100808.dist-info}/RECORD +29 -21
  27. {pyglove-0.4.5.dev202410020809.dist-info → pyglove-0.4.5.dev202410100808.dist-info}/LICENSE +0 -0
  28. {pyglove-0.4.5.dev202410020809.dist-info → pyglove-0.4.5.dev202410100808.dist-info}/WHEEL +0 -0
  29. {pyglove-0.4.5.dev202410020809.dist-info → pyglove-0.4.5.dev202410100808.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,250 @@
13
13
  # limitations under the License.
14
14
  """Callable extensions."""
15
15
 
16
- from typing import Any, Callable, List
16
+ import contextlib
17
+ import functools
18
+ import inspect
19
+ import types
20
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
17
21
 
22
+ from pyglove.core import object_utils
18
23
  from pyglove.core.typing import callable_signature
19
24
 
20
25
 
26
+ _TLS_KEY_PRESET_KWARGS = '__preset_kwargs__'
27
+
28
+
29
+ class PresetArgValue(object_utils.Formattable):
30
+ """Value placeholder for arguments whose value will be provided by presets.
31
+
32
+ Example:
33
+
34
+ def foo(x, y=pg.PresetArgValue(default=1))
35
+ return x + y
36
+
37
+ with pg.preset_args(y=2):
38
+ print(foo(x=1)) # 3: y=2
39
+ print(foo(x=1)) # 2: y=1
40
+ """
41
+
42
+ def __init__(self, default: Any = object_utils.MISSING_VALUE):
43
+ self.default = default
44
+
45
+ @property
46
+ def has_default(self) -> bool:
47
+ return self.default != object_utils.MISSING_VALUE
48
+
49
+ def __eq__(self, other: Any) -> bool:
50
+ return isinstance(other, PresetArgValue) and (
51
+ self.default == other.default
52
+ )
53
+
54
+ def __ne__(self, other: Any) -> bool:
55
+ return not self.__eq__(other)
56
+
57
+ def format(self, *args, **kwargs):
58
+ return object_utils.kvlist_str(
59
+ [
60
+ ('default', self.default, object_utils.MISSING_VALUE),
61
+ ],
62
+ label='PresetArgValue',
63
+ *args,
64
+ **kwargs
65
+ )
66
+
67
+ @classmethod
68
+ def inspect(
69
+ cls, func: types.FunctionType) -> Dict[str, 'PresetArgValue']:
70
+ """Gets the PresetArgValue specified in a function's signature."""
71
+ assert inspect.isfunction(func), func
72
+ sig = inspect.signature(func)
73
+ preset_arg_markers = {}
74
+ for p in sig.parameters.values():
75
+ if isinstance(p.default, cls):
76
+ preset_arg_markers[p.name] = p.default
77
+ return preset_arg_markers
78
+
79
+ @classmethod
80
+ def resolve_args(
81
+ cls,
82
+ call_args: Tuple[Any, ...],
83
+ call_kwargs: Dict[str, Any],
84
+ positional_arg_names: Sequence[str],
85
+ arg_defaults: Dict[str, Any],
86
+ preset_kwargs: Dict[str, Any],
87
+ include_all_preset_kwargs: bool = False,
88
+ ) -> Tuple[Sequence[Any], Dict[str, Any]]:
89
+ """Resolves calling arguments passed to a method with presets."""
90
+ # Step 1: compute marked kwargs.
91
+ resolved_kwargs = {}
92
+ for arg_name, arg_default in arg_defaults.items():
93
+ if not isinstance(arg_default, PresetArgValue):
94
+ resolved_kwargs[arg_name] = arg_default
95
+ continue
96
+ if arg_name in preset_kwargs:
97
+ resolved_kwargs[arg_name] = preset_kwargs[arg_name]
98
+ elif arg_default.has_default:
99
+ resolved_kwargs[arg_name] = arg_default.default
100
+ else:
101
+ raise ValueError(
102
+ f'Argument {arg_name!r} is not present as a keyword argument '
103
+ 'from the caller.'
104
+ )
105
+
106
+ # Step 2: add preset kwargs
107
+ if include_all_preset_kwargs:
108
+ for k, v in preset_kwargs.items():
109
+ if k not in resolved_kwargs:
110
+ resolved_kwargs[k] = v
111
+
112
+ # Step 3: merge call kwargs with resolved preset kwargs.
113
+ resolved_kwargs.update(call_kwargs)
114
+
115
+ # Step 3: remove resolved kwargs items as it's present in call args.
116
+ for i in range(len(call_args)):
117
+ if i >= len(positional_arg_names):
118
+ break
119
+ resolved_kwargs.pop(positional_arg_names[i], None)
120
+
121
+ # Step 4: convert kwargs back to postional arguments if applicable
122
+ resolved_args = call_args
123
+ if len(positional_arg_names) > len(call_args):
124
+ resolved_args = list(resolved_args)
125
+ for arg_name in positional_arg_names[len(call_args):]:
126
+ if arg_name not in resolved_kwargs:
127
+ break
128
+ arg_value = resolved_kwargs.pop(arg_name)
129
+ resolved_args.append(arg_value)
130
+ return resolved_args, resolved_kwargs
131
+
132
+
133
+ class _ArgPresets:
134
+ """Preset argument collection."""
135
+
136
+ def __init__(self, presets: Optional[Dict[str, Dict[str, Any]]] = None):
137
+ self._presets: Dict[str, Dict[str, Any]] = presets or {}
138
+
139
+ def derive(
140
+ self,
141
+ kwargs: Dict[str, Any],
142
+ preset_name: str = 'global',
143
+ inherit_preset: Union[str, bool] = False
144
+ ) -> '_ArgPresets':
145
+ """Derives new presets from current presets."""
146
+ presets = self._presets.copy() # Just do a shallow copy.
147
+ if isinstance(inherit_preset, bool) and inherit_preset:
148
+ inherit_preset = preset_name
149
+
150
+ if inherit_preset and inherit_preset in presets:
151
+ current_preset = presets[inherit_preset].copy()
152
+ current_preset.update(kwargs)
153
+ else:
154
+ current_preset = kwargs
155
+ presets[preset_name] = current_preset
156
+ return _ArgPresets(presets)
157
+
158
+ def get_preset(self, preset_name: str) -> Dict[str, Any]:
159
+ return self._presets.get(preset_name, {})
160
+
161
+
162
+ @contextlib.contextmanager
163
+ def preset_args(
164
+ kwargs: Dict[str, Any],
165
+ *,
166
+ preset_name: str = 'global',
167
+ inherit_preset: Union[str, bool] = False
168
+ ) -> Iterator[Dict[str, Any]]:
169
+ """Context manager to enable calling with user kwargs.
170
+
171
+ Args:
172
+ kwargs: The preset kwargs to be used by preset-enabled functions within
173
+ the context.
174
+ preset_name: The name of the preset to specify kwargs.
175
+ `enable_preset_args` allows users to pass a preset name, which will be
176
+ used to identify the present to be used.
177
+ inherit_preset: The name of the preset defined by the parent context to
178
+ inherit kwargs from. Or a boolean to indicate whether to inherit a
179
+ parent preset of the same name.
180
+
181
+ Yields:
182
+ Current preset kwargs.
183
+ """
184
+
185
+ parent_presets = object_utils.thread_local_peek(
186
+ _TLS_KEY_PRESET_KWARGS, _ArgPresets()
187
+ )
188
+ current_preset = parent_presets.derive(kwargs, preset_name, inherit_preset)
189
+ object_utils.thread_local_push(_TLS_KEY_PRESET_KWARGS, current_preset)
190
+ try:
191
+ yield current_preset
192
+ finally:
193
+ object_utils.thread_local_pop(_TLS_KEY_PRESET_KWARGS, None)
194
+
195
+
196
+ def enable_preset_args(
197
+ include_all_preset_kwargs: bool = False,
198
+ preset_name: str = 'global'
199
+ ) -> Callable[[types.FunctionType], types.FunctionType]:
200
+ """Decorator for functions that maybe use preset argument values.
201
+
202
+ Usage::
203
+
204
+ @pg.typing.enable_preset_args
205
+ def foo(x, y=pg.typing.PresetArgValue(default=1)):
206
+ return x + y
207
+
208
+ with pg.typing.preset_args(y=2):
209
+ print(foo(x=1)) # 3: y=2
210
+ print(foo(x=1)) # 2: y=1
211
+
212
+ Args:
213
+ include_all_preset_kwargs: Whether to include all preset kwargs (even
214
+ not makred as `PresetArgValue`) when callng the function.
215
+ preset_name: The name of the preset to specify kwargs.
216
+
217
+ Returns:
218
+ A decorated function that could consume the preset argument values.
219
+ """
220
+ def decorator(func):
221
+ sig = inspect.signature(func)
222
+ positional_arg_names = [
223
+ p.name for p in sig.parameters.values()
224
+ if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
225
+ ]
226
+
227
+ arg_defaults = {}
228
+ has_preset_value = False
229
+ has_varkw = False
230
+ for p in sig.parameters.values():
231
+ if p.kind == inspect.Parameter.VAR_KEYWORD:
232
+ has_varkw = True
233
+ continue
234
+ if p.kind == inspect.Parameter.VAR_POSITIONAL:
235
+ continue
236
+ if p.default == inspect.Parameter.empty:
237
+ continue
238
+ if isinstance(p.default, PresetArgValue):
239
+ has_preset_value = True
240
+ arg_defaults[p.name] = p.default
241
+
242
+ if has_preset_value:
243
+ @functools.wraps(func)
244
+ def _func(*args, **kwargs):
245
+ # Map positional arguments to keyword arguments.
246
+ presets = object_utils.thread_local_peek(
247
+ _TLS_KEY_PRESET_KWARGS, None
248
+ )
249
+ preset_kwargs = presets.get_preset(preset_name) if presets else {}
250
+ args, kwargs = PresetArgValue.resolve_args(
251
+ args, kwargs, positional_arg_names, arg_defaults, preset_kwargs,
252
+ include_all_preset_kwargs=include_all_preset_kwargs and has_varkw
253
+ )
254
+ return func(*args, **kwargs)
255
+ return _func
256
+ return func
257
+ return decorator
258
+
259
+
21
260
  class CallableWithOptionalKeywordArgs:
22
261
  """Helper class for invoking callable objects with optional keyword args.
23
262
 
@@ -18,6 +18,261 @@ from pyglove.core.typing import annotation_conversion # pylint: disable=unused
18
18
  from pyglove.core.typing import callable_ext
19
19
 
20
20
 
21
+ class PresetArgValueTest(unittest.TestCase):
22
+ """Tests for typing.PresetArgValue."""
23
+
24
+ def test_basics(self):
25
+ v = callable_ext.PresetArgValue()
26
+ self.assertFalse(v.has_default)
27
+ v = callable_ext.PresetArgValue(default=1)
28
+ self.assertTrue(v.has_default)
29
+ self.assertEqual(v.default, 1)
30
+ self.assertEqual(repr(v), 'PresetArgValue(default=1)')
31
+ self.assertEqual(str(v), 'PresetArgValue(\n default=1\n)')
32
+ self.assertEqual(
33
+ callable_ext.PresetArgValue(), callable_ext.PresetArgValue()
34
+ )
35
+ self.assertEqual(
36
+ callable_ext.PresetArgValue(1), callable_ext.PresetArgValue(1)
37
+ )
38
+ self.assertNotEqual(
39
+ callable_ext.PresetArgValue(), callable_ext.PresetArgValue(default=1)
40
+ )
41
+
42
+ def test_inspect(self):
43
+
44
+ def foo(
45
+ x=callable_ext.PresetArgValue(),
46
+ y=1,
47
+ *,
48
+ z=callable_ext.PresetArgValue(default=2)
49
+ ):
50
+ return x + y + z
51
+
52
+ self.assertEqual(
53
+ callable_ext.PresetArgValue.inspect(foo),
54
+ dict(
55
+ x=callable_ext.PresetArgValue(),
56
+ z=callable_ext.PresetArgValue(default=2)
57
+ )
58
+ )
59
+
60
+ def bar(x, y=1):
61
+ return x + y
62
+
63
+ self.assertEqual(
64
+ callable_ext.PresetArgValue.inspect(bar),
65
+ {}
66
+ )
67
+
68
+ def test_resolve_args(self):
69
+ # Both positional and keyword arguments are from preset.
70
+ self.assertEqual(
71
+ callable_ext.PresetArgValue.resolve_args(
72
+ call_args=[],
73
+ call_kwargs=dict(),
74
+ positional_arg_names=['x', 'y'],
75
+ arg_defaults={
76
+ 'x': callable_ext.PresetArgValue(),
77
+ 'z': callable_ext.PresetArgValue(default=2),
78
+ },
79
+ preset_kwargs=dict(x=1, y=2, z=3, w=4)
80
+ ),
81
+ ([1], dict(z=3))
82
+ )
83
+ # Both positional and keyword arguments are absent and use the default.
84
+ self.assertEqual(
85
+ callable_ext.PresetArgValue.resolve_args(
86
+ call_args=[],
87
+ call_kwargs=dict(),
88
+ positional_arg_names=['x', 'y'],
89
+ arg_defaults={
90
+ 'x': callable_ext.PresetArgValue(default=1),
91
+ 'y': 0,
92
+ 'z': callable_ext.PresetArgValue(default=2),
93
+ },
94
+ preset_kwargs=dict()
95
+ ),
96
+ ([1, 0], dict(z=2))
97
+ )
98
+ # Positional args from preset, keyword argument use preset default.
99
+ self.assertEqual(
100
+ callable_ext.PresetArgValue.resolve_args(
101
+ call_args=[],
102
+ call_kwargs=dict(),
103
+ positional_arg_names=['x', 'y'],
104
+ arg_defaults={
105
+ 'x': callable_ext.PresetArgValue(),
106
+ 'y': 0,
107
+ 'z': callable_ext.PresetArgValue(default=2),
108
+ },
109
+ preset_kwargs=dict(x=1, y=2)
110
+ ),
111
+ ([1, 0], dict(z=2))
112
+ )
113
+ # Postional argument provided by user, which should take precedence over
114
+ # the preset value.
115
+ self.assertEqual(
116
+ callable_ext.PresetArgValue.resolve_args(
117
+ call_args=[2],
118
+ call_kwargs={},
119
+ positional_arg_names=['x', 'y'],
120
+ arg_defaults={
121
+ 'x': callable_ext.PresetArgValue(),
122
+ 'y': 0,
123
+ 'z': callable_ext.PresetArgValue(default=2),
124
+ },
125
+ preset_kwargs=dict(x=1, y=2, z=3, w=4)
126
+ ),
127
+ ([2, 0], dict(z=3))
128
+ )
129
+ # Postional argument provided in keyword.
130
+ self.assertEqual(
131
+ callable_ext.PresetArgValue.resolve_args(
132
+ call_args=[],
133
+ call_kwargs=dict(x=2),
134
+ positional_arg_names=['x', 'y'],
135
+ arg_defaults={
136
+ 'x': callable_ext.PresetArgValue(),
137
+ 'y': 0,
138
+ 'z': callable_ext.PresetArgValue(default=2),
139
+ },
140
+ preset_kwargs=dict(x=1, y=2, z=3, w=4)
141
+ ),
142
+ ([2, 0], dict(z=3))
143
+ )
144
+ # Postional argument provided in keyword, and there are more args
145
+ # (due to varargs)
146
+ self.assertEqual(
147
+ callable_ext.PresetArgValue.resolve_args(
148
+ call_args=[1, 2, 3],
149
+ call_kwargs=dict(x=2),
150
+ positional_arg_names=['x', 'y'],
151
+ arg_defaults={
152
+ 'x': callable_ext.PresetArgValue(),
153
+ 'y': 0,
154
+ 'z': callable_ext.PresetArgValue(default=2),
155
+ },
156
+ preset_kwargs=dict(x=1, y=2, z=3, w=4)
157
+ ),
158
+ ([1, 2, 3], dict(z=3))
159
+ )
160
+ # Required preset argument is not provided.
161
+ with self.assertRaisesRegex(ValueError, 'Argument .* is not present.'):
162
+ callable_ext.PresetArgValue.resolve_args(
163
+ call_args=[],
164
+ call_kwargs=dict(),
165
+ positional_arg_names=['x'],
166
+ arg_defaults={
167
+ 'x': callable_ext.PresetArgValue(),
168
+ },
169
+ preset_kwargs=dict()
170
+ )
171
+
172
+ # Include all preset kwargs.
173
+ self.assertEqual(
174
+ callable_ext.PresetArgValue.resolve_args(
175
+ call_args=[],
176
+ call_kwargs=dict(),
177
+ positional_arg_names=['x', 'y'],
178
+ arg_defaults={
179
+ 'x': callable_ext.PresetArgValue(),
180
+ 'y': 0,
181
+ 'z': callable_ext.PresetArgValue(default=2),
182
+ },
183
+ preset_kwargs=dict(x=1, y=2, z=3, w=4),
184
+ include_all_preset_kwargs=True,
185
+ ),
186
+ ([1, 0], dict(z=3, w=4))
187
+ )
188
+
189
+ def test_preset_args(self):
190
+ @callable_ext.enable_preset_args()
191
+ def foo(
192
+ x=callable_ext.PresetArgValue(),
193
+ y=1,
194
+ *args,
195
+ z=callable_ext.PresetArgValue(default=2)
196
+ ):
197
+ del args
198
+ return x + y + z
199
+
200
+ with self.assertRaisesRegex(ValueError, 'Argument \'x\' is not present.'):
201
+ foo()
202
+
203
+ with callable_ext.preset_args(dict(x=1)):
204
+ self.assertEqual(foo(), 1 + 1 + 2)
205
+
206
+ # `y`` should not take precedence over the non-preset default.
207
+ with callable_ext.preset_args(dict(x=1, y=2)):
208
+ self.assertEqual(foo(), 1 + 1 + 2)
209
+
210
+ with callable_ext.preset_args(dict(x=1, y=2, z=3)):
211
+ self.assertEqual(foo(3), 3 + 1 + 3)
212
+
213
+ with callable_ext.preset_args(dict(x=1, y=2, z=3)):
214
+ self.assertEqual(foo(3, 3, z=4), 3 + 3 + 4)
215
+
216
+ def test_enable_preset_args(self):
217
+
218
+ # No side-effect if function does not have PresetArgValue.
219
+ def bar(x, y):
220
+ return x + y
221
+ self.assertIs(bar, callable_ext.enable_preset_args()(bar))
222
+
223
+ # `include_all_preset_kwargs` sets to False.
224
+ @callable_ext.enable_preset_args()
225
+ def baz(x, y=callable_ext.PresetArgValue(default=1), **kwargs):
226
+ return x + y + sum(kwargs.values())
227
+
228
+ with callable_ext.preset_args(dict(z=3, p=4)):
229
+ self.assertEqual(baz(1), 1 + 1)
230
+
231
+ # `include_all_prset_kwargs` is effective only when there is varkw.
232
+ @callable_ext.enable_preset_args(include_all_preset_kwargs=True)
233
+ def foo(x, y=callable_ext.PresetArgValue(default=1), **kwargs):
234
+ return x + y + sum(kwargs.values())
235
+
236
+ with callable_ext.preset_args(dict(z=3, p=4)):
237
+ self.assertEqual(foo(1), 1 + 1 + 3 + 4)
238
+
239
+ # `include_all_preset_kwargs` should be ignored if there is no varkw.
240
+ @callable_ext.enable_preset_args(include_all_preset_kwargs=True)
241
+ def fuz(x, y=callable_ext.PresetArgValue(default=1)):
242
+ return x + y
243
+
244
+ with callable_ext.preset_args(dict(y=2, z=3)):
245
+ self.assertEqual(fuz(1), 1 + 2)
246
+
247
+ def test_preset_args_nesting(self):
248
+ @callable_ext.enable_preset_args()
249
+ def foo(
250
+ x=callable_ext.PresetArgValue(),
251
+ y=1,
252
+ *,
253
+ z=callable_ext.PresetArgValue(default=2)
254
+ ):
255
+ return x + y + z
256
+
257
+ def bar(inherit_preset: bool = False, **kwargs):
258
+ with callable_ext.preset_args(
259
+ {k: v + 1 for k, v in kwargs.items()},
260
+ inherit_preset=inherit_preset
261
+ ):
262
+ return foo()
263
+
264
+ with callable_ext.preset_args(dict(x=1)):
265
+ self.assertEqual(foo(), 1 + 1 + 2)
266
+
267
+ self.assertEqual(bar(x=1), 2 + 1 + 2)
268
+ self.assertEqual(bar(x=1, z=2), 2 + 1 + 3)
269
+ self.assertEqual(bar(x=1, z=3), 2 + 1 + 4)
270
+
271
+ with self.assertRaisesRegex(ValueError, 'Argument \'x\' is not present.'):
272
+ bar()
273
+ self.assertEqual(bar(inherit_preset=True), 1 + 1 + 2)
274
+
275
+
21
276
  class CallWithOptionalKeywordArgsTest(unittest.TestCase):
22
277
  """Tests for typing.CallWithOptionalKeywordArgs."""
23
278
 
@@ -14,6 +14,7 @@
14
14
  """Utility module for inspecting generics types."""
15
15
 
16
16
  import inspect
17
+ import sys
17
18
  import typing
18
19
  from typing import Any, Callable, Optional, Tuple, Type, Union
19
20
 
@@ -116,6 +117,68 @@ def get_type_args(
116
117
  return ()
117
118
 
118
119
 
120
+ def get_outer_class(
121
+ cls: Type[Any],
122
+ base_cls: Union[Type[Any], Tuple[Type[Any], ...], None] = None,
123
+ immediate: bool = False,
124
+ ) -> Optional[Type[Any]]:
125
+ """Returns the outer class.
126
+
127
+ Example::
128
+
129
+ class A:
130
+ pass
131
+
132
+ class A1:
133
+ class B:
134
+ class C:
135
+ ...
136
+
137
+ pg.typing.outer_class(B) is A1
138
+ pg.typing.outer_class(C) is B
139
+ pg.typing.outer_class(C, base_cls=A) is None
140
+ pg.typing.outer_class(C, base_cls=A1) is None
141
+
142
+ Args:
143
+ cls: The class to get the outer class for.
144
+ base_cls: The base class of the outer class. If provided, an outer class
145
+ that is not a subclass of `base_cls` will be returned as None.
146
+ immediate: Whether to return the immediate outer class or a class in the
147
+ nesting hierarchy that is a subclass of `base_cls`. Applicable when
148
+ `base_cls` is not None.
149
+
150
+ Returns:
151
+ The outer class of `cls`. None if cannot find one or the outer class is
152
+ not a subclass of `base_cls`.
153
+ """
154
+ if '<locals>' in cls.__qualname__:
155
+ raise ValueError(
156
+ 'Cannot find the outer class for locally defined class '
157
+ f'{cls.__qualname__!r}'
158
+ )
159
+
160
+ names = cls.__qualname__.split('.')
161
+ if len(names) < 2:
162
+ return None
163
+
164
+ parent = sys.modules[cls.__module__]
165
+ symbols = []
166
+ for name in names[:-1]:
167
+ symbol = getattr(parent, name, None)
168
+ if symbol is None:
169
+ return None
170
+ assert inspect.isclass(symbol), symbol
171
+ symbols.append(symbol)
172
+ parent = symbol
173
+
174
+ for symbol in reversed(symbols):
175
+ if immediate:
176
+ return symbol if not base_cls or issubclass(symbol, base_cls) else None
177
+ if not base_cls or issubclass(symbol, base_cls):
178
+ return symbol
179
+ return None
180
+
181
+
119
182
  def callable_eq(
120
183
  x: Optional[Callable[..., Any]], y: Optional[Callable[..., Any]]
121
184
  ) -> bool:
@@ -16,8 +16,10 @@
16
16
  from typing import Any, Generic, TypeVar
17
17
  import unittest
18
18
 
19
+ from pyglove.core.typing import callable_signature
19
20
  from pyglove.core.typing import inspect
20
21
 
22
+
21
23
  XType = TypeVar('XType')
22
24
  YType = TypeVar('YType')
23
25
 
@@ -50,6 +52,16 @@ class D(C):
50
52
  pass
51
53
 
52
54
 
55
+ class AA:
56
+ pass
57
+
58
+
59
+ class AA1(AA):
60
+ class BB1:
61
+ class CC1:
62
+ pass
63
+
64
+
53
65
  class InspectTest(unittest.TestCase):
54
66
 
55
67
  def test_issubclass(self):
@@ -141,6 +153,33 @@ class InspectTest(unittest.TestCase):
141
153
  self.assertEqual(inspect.get_type_args(C, A), (str, int))
142
154
  self.assertEqual(inspect.get_type_args(C, B), (Str,))
143
155
 
156
+ def test_outer_class(self):
157
+ class Foo:
158
+ pass
159
+
160
+ with self.assertRaisesRegex(ValueError, '.* locally defined class'):
161
+ inspect.get_outer_class(Foo)
162
+
163
+ self.assertIsNone(inspect.get_outer_class(AA))
164
+ self.assertIs(inspect.get_outer_class(AA1.BB1), AA1)
165
+ self.assertIs(inspect.get_outer_class(AA1.BB1, AA), AA1)
166
+ self.assertIs(inspect.get_outer_class(AA1.BB1, A), None)
167
+ self.assertIs(inspect.get_outer_class(AA1.BB1.CC1), AA1.BB1)
168
+ self.assertIsNone(
169
+ inspect.get_outer_class(AA1.BB1.CC1, base_cls=AA, immediate=True)
170
+ )
171
+ self.assertIs(inspect.get_outer_class(AA1.BB1.CC1, AA), AA1)
172
+ self.assertIs(
173
+ inspect.get_outer_class(callable_signature.Argument.Kind),
174
+ callable_signature.Argument
175
+ )
176
+
177
+ class Bar:
178
+ pass
179
+
180
+ Bar.__qualname__ = 'NonExist.Bar'
181
+ self.assertIsNone(inspect.get_outer_class(Bar))
182
+
144
183
  def test_callable_eq(self):
145
184
  def foo(unused_x):
146
185
  pass
@@ -0,0 +1,30 @@
1
+ # Copyright 2024 The PyGlove Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PyGlove views."""
15
+
16
+ from pyglove.core.views import base
17
+ from pyglove.core.views import html
18
+
19
+ View = base.View
20
+ view = base.view
21
+
22
+ # Pytype annotation.
23
+ NodeFilter = base.NodeFilter
24
+
25
+ Html = html.Html
26
+ HtmlView = html.HtmlView
27
+ HtmlTreeView = html.HtmlTreeView
28
+
29
+ to_html = html.to_html
30
+ to_html_str = html.to_html_str