pyglove 0.4.5.dev20240318__py3-none-any.whl → 0.4.5.dev202501132210__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyglove/core/__init__.py +54 -20
- pyglove/core/coding/__init__.py +42 -0
- pyglove/core/coding/errors.py +111 -0
- pyglove/core/coding/errors_test.py +98 -0
- pyglove/core/coding/execution.py +309 -0
- pyglove/core/coding/execution_test.py +333 -0
- pyglove/core/{object_utils/codegen.py → coding/function_generation.py} +10 -4
- pyglove/core/{object_utils/codegen_test.py → coding/function_generation_test.py} +5 -7
- pyglove/core/coding/parsing.py +153 -0
- pyglove/core/coding/parsing_test.py +150 -0
- pyglove/core/coding/permissions.py +100 -0
- pyglove/core/coding/permissions_test.py +93 -0
- pyglove/core/geno/base.py +54 -41
- pyglove/core/geno/base_test.py +2 -4
- pyglove/core/geno/categorical.py +37 -28
- pyglove/core/geno/custom.py +19 -16
- pyglove/core/geno/numerical.py +20 -17
- pyglove/core/geno/space.py +4 -5
- pyglove/core/hyper/base.py +6 -6
- pyglove/core/hyper/categorical.py +94 -55
- pyglove/core/hyper/custom.py +7 -7
- pyglove/core/hyper/custom_test.py +9 -10
- pyglove/core/hyper/derived.py +30 -22
- pyglove/core/hyper/derived_test.py +2 -4
- pyglove/core/hyper/dynamic_evaluation.py +5 -6
- pyglove/core/hyper/evolvable.py +57 -46
- pyglove/core/hyper/numerical.py +48 -24
- pyglove/core/hyper/numerical_test.py +9 -9
- pyglove/core/hyper/object_template.py +58 -46
- pyglove/core/io/__init__.py +1 -0
- pyglove/core/io/file_system.py +17 -7
- pyglove/core/io/file_system_test.py +2 -0
- pyglove/core/io/sequence.py +299 -0
- pyglove/core/io/sequence_test.py +124 -0
- pyglove/core/logging_test.py +0 -2
- pyglove/core/patching/object_factory.py +4 -4
- pyglove/core/patching/pattern_based.py +4 -4
- pyglove/core/patching/rule_based.py +17 -5
- pyglove/core/patching/rule_based_test.py +27 -4
- pyglove/core/symbolic/__init__.py +2 -7
- pyglove/core/symbolic/base.py +320 -183
- pyglove/core/symbolic/base_test.py +123 -19
- pyglove/core/symbolic/boilerplate.py +7 -13
- pyglove/core/symbolic/boilerplate_test.py +25 -23
- pyglove/core/symbolic/class_wrapper.py +48 -45
- pyglove/core/symbolic/class_wrapper_test.py +2 -2
- pyglove/core/symbolic/compounding.py +9 -15
- pyglove/core/symbolic/compounding_test.py +2 -4
- pyglove/core/symbolic/dict.py +154 -110
- pyglove/core/symbolic/dict_test.py +238 -130
- pyglove/core/symbolic/diff.py +199 -10
- pyglove/core/symbolic/diff_test.py +226 -0
- pyglove/core/symbolic/flags.py +1 -1
- pyglove/core/symbolic/functor.py +29 -26
- pyglove/core/symbolic/functor_test.py +102 -50
- pyglove/core/symbolic/inferred.py +2 -2
- pyglove/core/symbolic/list.py +81 -50
- pyglove/core/symbolic/list_test.py +119 -97
- pyglove/core/symbolic/object.py +225 -113
- pyglove/core/symbolic/object_test.py +320 -108
- pyglove/core/symbolic/origin.py +17 -14
- pyglove/core/symbolic/origin_test.py +4 -2
- pyglove/core/symbolic/pure_symbolic.py +4 -3
- pyglove/core/symbolic/ref.py +108 -21
- pyglove/core/symbolic/ref_test.py +93 -0
- pyglove/core/symbolic/symbolize_test.py +10 -2
- pyglove/core/tuning/local_backend.py +2 -2
- pyglove/core/tuning/protocols.py +3 -3
- pyglove/core/tuning/sample_test.py +3 -3
- pyglove/core/typing/__init__.py +14 -5
- pyglove/core/typing/annotation_conversion.py +43 -27
- pyglove/core/typing/annotation_conversion_test.py +23 -0
- pyglove/core/typing/callable_ext.py +241 -3
- pyglove/core/typing/callable_ext_test.py +255 -0
- pyglove/core/typing/callable_signature.py +510 -66
- pyglove/core/typing/callable_signature_test.py +619 -99
- pyglove/core/typing/class_schema.py +229 -154
- pyglove/core/typing/class_schema_test.py +149 -95
- pyglove/core/typing/custom_typing.py +5 -4
- pyglove/core/typing/inspect.py +63 -0
- pyglove/core/typing/inspect_test.py +39 -0
- pyglove/core/typing/key_specs.py +10 -11
- pyglove/core/typing/key_specs_test.py +7 -4
- pyglove/core/typing/type_conversion.py +4 -5
- pyglove/core/typing/type_conversion_test.py +12 -12
- pyglove/core/typing/typed_missing.py +6 -7
- pyglove/core/typing/typed_missing_test.py +7 -8
- pyglove/core/typing/value_specs.py +604 -362
- pyglove/core/typing/value_specs_test.py +328 -90
- pyglove/core/utils/__init__.py +164 -0
- pyglove/core/{object_utils → utils}/common_traits.py +3 -67
- pyglove/core/utils/common_traits_test.py +36 -0
- pyglove/core/{object_utils → utils}/docstr_utils.py +23 -0
- pyglove/core/{object_utils → utils}/docstr_utils_test.py +36 -4
- pyglove/core/{object_utils → utils}/error_utils.py +78 -9
- pyglove/core/{object_utils → utils}/error_utils_test.py +61 -5
- pyglove/core/utils/formatting.py +464 -0
- pyglove/core/utils/formatting_test.py +453 -0
- pyglove/core/{object_utils → utils}/hierarchical.py +23 -25
- pyglove/core/{object_utils → utils}/hierarchical_test.py +3 -5
- pyglove/core/{object_utils → utils}/json_conversion.py +177 -52
- pyglove/core/{object_utils → utils}/json_conversion_test.py +97 -16
- pyglove/core/{object_utils → utils}/missing.py +3 -3
- pyglove/core/{object_utils → utils}/missing_test.py +2 -4
- pyglove/core/utils/text_color.py +128 -0
- pyglove/core/utils/text_color_test.py +94 -0
- pyglove/core/{object_utils → utils}/thread_local_test.py +1 -3
- pyglove/core/utils/timing.py +236 -0
- pyglove/core/utils/timing_test.py +154 -0
- pyglove/core/{object_utils → utils}/value_location.py +275 -6
- pyglove/core/utils/value_location_test.py +707 -0
- pyglove/core/views/__init__.py +32 -0
- pyglove/core/views/base.py +804 -0
- pyglove/core/views/base_test.py +580 -0
- pyglove/core/views/html/__init__.py +27 -0
- pyglove/core/views/html/base.py +547 -0
- pyglove/core/views/html/base_test.py +830 -0
- pyglove/core/views/html/controls/__init__.py +35 -0
- pyglove/core/views/html/controls/base.py +275 -0
- pyglove/core/views/html/controls/label.py +207 -0
- pyglove/core/views/html/controls/label_test.py +157 -0
- pyglove/core/views/html/controls/progress_bar.py +183 -0
- pyglove/core/views/html/controls/progress_bar_test.py +97 -0
- pyglove/core/views/html/controls/tab.py +320 -0
- pyglove/core/views/html/controls/tab_test.py +87 -0
- pyglove/core/views/html/controls/tooltip.py +99 -0
- pyglove/core/views/html/controls/tooltip_test.py +99 -0
- pyglove/core/views/html/tree_view.py +1517 -0
- pyglove/core/views/html/tree_view_test.py +1461 -0
- {pyglove-0.4.5.dev20240318.dist-info → pyglove-0.4.5.dev202501132210.dist-info}/METADATA +18 -4
- pyglove-0.4.5.dev202501132210.dist-info/RECORD +214 -0
- {pyglove-0.4.5.dev20240318.dist-info → pyglove-0.4.5.dev202501132210.dist-info}/WHEEL +1 -1
- pyglove/core/object_utils/__init__.py +0 -154
- pyglove/core/object_utils/common_traits_test.py +0 -82
- pyglove/core/object_utils/formatting.py +0 -234
- pyglove/core/object_utils/formatting_test.py +0 -223
- pyglove/core/object_utils/value_location_test.py +0 -385
- pyglove/core/symbolic/schema_utils.py +0 -327
- pyglove/core/symbolic/schema_utils_test.py +0 -57
- pyglove/core/typing/class_schema_utils.py +0 -202
- pyglove/core/typing/class_schema_utils_test.py +0 -194
- pyglove-0.4.5.dev20240318.dist-info/RECORD +0 -185
- /pyglove/core/{object_utils → utils}/thread_local.py +0 -0
- {pyglove-0.4.5.dev20240318.dist-info → pyglove-0.4.5.dev202501132210.dist-info}/LICENSE +0 -0
- {pyglove-0.4.5.dev20240318.dist-info → pyglove-0.4.5.dev202501132210.dist-info}/top_level.txt +0 -0
@@ -11,15 +11,18 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for pyglove.symbolic.base."""
|
15
|
-
|
16
14
|
import copy
|
15
|
+
import inspect
|
16
|
+
from typing import Any
|
17
17
|
import unittest
|
18
18
|
|
19
|
-
from pyglove.core import object_utils
|
20
19
|
from pyglove.core import typing as pg_typing
|
20
|
+
from pyglove.core import utils
|
21
|
+
from pyglove.core import views
|
21
22
|
from pyglove.core.symbolic import base
|
22
23
|
from pyglove.core.symbolic.dict import Dict
|
24
|
+
from pyglove.core.symbolic.inferred import ValueFromParentChain
|
25
|
+
from pyglove.core.symbolic.object import Object
|
23
26
|
|
24
27
|
|
25
28
|
class FieldUpdateTest(unittest.TestCase):
|
@@ -28,7 +31,7 @@ class FieldUpdateTest(unittest.TestCase):
|
|
28
31
|
def test_basics(self):
|
29
32
|
x = Dict(x=1)
|
30
33
|
f = pg_typing.Field('x', pg_typing.Int())
|
31
|
-
update = base.FieldUpdate(
|
34
|
+
update = base.FieldUpdate(utils.KeyPath('x'), x, f, 1, 2)
|
32
35
|
self.assertEqual(update.path, 'x')
|
33
36
|
self.assertIs(update.target, x)
|
34
37
|
self.assertIs(update.field, f)
|
@@ -37,15 +40,15 @@ class FieldUpdateTest(unittest.TestCase):
|
|
37
40
|
|
38
41
|
def test_format(self):
|
39
42
|
self.assertEqual(
|
40
|
-
base.FieldUpdate(
|
41
|
-
|
42
|
-
)
|
43
|
+
base.FieldUpdate(utils.KeyPath('x'), Dict(x=1), None, 1, 2).format(
|
44
|
+
compact=True
|
45
|
+
),
|
43
46
|
'FieldUpdate(parent_path=, path=x, old_value=1, new_value=2)',
|
44
47
|
)
|
45
48
|
|
46
49
|
self.assertEqual(
|
47
50
|
base.FieldUpdate(
|
48
|
-
|
51
|
+
utils.KeyPath('a'), Dict(x=Dict(a=1)).x, None, 1, 2
|
49
52
|
).format(compact=True),
|
50
53
|
'FieldUpdate(parent_path=x, path=a, old_value=1, new_value=2)',
|
51
54
|
)
|
@@ -54,35 +57,136 @@ class FieldUpdateTest(unittest.TestCase):
|
|
54
57
|
x = Dict()
|
55
58
|
f = pg_typing.Field('x', pg_typing.Int())
|
56
59
|
self.assertEqual(
|
57
|
-
base.FieldUpdate(
|
58
|
-
base.FieldUpdate(
|
60
|
+
base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2),
|
61
|
+
base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2),
|
59
62
|
)
|
60
63
|
|
61
64
|
# Targets are not the same instance.
|
62
65
|
self.assertNotEqual(
|
63
|
-
base.FieldUpdate(
|
64
|
-
base.FieldUpdate(
|
66
|
+
base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2),
|
67
|
+
base.FieldUpdate(utils.KeyPath('a'), Dict(), f, 1, 2),
|
65
68
|
)
|
66
69
|
|
67
70
|
# Fields are not the same instance.
|
68
71
|
self.assertNotEqual(
|
69
|
-
base.FieldUpdate(
|
70
|
-
base.FieldUpdate(
|
72
|
+
base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2),
|
73
|
+
base.FieldUpdate(utils.KeyPath('b'), x, copy.copy(f), 1, 2),
|
71
74
|
)
|
72
75
|
|
73
76
|
self.assertNotEqual(
|
74
|
-
base.FieldUpdate(
|
75
|
-
base.FieldUpdate(
|
77
|
+
base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2),
|
78
|
+
base.FieldUpdate(utils.KeyPath('a'), x, f, 0, 2),
|
76
79
|
)
|
77
80
|
|
78
81
|
self.assertNotEqual(
|
79
|
-
base.FieldUpdate(
|
80
|
-
base.FieldUpdate(
|
82
|
+
base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2),
|
83
|
+
base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 1),
|
81
84
|
)
|
82
85
|
|
83
86
|
self.assertNotEqual(
|
84
|
-
base.FieldUpdate(
|
87
|
+
base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2), Dict()
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
class HtmlTreeViewExtensionTest(unittest.TestCase):
|
92
|
+
|
93
|
+
def assert_content(self, html, expected):
|
94
|
+
expected = inspect.cleandoc(expected).strip()
|
95
|
+
actual = html.content.strip()
|
96
|
+
if actual != expected:
|
97
|
+
print(actual)
|
98
|
+
self.assertEqual(actual.strip(), expected)
|
99
|
+
|
100
|
+
def test_to_html(self):
|
101
|
+
|
102
|
+
class Foo(Object):
|
103
|
+
x: int
|
104
|
+
y: Any = 'foo'
|
105
|
+
z: pg_typing.Int().freeze(1)
|
106
|
+
|
107
|
+
# Disable tooltip.
|
108
|
+
self.assert_content(
|
109
|
+
Foo(x=1, y='foo').to_html(
|
110
|
+
enable_summary_tooltip=False,
|
111
|
+
enable_key_tooltip=False
|
112
|
+
),
|
113
|
+
"""
|
114
|
+
<details open class="pyglove foo"><summary><div class="summary-title">Foo(...)</div></summary><div class="complex-value foo"><details open class="pyglove int"><summary><div class="summary-name">x</div><div class="summary-title">int</div></summary><span class="simple-value int">1</span></details><details open class="pyglove str"><summary><div class="summary-name">y</div><div class="summary-title">str</div></summary><span class="simple-value str">'foo'</span></details></div></details>
|
115
|
+
"""
|
116
|
+
)
|
117
|
+
# Hide frozen and default values.
|
118
|
+
self.assert_content(
|
119
|
+
Foo(x=1, y='foo').to_html(
|
120
|
+
enable_summary_tooltip=False,
|
121
|
+
enable_key_tooltip=False,
|
122
|
+
collapse_level=0,
|
123
|
+
key_style='label',
|
124
|
+
extra_flags=dict(
|
125
|
+
hide_frozen=True,
|
126
|
+
hide_default_values=True
|
127
|
+
)
|
128
|
+
),
|
129
|
+
"""
|
130
|
+
<details class="pyglove foo"><summary><div class="summary-title">Foo(...)</div></summary><div class="complex-value foo"><table><tr><td><span class="object-key str">x</span></td><td><span class="simple-value int">1</span></td></tr></table></div></details>
|
131
|
+
"""
|
132
|
+
)
|
133
|
+
self.assert_content(
|
134
|
+
Foo(x=1, y='foo').to_html(
|
135
|
+
enable_summary_tooltip=False,
|
136
|
+
enable_key_tooltip=False,
|
137
|
+
collapse_level=0,
|
138
|
+
extra_flags=dict(
|
139
|
+
hide_frozen=True,
|
140
|
+
hide_default_values=True
|
141
|
+
)
|
142
|
+
),
|
143
|
+
"""
|
144
|
+
<details class="pyglove foo"><summary><div class="summary-title">Foo(...)</div></summary><div class="complex-value foo"><details open class="pyglove int"><summary><div class="summary-name">x</div><div class="summary-title">int</div></summary><span class="simple-value int">1</span></details></div></details>
|
145
|
+
"""
|
146
|
+
)
|
147
|
+
# Use inferred values.
|
148
|
+
x = Dict(x=Dict(y=ValueFromParentChain()), y=2)
|
149
|
+
self.assert_content(
|
150
|
+
x.x.to_html(
|
151
|
+
enable_summary_tooltip=False,
|
152
|
+
enable_key_tooltip=False,
|
153
|
+
key_style='label',
|
154
|
+
extra_flags=dict(
|
155
|
+
use_inferred=False
|
156
|
+
)
|
157
|
+
),
|
158
|
+
"""
|
159
|
+
<details open class="pyglove dict"><summary><div class="summary-title">Dict(...)</div></summary><div class="complex-value dict"><table><tr><td><span class="object-key str">y</span></td><td><details class="pyglove value-from-parent-chain"><summary><div class="summary-title">ValueFromParentChain(...)</div></summary><div class="complex-value value-from-parent-chain"><span class="empty-container"></span></div></details></td></tr></table></div></details>
|
160
|
+
"""
|
161
|
+
)
|
162
|
+
self.assert_content(
|
163
|
+
x.x.to_html(
|
164
|
+
enable_summary_tooltip=False,
|
165
|
+
enable_key_tooltip=False,
|
166
|
+
key_style='label',
|
167
|
+
extra_flags=dict(
|
168
|
+
use_inferred=True
|
169
|
+
)
|
170
|
+
),
|
171
|
+
"""
|
172
|
+
<details open class="pyglove dict"><summary><div class="summary-title">Dict(...)</div></summary><div class="complex-value dict"><table><tr><td><span class="object-key str">y</span></td><td><span class="simple-value int">2</span></td></tr></table></div></details>
|
173
|
+
"""
|
85
174
|
)
|
175
|
+
# Test collapse level.
|
176
|
+
v = Foo(1, Foo(2, Foo(3, Foo(4))))
|
177
|
+
with views.view_options(key_style='label'):
|
178
|
+
self.assertEqual(
|
179
|
+
v.to_html(collapse_level=0).content.count('open'), 0
|
180
|
+
)
|
181
|
+
self.assertEqual(
|
182
|
+
v.to_html(collapse_level=1).content.count('open'), 1
|
183
|
+
)
|
184
|
+
self.assertEqual(
|
185
|
+
v.to_html(collapse_level=2).content.count('open'), 2
|
186
|
+
)
|
187
|
+
self.assertEqual(
|
188
|
+
v.to_html(collapse_level=None).content.count('open'), 4
|
189
|
+
)
|
86
190
|
|
87
191
|
|
88
192
|
if __name__ == '__main__':
|
@@ -15,14 +15,12 @@
|
|
15
15
|
|
16
16
|
import copy
|
17
17
|
import inspect
|
18
|
-
|
19
18
|
from typing import Any, List, Optional, Type
|
20
19
|
|
21
|
-
from pyglove.core import object_utils
|
22
20
|
from pyglove.core import typing as pg_typing
|
21
|
+
from pyglove.core import utils
|
23
22
|
from pyglove.core.symbolic import flags
|
24
23
|
from pyglove.core.symbolic import object as pg_object
|
25
|
-
from pyglove.core.symbolic import schema_utils
|
26
24
|
|
27
25
|
|
28
26
|
def boilerplate_class(
|
@@ -130,9 +128,9 @@ def boilerplate_class(
|
|
130
128
|
cls.auto_register = True
|
131
129
|
|
132
130
|
allow_partial = value.allow_partial
|
133
|
-
def _freeze_field(
|
134
|
-
|
135
|
-
|
131
|
+
def _freeze_field(
|
132
|
+
path: utils.KeyPath, field: pg_typing.Field, value: Any
|
133
|
+
) -> Any:
|
136
134
|
# We do not do validation since Object is already in valid form.
|
137
135
|
del path
|
138
136
|
if not isinstance(field.key, pg_typing.ListKey):
|
@@ -152,18 +150,14 @@ def boilerplate_class(
|
|
152
150
|
return value
|
153
151
|
|
154
152
|
# NOTE(daiyip): we call `cls.__schema__.apply` to freeze fields that have
|
155
|
-
# default values.
|
156
|
-
# it's copied from the boilerplate object's class which was already
|
157
|
-
# formalized.
|
153
|
+
# default values.
|
158
154
|
with flags.allow_writable_accessors():
|
159
155
|
cls.__schema__.apply(
|
160
156
|
value._sym_attributes, # pylint: disable=protected-access
|
161
157
|
allow_partial=allow_partial,
|
162
158
|
child_transform=_freeze_field,
|
163
159
|
)
|
164
|
-
|
165
|
-
|
166
|
-
schema_utils.validate_init_arg_list(init_arg_list, cls.__schema__)
|
167
|
-
cls.__schema__.metadata['init_arg_list'] = init_arg_list
|
160
|
+
cls.__schema__.metadata['init_arg_list'] = init_arg_list
|
161
|
+
cls.apply_schema(cls.__schema__)
|
168
162
|
cls.register_for_deserialization(serialization_key, additional_keys)
|
169
163
|
return cls
|
@@ -70,35 +70,37 @@ class BoilerplateClassTest(unittest.TestCase):
|
|
70
70
|
pg_boilerplate_class('A', template_object, init_arg_list=['x', 'y'])
|
71
71
|
|
72
72
|
def test_init_arg_list(self):
|
73
|
-
self.assertEqual(B.init_arg_list, ['a'
|
73
|
+
self.assertEqual(B.init_arg_list, ['a'])
|
74
74
|
self.assertEqual(C.init_arg_list, ['a', 'c', 'b'])
|
75
75
|
|
76
76
|
def test_schema(self):
|
77
77
|
# Boilerplate class' schema should carry the default value and be frozen.
|
78
78
|
self.assertEqual(
|
79
|
-
B.__schema__,
|
80
|
-
pg_typing.create_schema(
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
),
|
88
|
-
(
|
89
|
-
'c',
|
90
|
-
pg_typing.Dict([(
|
91
|
-
'd',
|
92
|
-
pg_typing.List(
|
93
|
-
pg_typing.Dict([
|
94
|
-
('e', pg_typing.Float()),
|
95
|
-
('f', pg_typing.Bool()),
|
96
|
-
]),
|
97
|
-
default=List([Dict(e=1.0, f=True)]),
|
79
|
+
list(B.__schema__.fields.values()),
|
80
|
+
list(pg_typing.create_schema(
|
81
|
+
[
|
82
|
+
('a', pg_typing.Int()),
|
83
|
+
(
|
84
|
+
'b',
|
85
|
+
pg_typing.Union(
|
86
|
+
[pg_typing.Int(), pg_typing.Str()], default='foo'
|
98
87
|
).freeze(),
|
99
|
-
)
|
100
|
-
|
101
|
-
|
88
|
+
),
|
89
|
+
(
|
90
|
+
'c',
|
91
|
+
pg_typing.Dict([(
|
92
|
+
'd',
|
93
|
+
pg_typing.List(
|
94
|
+
pg_typing.Dict([
|
95
|
+
('e', pg_typing.Float()),
|
96
|
+
('f', pg_typing.Bool()),
|
97
|
+
]),
|
98
|
+
default=List([Dict(e=1.0, f=True)]),
|
99
|
+
).freeze(),
|
100
|
+
)]).freeze(),
|
101
|
+
),
|
102
|
+
],
|
103
|
+
).fields.values())
|
102
104
|
)
|
103
105
|
|
104
106
|
# Original class' schema should remain unchanged.
|
@@ -23,22 +23,21 @@ import abc
|
|
23
23
|
import functools
|
24
24
|
import inspect
|
25
25
|
import types
|
26
|
-
from typing import Any, Callable, Dict, List, Optional, Sequence,
|
26
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
|
27
27
|
|
28
28
|
from pyglove.core import detouring
|
29
|
-
from pyglove.core import object_utils
|
30
29
|
from pyglove.core import typing as pg_typing
|
30
|
+
from pyglove.core import utils
|
31
31
|
|
32
32
|
from pyglove.core.symbolic import dict as pg_dict # pylint: disable=unused-import
|
33
33
|
from pyglove.core.symbolic import list as pg_list # pylint: disable=unused-import
|
34
34
|
from pyglove.core.symbolic import object as pg_object
|
35
|
-
from pyglove.core.symbolic import schema_utils
|
36
35
|
|
37
36
|
|
38
37
|
class ClassWrapperMeta(pg_object.ObjectMeta):
|
39
38
|
"""Metaclass for class wrapper."""
|
40
39
|
|
41
|
-
def __repr__(self) ->
|
40
|
+
def __repr__(self) -> str:
|
42
41
|
wrapped_cls = getattr(self, 'sym_wrapped_cls', None)
|
43
42
|
if wrapped_cls is None:
|
44
43
|
return f'<class {self.__type_name__!r}>'
|
@@ -46,11 +45,7 @@ class ClassWrapperMeta(pg_object.ObjectMeta):
|
|
46
45
|
|
47
46
|
def __getattr__(self, name):
|
48
47
|
"""Pass through attribute requests to sym_wrapped_cls."""
|
49
|
-
|
50
|
-
return super().__getattr__(name)
|
51
|
-
except AttributeError:
|
52
|
-
wrapped_cls = object.__getattribute__(self, 'sym_wrapped_cls')
|
53
|
-
return getattr(wrapped_cls, name)
|
48
|
+
return getattr(object.__getattribute__(self, 'sym_wrapped_cls'), name)
|
54
49
|
|
55
50
|
|
56
51
|
class ClassWrapper(pg_object.Object, metaclass=ClassWrapperMeta):
|
@@ -76,7 +71,7 @@ class _SubclassedWrapperBase(ClassWrapper):
|
|
76
71
|
# the `__init__` method.
|
77
72
|
auto_typing = False
|
78
73
|
|
79
|
-
@
|
74
|
+
@utils.explicit_method_override
|
80
75
|
def __init__(self, *args, **kwargs):
|
81
76
|
"""Overridden __init__ to construct symbolic wrapper only."""
|
82
77
|
# NOTE(daiyip): We avoid `__init__` to be called multiple times.
|
@@ -105,7 +100,7 @@ class _SubclassedWrapperBase(ClassWrapper):
|
|
105
100
|
def __init_subclass__(cls):
|
106
101
|
# Class wrappers inherit `__init__` from the user class. Therefore, we mark
|
107
102
|
# all of them as explicitly overridden.
|
108
|
-
|
103
|
+
utils.explicit_method_override(cls.__init__)
|
109
104
|
|
110
105
|
super().__init_subclass__()
|
111
106
|
if cls.__init__ is _SubclassedWrapperBase.__init__:
|
@@ -131,10 +126,10 @@ class _SubclassedWrapperBase(ClassWrapper):
|
|
131
126
|
# In both cases, we need to generate an __init__ wrapper for
|
132
127
|
# calling the symbolic initialization.
|
133
128
|
setattr(cls, '__orig_init__', cls.__init__)
|
134
|
-
|
129
|
+
init_arg_list, arg_fields = _extract_init_signature(
|
135
130
|
cls, auto_doc=cls.auto_doc, auto_typing=cls.auto_typing)
|
136
131
|
|
137
|
-
@
|
132
|
+
@utils.explicit_method_override
|
138
133
|
@functools.wraps(cls.__init__)
|
139
134
|
def _sym_init(self, *args, **kwargs):
|
140
135
|
_SubclassedWrapperBase.__init__(self, *args, **kwargs)
|
@@ -152,9 +147,7 @@ class _SubclassedWrapperBase(ClassWrapper):
|
|
152
147
|
|
153
148
|
# We do not extend existing schema which is inherited from the base
|
154
149
|
# class.
|
155
|
-
|
156
|
-
cls, arg_fields, init_arg_list=init_arg_list,
|
157
|
-
extend=False, description=description)
|
150
|
+
cls.update_schema(arg_fields, init_arg_list=init_arg_list, extend=False)
|
158
151
|
else:
|
159
152
|
assert hasattr(cls, '__orig_init__')
|
160
153
|
|
@@ -228,8 +221,8 @@ def _subclassed_wrapper(
|
|
228
221
|
use_auto_doc: bool,
|
229
222
|
use_auto_typing: bool,
|
230
223
|
reset_state_fn: Optional[Callable[[Any], None]],
|
231
|
-
class_name: Optional[
|
232
|
-
module_name: Optional[
|
224
|
+
class_name: Optional[str] = None,
|
225
|
+
module_name: Optional[str] = None):
|
233
226
|
"""Class wrapper implementation by regular multi-inheritance."""
|
234
227
|
# NOTE(daiyip): The user class may have a user-defined metaclass, which
|
235
228
|
# conflicts with the metaclass of the symbolic base. Therefore, we detect
|
@@ -312,10 +305,11 @@ def _subclassed_wrapper(
|
|
312
305
|
|
313
306
|
def wrap(
|
314
307
|
cls,
|
315
|
-
init_args:
|
316
|
-
|
317
|
-
|
318
|
-
|
308
|
+
init_args: Union[
|
309
|
+
List[Union[pg_typing.Field, pg_typing.FieldDef]],
|
310
|
+
Dict[pg_typing.FieldKeyDef, pg_typing.FieldValueDef],
|
311
|
+
None
|
312
|
+
] = None,
|
319
313
|
*,
|
320
314
|
reset_state_fn: Optional[Callable[[Any], None]] = None,
|
321
315
|
repr: bool = True, # pylint: disable=redefined-builtin
|
@@ -429,16 +423,10 @@ def wrap(
|
|
429
423
|
if issubclass(cls, ClassWrapper):
|
430
424
|
# Update init argument specifications according to user specified specs.
|
431
425
|
# Replace schema instead of extending it.
|
432
|
-
|
426
|
+
init_arg_list, arg_fields = _extract_init_signature(
|
433
427
|
cls, init_args, auto_doc=auto_doc, auto_typing=auto_typing)
|
434
|
-
|
435
|
-
|
436
|
-
arg_fields,
|
437
|
-
init_arg_list=init_arg_list,
|
438
|
-
extend=False,
|
439
|
-
description=description,
|
440
|
-
serialization_key=serialization_key,
|
441
|
-
additional_keys=additional_keys)
|
428
|
+
cls.update_schema(arg_fields, init_arg_list=init_arg_list, extend=False)
|
429
|
+
cls.register_for_deserialization(serialization_key, additional_keys)
|
442
430
|
|
443
431
|
if override:
|
444
432
|
for k, v in override.items():
|
@@ -448,7 +436,7 @@ def wrap(
|
|
448
436
|
|
449
437
|
def wrap_module(
|
450
438
|
module,
|
451
|
-
names: Optional[Sequence[
|
439
|
+
names: Optional[Sequence[str]] = None,
|
452
440
|
where: Optional[Callable[[Type['ClassWrapper']], bool]] = None,
|
453
441
|
export_to: Optional[types.ModuleType] = None,
|
454
442
|
**kwargs):
|
@@ -534,7 +522,7 @@ def apply_wrappers(
|
|
534
522
|
"""
|
535
523
|
if not wrapper_classes:
|
536
524
|
wrapper_classes = []
|
537
|
-
for _, c in
|
525
|
+
for _, c in utils.JSONConvertible.registered_types():
|
538
526
|
if (issubclass(c, ClassWrapper)
|
539
527
|
and c not in (ClassWrapper, _SubclassedWrapperBase)
|
540
528
|
and (not where or where(c))
|
@@ -547,23 +535,30 @@ def _extract_init_signature(
|
|
547
535
|
cls,
|
548
536
|
arg_specs=None,
|
549
537
|
auto_doc: bool = False,
|
550
|
-
auto_typing: bool = False
|
538
|
+
auto_typing: bool = False
|
539
|
+
) -> Tuple[List[str], List[pg_typing.Field]]:
|
551
540
|
"""Extract argument fields from class __init__ method."""
|
552
541
|
init_method = getattr(cls, '__orig_init__', cls.__init__)
|
553
|
-
|
554
|
-
description = None
|
555
|
-
args_docstr = dict()
|
542
|
+
docstr = None
|
556
543
|
if auto_doc:
|
557
544
|
# Read args docstr from both class doc string and __init__ doc string.
|
545
|
+
args_docstr = dict()
|
558
546
|
if cls.__doc__:
|
559
|
-
cls_docstr =
|
560
|
-
description = schema_utils.schema_description_from_docstr(
|
561
|
-
cls_docstr, include_long_description=True)
|
547
|
+
cls_docstr = utils.DocStr.parse(cls.__doc__)
|
562
548
|
args_docstr = cls_docstr.args
|
563
549
|
if init_method.__doc__:
|
564
|
-
init_docstr =
|
550
|
+
init_docstr = utils.DocStr.parse(init_method.__doc__)
|
565
551
|
args_docstr.update(init_docstr.args)
|
566
|
-
|
552
|
+
docstr = utils.DocStr(
|
553
|
+
utils.DocStrStyle.GOOGLE,
|
554
|
+
short_description=None,
|
555
|
+
long_description=None,
|
556
|
+
examples=[],
|
557
|
+
args=args_docstr,
|
558
|
+
returns=None,
|
559
|
+
raises=[],
|
560
|
+
blank_after_short_description=True,
|
561
|
+
)
|
567
562
|
if init_method is object.__init__:
|
568
563
|
if arg_specs:
|
569
564
|
raise ValueError(
|
@@ -572,13 +567,21 @@ def _extract_init_signature(
|
|
572
567
|
init_arg_list = []
|
573
568
|
arg_fields = []
|
574
569
|
else:
|
575
|
-
signature = pg_typing.
|
570
|
+
signature = pg_typing.Signature.from_signature(
|
571
|
+
inspect.signature(init_method),
|
572
|
+
name=cls.__name__,
|
573
|
+
callable_type=pg_typing.CallableType.METHOD,
|
574
|
+
module_name=cls.__module__,
|
575
|
+
qualname=cls.__qualname__,
|
576
|
+
auto_typing=auto_typing,
|
577
|
+
docstr=docstr,
|
578
|
+
).annotate(arg_specs)
|
576
579
|
if not signature.args or signature.args[0].name != 'self':
|
577
580
|
raise ValueError(
|
578
581
|
f'{cls.__name__}.__init__ must have `self` as the first argument.')
|
579
582
|
# Remove field for 'self'.
|
580
|
-
arg_fields =
|
583
|
+
arg_fields = signature.fields(remove_self=True)
|
581
584
|
init_arg_list = [arg.name for arg in signature.args[1:]]
|
582
585
|
if signature.varargs is not None:
|
583
586
|
init_arg_list.append(f'*{signature.varargs.name}')
|
584
|
-
return (
|
587
|
+
return (init_arg_list, arg_fields)
|
@@ -588,7 +588,7 @@ class ClassWrapperTest(unittest.TestCase):
|
|
588
588
|
self.assertIsInstance(C(1, 2), ClassWrapper)
|
589
589
|
self.assertTrue(pg_eq(C(1, 2), C(1, 2)))
|
590
590
|
self.assertEqual(list(C.__schema__.fields.keys()), ['x', 'y'])
|
591
|
-
self.assertEqual(repr(C), f'<class {C.
|
591
|
+
self.assertEqual(repr(C), f'<class {C.__type_name__!r}>')
|
592
592
|
|
593
593
|
def test_custom_metaclass(self):
|
594
594
|
|
@@ -604,7 +604,7 @@ class ClassWrapperTest(unittest.TestCase):
|
|
604
604
|
A1 = pg_wrap(A) # pylint: disable=invalid-name
|
605
605
|
self.assertTrue(issubclass(A1, ClassWrapper))
|
606
606
|
self.assertTrue(issubclass(A1, A))
|
607
|
-
self.assertEqual(A1.
|
607
|
+
self.assertEqual(A1.__type_name__, 'pyglove.core.symbolic.class_wrapper_test.A')
|
608
608
|
self.assertEqual(A1.__schema__, pg_typing.Schema([]))
|
609
609
|
self.assertEqual(A1.foo, 'foo')
|
610
610
|
self.assertRegex(repr(A1), r'Symbolic\[.*\]')
|
@@ -17,10 +17,9 @@ import abc
|
|
17
17
|
import inspect
|
18
18
|
import sys
|
19
19
|
import types
|
20
|
-
from typing import Any, List, Optional, Tuple, Type, Union
|
20
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
21
21
|
|
22
|
-
from pyglove.core import
|
23
|
-
from pyglove.core.symbolic import schema_utils
|
22
|
+
from pyglove.core import utils
|
24
23
|
from pyglove.core.symbolic.base import Symbolic
|
25
24
|
from pyglove.core.symbolic.object import Object
|
26
25
|
import pyglove.core.typing as pg_typing
|
@@ -40,7 +39,7 @@ class Compound(Object):
|
|
40
39
|
# from the user class to compound with.
|
41
40
|
Object.__init_subclass__(cls)
|
42
41
|
|
43
|
-
@
|
42
|
+
@utils.explicit_method_override
|
44
43
|
def __init__(self, *args, **kwargs):
|
45
44
|
# `explicit_init` allows the `__init__` of the other classes that sit after
|
46
45
|
# `Compound` to be bypassed.
|
@@ -53,16 +52,11 @@ _COMPOUND_OWNED_ATTR_NAMES = frozenset(dir(Compound))
|
|
53
52
|
def compound_class(
|
54
53
|
factory_fn: types.FunctionType,
|
55
54
|
base_class: Optional[Type[Object]] = None,
|
56
|
-
args:
|
57
|
-
List[
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
Tuple[str, pg_typing.KeySpec], pg_typing.ValueSpec, str, Any
|
62
|
-
],
|
63
|
-
]
|
64
|
-
]
|
65
|
-
] = None, # pylint: disable=bad-continuation
|
55
|
+
args: Union[
|
56
|
+
List[Union[pg_typing.Field, pg_typing.FieldDef]],
|
57
|
+
Dict[pg_typing.FieldKeyDef, pg_typing.FieldValueDef],
|
58
|
+
None
|
59
|
+
] = None,
|
66
60
|
*,
|
67
61
|
lazy_build: bool = True,
|
68
62
|
auto_doc: bool = True,
|
@@ -127,7 +121,7 @@ def compound_class(
|
|
127
121
|
if not inspect.isfunction(factory_fn):
|
128
122
|
raise TypeError('Decorator `compound` is only applicable to functions.')
|
129
123
|
|
130
|
-
schema =
|
124
|
+
schema = pg_typing.schema(
|
131
125
|
factory_fn,
|
132
126
|
args=args,
|
133
127
|
returns=pg_typing.Object(base_class) if base_class else None,
|
@@ -11,15 +11,13 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for pyglove.compounding."""
|
15
|
-
|
16
14
|
import abc
|
17
15
|
import dataclasses
|
18
16
|
import sys
|
19
17
|
import unittest
|
20
18
|
|
21
|
-
from pyglove.core import object_utils
|
22
19
|
from pyglove.core import typing as pg_typing
|
20
|
+
from pyglove.core import utils
|
23
21
|
from pyglove.core.symbolic.compounding import compound as pg_compound
|
24
22
|
from pyglove.core.symbolic.compounding import compound_class as pg_compound_class
|
25
23
|
from pyglove.core.symbolic.dict import Dict
|
@@ -145,7 +143,7 @@ class UserClassTest(unittest.TestCase):
|
|
145
143
|
class A(Object):
|
146
144
|
x: int
|
147
145
|
|
148
|
-
@
|
146
|
+
@utils.explicit_method_override
|
149
147
|
def __init__(self, x):
|
150
148
|
super().__init__(x=x)
|
151
149
|
assert type(self) is A # pylint: disable=unidiomatic-typecheck
|