pyglove 0.4.5.dev202501050808__py3-none-any.whl → 0.4.5.dev202501060809__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyglove/core/__init__.py +24 -21
- pyglove/core/geno/base.py +53 -38
- pyglove/core/geno/base_test.py +2 -4
- pyglove/core/geno/categorical.py +36 -27
- pyglove/core/geno/custom.py +18 -15
- pyglove/core/geno/numerical.py +19 -16
- pyglove/core/geno/space.py +3 -4
- pyglove/core/hyper/base.py +6 -6
- pyglove/core/hyper/categorical.py +91 -52
- pyglove/core/hyper/custom.py +7 -7
- pyglove/core/hyper/custom_test.py +9 -10
- pyglove/core/hyper/derived.py +30 -22
- pyglove/core/hyper/derived_test.py +2 -4
- pyglove/core/hyper/dynamic_evaluation.py +3 -4
- pyglove/core/hyper/evolvable.py +57 -46
- pyglove/core/hyper/numerical.py +48 -24
- pyglove/core/hyper/numerical_test.py +9 -9
- pyglove/core/hyper/object_template.py +58 -46
- pyglove/core/logging_test.py +0 -2
- pyglove/core/patching/object_factory.py +4 -4
- pyglove/core/patching/pattern_based.py +4 -4
- pyglove/core/patching/rule_based.py +4 -3
- pyglove/core/symbolic/base.py +167 -131
- pyglove/core/symbolic/base_test.py +17 -19
- pyglove/core/symbolic/boilerplate.py +4 -5
- pyglove/core/symbolic/class_wrapper.py +9 -9
- pyglove/core/symbolic/compounding.py +2 -2
- pyglove/core/symbolic/compounding_test.py +2 -4
- pyglove/core/symbolic/dict.py +70 -54
- pyglove/core/symbolic/dict_test.py +117 -100
- pyglove/core/symbolic/diff.py +12 -12
- pyglove/core/symbolic/flags.py +1 -1
- pyglove/core/symbolic/functor.py +16 -15
- pyglove/core/symbolic/functor_test.py +2 -4
- pyglove/core/symbolic/inferred.py +2 -2
- pyglove/core/symbolic/list.py +70 -47
- pyglove/core/symbolic/list_test.py +117 -98
- pyglove/core/symbolic/object.py +42 -40
- pyglove/core/symbolic/object_test.py +95 -88
- pyglove/core/symbolic/origin.py +5 -7
- pyglove/core/symbolic/pure_symbolic.py +4 -3
- pyglove/core/symbolic/ref.py +12 -8
- pyglove/core/tuning/local_backend.py +2 -2
- pyglove/core/tuning/protocols.py +3 -3
- pyglove/core/typing/annotation_conversion.py +3 -3
- pyglove/core/typing/callable_ext.py +11 -13
- pyglove/core/typing/callable_signature.py +19 -18
- pyglove/core/typing/callable_signature_test.py +3 -5
- pyglove/core/typing/class_schema.py +48 -44
- pyglove/core/typing/class_schema_test.py +3 -5
- pyglove/core/typing/custom_typing.py +5 -4
- pyglove/core/typing/key_specs.py +5 -7
- pyglove/core/typing/key_specs_test.py +4 -4
- pyglove/core/typing/type_conversion.py +4 -5
- pyglove/core/typing/type_conversion_test.py +12 -12
- pyglove/core/typing/typed_missing.py +6 -7
- pyglove/core/typing/typed_missing_test.py +7 -8
- pyglove/core/typing/value_specs.py +210 -141
- pyglove/core/typing/value_specs_test.py +12 -13
- pyglove/core/utils/__init__.py +159 -0
- pyglove/core/{object_utils → utils}/common_traits_test.py +1 -3
- pyglove/core/{object_utils → utils}/docstr_utils_test.py +1 -3
- pyglove/core/{object_utils → utils}/error_utils.py +3 -3
- pyglove/core/{object_utils → utils}/error_utils_test.py +1 -1
- pyglove/core/{object_utils → utils}/formatting.py +1 -1
- pyglove/core/{object_utils → utils}/formatting_test.py +1 -2
- pyglove/core/{object_utils → utils}/hierarchical.py +23 -25
- pyglove/core/{object_utils → utils}/hierarchical_test.py +3 -5
- pyglove/core/{object_utils → utils}/json_conversion_test.py +1 -3
- pyglove/core/{object_utils → utils}/missing.py +2 -2
- pyglove/core/{object_utils → utils}/missing_test.py +2 -4
- pyglove/core/{object_utils → utils}/thread_local_test.py +1 -3
- pyglove/core/{object_utils → utils}/timing.py +3 -3
- pyglove/core/{object_utils → utils}/timing_test.py +2 -3
- pyglove/core/{object_utils → utils}/value_location.py +2 -2
- pyglove/core/{object_utils → utils}/value_location_test.py +2 -4
- pyglove/core/views/base.py +25 -29
- pyglove/core/views/html/base.py +14 -15
- pyglove/core/views/html/controls/base.py +5 -5
- pyglove/core/views/html/controls/progress_bar.py +3 -5
- pyglove/core/views/html/tree_view.py +37 -35
- {pyglove-0.4.5.dev202501050808.dist-info → pyglove-0.4.5.dev202501060809.dist-info}/METADATA +1 -1
- {pyglove-0.4.5.dev202501050808.dist-info → pyglove-0.4.5.dev202501060809.dist-info}/RECORD +90 -90
- {pyglove-0.4.5.dev202501050808.dist-info → pyglove-0.4.5.dev202501060809.dist-info}/WHEEL +1 -1
- pyglove/core/object_utils/__init__.py +0 -161
- /pyglove/core/{object_utils → utils}/common_traits.py +0 -0
- /pyglove/core/{object_utils → utils}/docstr_utils.py +0 -0
- /pyglove/core/{object_utils → utils}/json_conversion.py +0 -0
- /pyglove/core/{object_utils → utils}/thread_local.py +0 -0
- {pyglove-0.4.5.dev202501050808.dist-info → pyglove-0.4.5.dev202501060809.dist-info}/LICENSE +0 -0
- {pyglove-0.4.5.dev202501050808.dist-info → pyglove-0.4.5.dev202501060809.dist-info}/top_level.txt +0 -0
pyglove/core/hyper/evolvable.py
CHANGED
@@ -20,9 +20,9 @@ import types
|
|
20
20
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
21
21
|
|
22
22
|
from pyglove.core import geno
|
23
|
-
from pyglove.core import object_utils
|
24
23
|
from pyglove.core import symbolic
|
25
24
|
from pyglove.core import typing as pg_typing
|
25
|
+
from pyglove.core import utils
|
26
26
|
from pyglove.core.hyper import custom
|
27
27
|
|
28
28
|
|
@@ -44,7 +44,7 @@ class MutationPoint:
|
|
44
44
|
parent: The parent node of the mutation point.
|
45
45
|
"""
|
46
46
|
mutation_type: 'MutationType'
|
47
|
-
location:
|
47
|
+
location: utils.KeyPath
|
48
48
|
old_value: Any
|
49
49
|
parent: Optional[symbolic.Symbolic]
|
50
50
|
|
@@ -71,9 +71,9 @@ class Evolvable(custom.CustomHyper):
|
|
71
71
|
mutation_points: List[MutationPoint] = []
|
72
72
|
mutation_weights: List[float] = []
|
73
73
|
|
74
|
-
def _choose_mutation_point(
|
75
|
-
|
76
|
-
|
74
|
+
def _choose_mutation_point(
|
75
|
+
k: utils.KeyPath, v: Any, p: Optional[symbolic.Symbolic]
|
76
|
+
):
|
77
77
|
"""Visiting function for a symbolic node."""
|
78
78
|
def _add_point(mt: MutationType, k=k, v=v, p=p):
|
79
79
|
mutation_points.append(MutationPoint(mt, k, v, p))
|
@@ -98,10 +98,9 @@ class Evolvable(custom.CustomHyper):
|
|
98
98
|
reached_min_size = False
|
99
99
|
|
100
100
|
for i, cv in enumerate(v):
|
101
|
-
ck =
|
101
|
+
ck = utils.KeyPath(i, parent=k)
|
102
102
|
if not reached_max_size:
|
103
|
-
_add_point(MutationType.INSERT,
|
104
|
-
k=ck, v=object_utils.MISSING_VALUE, p=v)
|
103
|
+
_add_point(MutationType.INSERT, k=ck, v=utils.MISSING_VALUE, p=v)
|
105
104
|
|
106
105
|
if not reached_min_size:
|
107
106
|
_add_point(MutationType.DELETE, k=ck, v=cv, p=v)
|
@@ -109,10 +108,12 @@ class Evolvable(custom.CustomHyper):
|
|
109
108
|
# Replace type and value will be added in traverse.
|
110
109
|
symbolic.traverse(cv, _choose_mutation_point, root_path=ck, parent=v)
|
111
110
|
if not reached_max_size and i == len(v) - 1:
|
112
|
-
_add_point(
|
113
|
-
|
114
|
-
|
115
|
-
|
111
|
+
_add_point(
|
112
|
+
MutationType.INSERT,
|
113
|
+
k=utils.KeyPath(i + 1, parent=k),
|
114
|
+
v=utils.MISSING_VALUE,
|
115
|
+
p=v,
|
116
|
+
)
|
116
117
|
return symbolic.TraverseAction.CONTINUE
|
117
118
|
return symbolic.TraverseAction.ENTER
|
118
119
|
|
@@ -157,7 +158,7 @@ class Evolvable(custom.CustomHyper):
|
|
157
158
|
point.location, point.old_value, point.parent)
|
158
159
|
elif point.mutation_type == MutationType.INSERT:
|
159
160
|
assert isinstance(point.parent, symbolic.List), point
|
160
|
-
assert point.old_value ==
|
161
|
+
assert point.old_value == utils.MISSING_VALUE, point
|
161
162
|
assert isinstance(point.location.key, int), point
|
162
163
|
with symbolic.allow_writable_accessors():
|
163
164
|
point.parent.insert(
|
@@ -175,24 +176,31 @@ class Evolvable(custom.CustomHyper):
|
|
175
176
|
# We defer members declaration for Evolvable since the weights will reference
|
176
177
|
# the definition of MutationType.
|
177
178
|
symbolic.members([
|
178
|
-
(
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
(
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
179
|
+
(
|
180
|
+
'initial_value',
|
181
|
+
pg_typing.Object(symbolic.Symbolic),
|
182
|
+
'Symbolic value to involve.',
|
183
|
+
),
|
184
|
+
('node_transform', pg_typing.Callable([], returns=pg_typing.Any()), ''),
|
185
|
+
(
|
186
|
+
'weights',
|
187
|
+
pg_typing.Callable(
|
188
|
+
[
|
189
|
+
pg_typing.Object(MutationType),
|
190
|
+
pg_typing.Object(utils.KeyPath),
|
191
|
+
pg_typing.Any().noneable(),
|
192
|
+
pg_typing.Object(symbolic.Symbolic),
|
193
|
+
],
|
194
|
+
returns=pg_typing.Float(min_value=0.0),
|
195
|
+
).noneable(),
|
196
|
+
(
|
197
|
+
'An optional callable object that returns the unnormalized (e.g.'
|
198
|
+
' the sum of all probabilities do not have to sum to 1.0) mutation'
|
199
|
+
' probabilities for all the nodes in the symbolic tree, based on'
|
200
|
+
' (mutation type, location, old value, parent node). If None, all'
|
201
|
+
' the locations and mutation types will be sampled uniformly.'
|
202
|
+
),
|
203
|
+
),
|
196
204
|
])(Evolvable)
|
197
205
|
|
198
206
|
|
@@ -200,25 +208,28 @@ def evolve(
|
|
200
208
|
initial_value: symbolic.Symbolic,
|
201
209
|
node_transform: Callable[
|
202
210
|
[
|
203
|
-
|
204
|
-
Any,
|
205
|
-
|
206
|
-
symbolic.Symbolic,
|
211
|
+
utils.KeyPath, # Location.
|
212
|
+
Any, # Old value.
|
213
|
+
# pg.MISSING_VALUE for insertion.
|
214
|
+
symbolic.Symbolic, # Parent node.
|
207
215
|
],
|
208
|
-
Any
|
216
|
+
Any, # Replacement.
|
209
217
|
],
|
210
218
|
*,
|
211
|
-
weights: Optional[
|
212
|
-
[
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
219
|
+
weights: Optional[
|
220
|
+
Callable[
|
221
|
+
[
|
222
|
+
MutationType, # Mutation type.
|
223
|
+
utils.KeyPath, # Location.
|
224
|
+
Any, # Value.
|
225
|
+
symbolic.Symbolic, # Parent.
|
226
|
+
],
|
227
|
+
float, # Mutation weight.
|
228
|
+
]
|
229
|
+
] = None, # pylint: disable=bad-whitespace
|
220
230
|
name: Optional[str] = None,
|
221
|
-
hints: Optional[Any] = None
|
231
|
+
hints: Optional[Any] = None
|
232
|
+
) -> Evolvable:
|
222
233
|
"""An evolvable symbolic value.
|
223
234
|
|
224
235
|
Example::
|
pyglove/core/hyper/numerical.py
CHANGED
@@ -17,9 +17,9 @@ import typing
|
|
17
17
|
from typing import Any, Callable, Optional, Tuple
|
18
18
|
|
19
19
|
from pyglove.core import geno
|
20
|
-
from pyglove.core import object_utils
|
21
20
|
from pyglove.core import symbolic
|
22
21
|
from pyglove.core import typing as pg_typing
|
22
|
+
from pyglove.core import utils
|
23
23
|
from pyglove.core.hyper import base
|
24
24
|
|
25
25
|
|
@@ -62,8 +62,7 @@ class Float(base.HyperPrimitive):
|
|
62
62
|
f'\'min_value\' must be positive when `scale` is {self.scale!r}. '
|
63
63
|
f'encountered: {self.min_value}.')
|
64
64
|
|
65
|
-
def dna_spec(self,
|
66
|
-
location: Optional[object_utils.KeyPath] = None) -> geno.Float:
|
65
|
+
def dna_spec(self, location: Optional[utils.KeyPath] = None) -> geno.Float:
|
67
66
|
"""Returns corresponding DNASpec."""
|
68
67
|
return geno.Float(
|
69
68
|
min_value=self.min_value,
|
@@ -71,55 +70,74 @@ class Float(base.HyperPrimitive):
|
|
71
70
|
scale=self.scale,
|
72
71
|
hints=self.hints,
|
73
72
|
name=self.name,
|
74
|
-
location=location or
|
73
|
+
location=location or utils.KeyPath(),
|
74
|
+
)
|
75
75
|
|
76
76
|
def _decode(self) -> float:
|
77
77
|
"""Decode a DNA into a float value."""
|
78
78
|
dna = self._dna
|
79
79
|
if not isinstance(dna.value, float):
|
80
80
|
raise ValueError(
|
81
|
-
|
82
|
-
f'Expect float value. Encountered: {dna.value}.', self.sym_path
|
81
|
+
utils.message_on_path(
|
82
|
+
f'Expect float value. Encountered: {dna.value}.', self.sym_path
|
83
|
+
)
|
84
|
+
)
|
83
85
|
if dna.value < self.min_value:
|
84
86
|
raise ValueError(
|
85
|
-
|
87
|
+
utils.message_on_path(
|
86
88
|
f'DNA value should be no less than {self.min_value}. '
|
87
|
-
f'Encountered {dna.value}.',
|
89
|
+
f'Encountered {dna.value}.',
|
90
|
+
self.sym_path,
|
91
|
+
)
|
92
|
+
)
|
88
93
|
|
89
94
|
if dna.value > self.max_value:
|
90
95
|
raise ValueError(
|
91
|
-
|
96
|
+
utils.message_on_path(
|
92
97
|
f'DNA value should be no greater than {self.max_value}. '
|
93
|
-
f'Encountered {dna.value}.',
|
98
|
+
f'Encountered {dna.value}.',
|
99
|
+
self.sym_path,
|
100
|
+
)
|
101
|
+
)
|
94
102
|
return dna.value
|
95
103
|
|
96
104
|
def encode(self, value: float) -> geno.DNA:
|
97
105
|
"""Encode a float value into a DNA."""
|
98
106
|
if not isinstance(value, float):
|
99
107
|
raise ValueError(
|
100
|
-
|
108
|
+
utils.message_on_path(
|
101
109
|
f'Value should be float to be encoded for {self!r}. '
|
102
|
-
f'Encountered {value}.',
|
110
|
+
f'Encountered {value}.',
|
111
|
+
self.sym_path,
|
112
|
+
)
|
113
|
+
)
|
103
114
|
if value < self.min_value:
|
104
115
|
raise ValueError(
|
105
|
-
|
116
|
+
utils.message_on_path(
|
106
117
|
f'Value should be no less than {self.min_value}. '
|
107
|
-
f'Encountered {value}.',
|
118
|
+
f'Encountered {value}.',
|
119
|
+
self.sym_path,
|
120
|
+
)
|
121
|
+
)
|
108
122
|
if value > self.max_value:
|
109
123
|
raise ValueError(
|
110
|
-
|
124
|
+
utils.message_on_path(
|
111
125
|
f'Value should be no greater than {self.max_value}. '
|
112
|
-
f'Encountered {value}.',
|
126
|
+
f'Encountered {value}.',
|
127
|
+
self.sym_path,
|
128
|
+
)
|
129
|
+
)
|
113
130
|
return geno.DNA(value)
|
114
131
|
|
115
132
|
def custom_apply(
|
116
133
|
self,
|
117
|
-
path:
|
134
|
+
path: utils.KeyPath,
|
118
135
|
value_spec: pg_typing.ValueSpec,
|
119
136
|
allow_partial: bool = False,
|
120
|
-
child_transform: Optional[
|
121
|
-
[
|
122
|
-
|
137
|
+
child_transform: Optional[
|
138
|
+
Callable[[utils.KeyPath, pg_typing.Field, Any], Any]
|
139
|
+
] = None,
|
140
|
+
) -> Tuple[bool, 'Float']:
|
123
141
|
"""Validate candidates during value_spec binding time."""
|
124
142
|
del allow_partial
|
125
143
|
del child_transform
|
@@ -134,17 +152,23 @@ class Float(base.HyperPrimitive):
|
|
134
152
|
if (float_spec.min_value is not None
|
135
153
|
and self.min_value < float_spec.min_value):
|
136
154
|
raise ValueError(
|
137
|
-
|
155
|
+
utils.message_on_path(
|
138
156
|
f'Float.min_value ({self.min_value}) should be no less than '
|
139
157
|
f'the min value ({float_spec.min_value}) of value spec: '
|
140
|
-
f'{float_spec}.',
|
158
|
+
f'{float_spec}.',
|
159
|
+
path,
|
160
|
+
)
|
161
|
+
)
|
141
162
|
if (float_spec.max_value is not None
|
142
163
|
and self.max_value > float_spec.max_value):
|
143
164
|
raise ValueError(
|
144
|
-
|
165
|
+
utils.message_on_path(
|
145
166
|
f'Float.max_value ({self.max_value}) should be no greater than '
|
146
167
|
f'the max value ({float_spec.max_value}) of value spec: '
|
147
|
-
f'{float_spec}.',
|
168
|
+
f'{float_spec}.',
|
169
|
+
path,
|
170
|
+
)
|
171
|
+
)
|
148
172
|
return (False, self)
|
149
173
|
|
150
174
|
def is_leaf(self) -> bool:
|
@@ -11,14 +11,12 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for pyglove.hyper.Float."""
|
15
|
-
|
16
14
|
import unittest
|
17
15
|
|
18
16
|
from pyglove.core import geno
|
19
|
-
from pyglove.core import object_utils
|
20
17
|
from pyglove.core import symbolic
|
21
18
|
from pyglove.core import typing as pg_typing
|
19
|
+
from pyglove.core import utils
|
22
20
|
from pyglove.core.hyper.numerical import Float
|
23
21
|
from pyglove.core.hyper.numerical import floatv
|
24
22
|
|
@@ -44,12 +42,14 @@ class FloatTest(unittest.TestCase):
|
|
44
42
|
floatv(-1.0, 1.0, 'log')
|
45
43
|
|
46
44
|
def test_dna_spec(self):
|
47
|
-
self.assertTrue(
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
45
|
+
self.assertTrue(
|
46
|
+
symbolic.eq(
|
47
|
+
floatv(0.0, 1.0).dna_spec('a'),
|
48
|
+
geno.Float(
|
49
|
+
location=utils.KeyPath('a'), min_value=0.0, max_value=1.0
|
50
|
+
),
|
51
|
+
)
|
52
|
+
)
|
53
53
|
|
54
54
|
def test_decode(self):
|
55
55
|
v = floatv(0.0, 1.0)
|
@@ -16,14 +16,14 @@
|
|
16
16
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17
17
|
|
18
18
|
from pyglove.core import geno
|
19
|
-
from pyglove.core import object_utils
|
20
19
|
from pyglove.core import symbolic
|
21
20
|
from pyglove.core import typing as pg_typing
|
21
|
+
from pyglove.core import utils
|
22
22
|
from pyglove.core.hyper import base
|
23
23
|
from pyglove.core.hyper import derived
|
24
24
|
|
25
25
|
|
26
|
-
class ObjectTemplate(base.HyperValue,
|
26
|
+
class ObjectTemplate(base.HyperValue, utils.Formattable):
|
27
27
|
"""Object template that encodes and decodes symbolic values.
|
28
28
|
|
29
29
|
An object template can be created from a hyper value, which is a symbolic
|
@@ -131,18 +131,18 @@ class ObjectTemplate(base.HyperValue, object_utils.Formattable):
|
|
131
131
|
"""
|
132
132
|
super().__init__()
|
133
133
|
self._value = value
|
134
|
-
self._root_path =
|
134
|
+
self._root_path = utils.KeyPath()
|
135
135
|
self._compute_derived = compute_derived
|
136
136
|
self._where = where
|
137
137
|
self._parse_generators()
|
138
138
|
|
139
139
|
@property
|
140
|
-
def root_path(self) ->
|
140
|
+
def root_path(self) -> utils.KeyPath:
|
141
141
|
"""Returns root path."""
|
142
142
|
return self._root_path
|
143
143
|
|
144
144
|
@root_path.setter
|
145
|
-
def root_path(self, path:
|
145
|
+
def root_path(self, path: utils.KeyPath):
|
146
146
|
"""Set root path."""
|
147
147
|
self._root_path = path
|
148
148
|
|
@@ -150,7 +150,8 @@ class ObjectTemplate(base.HyperValue, object_utils.Formattable):
|
|
150
150
|
"""Parse generators from its templated value."""
|
151
151
|
hyper_primitives = []
|
152
152
|
def _extract_immediate_child_hyper_primitives(
|
153
|
-
path:
|
153
|
+
path: utils.KeyPath, value: Any
|
154
|
+
) -> bool:
|
154
155
|
"""Extract top-level hyper primitives."""
|
155
156
|
if (isinstance(value, base.HyperValue)
|
156
157
|
and (not self._where or self._where(value))):
|
@@ -162,13 +163,14 @@ class ObjectTemplate(base.HyperValue, object_utils.Formattable):
|
|
162
163
|
hyper_primitives.append((path, value))
|
163
164
|
elif isinstance(value, symbolic.Object):
|
164
165
|
for k, v in value.sym_items():
|
165
|
-
|
166
|
-
v,
|
167
|
-
|
166
|
+
utils.traverse(
|
167
|
+
v,
|
168
|
+
_extract_immediate_child_hyper_primitives,
|
169
|
+
root_path=utils.KeyPath(k, path),
|
170
|
+
)
|
168
171
|
return True
|
169
172
|
|
170
|
-
|
171
|
-
self._value, _extract_immediate_child_hyper_primitives)
|
173
|
+
utils.traverse(self._value, _extract_immediate_child_hyper_primitives)
|
172
174
|
self._hyper_primitives = hyper_primitives
|
173
175
|
|
174
176
|
@property
|
@@ -186,15 +188,15 @@ class ObjectTemplate(base.HyperValue, object_utils.Formattable):
|
|
186
188
|
"""Returns whether current template is constant value."""
|
187
189
|
return not self._hyper_primitives
|
188
190
|
|
189
|
-
def dna_spec(
|
190
|
-
self, location: Optional[object_utils.KeyPath] = None) -> geno.Space:
|
191
|
+
def dna_spec(self, location: Optional[utils.KeyPath] = None) -> geno.Space:
|
191
192
|
"""Return DNA spec (geno.Space) from this template."""
|
192
193
|
return geno.Space(
|
193
194
|
elements=[
|
194
195
|
primitive.dna_spec(primitive_location)
|
195
196
|
for primitive_location, primitive in self._hyper_primitives
|
196
197
|
],
|
197
|
-
location=location or
|
198
|
+
location=location or utils.KeyPath(),
|
199
|
+
)
|
198
200
|
|
199
201
|
def _decode(self) -> Any:
|
200
202
|
"""Decode DNA into a value."""
|
@@ -202,9 +204,10 @@ class ObjectTemplate(base.HyperValue, object_utils.Formattable):
|
|
202
204
|
assert dna is not None
|
203
205
|
if not self._hyper_primitives and (dna.value is not None or dna.children):
|
204
206
|
raise ValueError(
|
205
|
-
|
206
|
-
f'Encountered extra DNA value to decode: {dna!r}',
|
207
|
-
|
207
|
+
utils.message_on_path(
|
208
|
+
f'Encountered extra DNA value to decode: {dna!r}', self._root_path
|
209
|
+
)
|
210
|
+
)
|
208
211
|
|
209
212
|
# Compute hyper primitive values first.
|
210
213
|
rebind_dict = {}
|
@@ -214,11 +217,14 @@ class ObjectTemplate(base.HyperValue, object_utils.Formattable):
|
|
214
217
|
else:
|
215
218
|
if len(dna.children) != len(self._hyper_primitives):
|
216
219
|
raise ValueError(
|
217
|
-
|
220
|
+
utils.message_on_path(
|
218
221
|
f'The length of child values ({len(dna.children)}) is '
|
219
|
-
|
222
|
+
'different from the number of hyper primitives '
|
220
223
|
f'({len(self._hyper_primitives)}) in ObjectTemplate. '
|
221
|
-
f'DNA={dna!r}, ObjectTemplate={self!r}.',
|
224
|
+
f'DNA={dna!r}, ObjectTemplate={self!r}.',
|
225
|
+
self._root_path,
|
226
|
+
)
|
227
|
+
)
|
222
228
|
for i, (primitive_location, primitive) in enumerate(
|
223
229
|
self._hyper_primitives):
|
224
230
|
rebind_dict[primitive_location.path] = (
|
@@ -247,18 +253,18 @@ class ObjectTemplate(base.HyperValue, object_utils.Formattable):
|
|
247
253
|
# TODO(daiyip): Currently derived value parsing is done at decode time,
|
248
254
|
# which can be optimized by moving to template creation time.
|
249
255
|
derived_values = []
|
250
|
-
def _extract_derived_values(
|
251
|
-
path: object_utils.KeyPath, value: Any) -> bool:
|
256
|
+
def _extract_derived_values(path: utils.KeyPath, value: Any) -> bool:
|
252
257
|
"""Extract top-level primitives."""
|
253
258
|
if isinstance(value, derived.DerivedValue):
|
254
259
|
derived_values.append((path, value))
|
255
260
|
elif isinstance(value, symbolic.Object):
|
256
261
|
for k, v in value.sym_items():
|
257
|
-
|
258
|
-
v, _extract_derived_values,
|
259
|
-
|
262
|
+
utils.traverse(
|
263
|
+
v, _extract_derived_values, root_path=utils.KeyPath(k, path)
|
264
|
+
)
|
260
265
|
return True
|
261
|
-
|
266
|
+
|
267
|
+
utils.traverse(value, _extract_derived_values)
|
262
268
|
|
263
269
|
if derived_values:
|
264
270
|
if not copied:
|
@@ -299,9 +305,9 @@ class ObjectTemplate(base.HyperValue, object_utils.Formattable):
|
|
299
305
|
ValueError if value cannot be encoded by this template.
|
300
306
|
"""
|
301
307
|
children = []
|
302
|
-
def _encode(
|
303
|
-
|
304
|
-
|
308
|
+
def _encode(
|
309
|
+
path: utils.KeyPath, template_value: Any, input_value: Any
|
310
|
+
) -> Any:
|
305
311
|
"""Encode input value according to template value."""
|
306
312
|
if (pg_typing.MISSING_VALUE == input_value
|
307
313
|
and pg_typing.MISSING_VALUE != template_value):
|
@@ -339,10 +345,12 @@ class ObjectTemplate(base.HyperValue, object_utils.Formattable):
|
|
339
345
|
f'TemplateOnlyKeys={template_keys - value_keys}, '
|
340
346
|
f'InputOnlyKeys={value_keys - template_keys})')
|
341
347
|
for key in template_value.sym_keys():
|
342
|
-
|
348
|
+
utils.merge_tree(
|
343
349
|
template_value.sym_getattr(key),
|
344
350
|
input_value.sym_getattr(key),
|
345
|
-
_encode,
|
351
|
+
_encode,
|
352
|
+
root_path=utils.KeyPath(key, path),
|
353
|
+
)
|
346
354
|
elif isinstance(template_value, symbolic.Dict):
|
347
355
|
# Do nothing since merge will iterate all elements in dict and list.
|
348
356
|
if not isinstance(input_value, dict):
|
@@ -358,19 +366,23 @@ class ObjectTemplate(base.HyperValue, object_utils.Formattable):
|
|
358
366
|
f'value. (Path=\'{path}\', Template={template_value!r}, '
|
359
367
|
f'Input={input_value!r})')
|
360
368
|
for i, template_item in enumerate(template_value):
|
361
|
-
|
362
|
-
template_item,
|
363
|
-
|
369
|
+
utils.merge_tree(
|
370
|
+
template_item,
|
371
|
+
input_value[i],
|
372
|
+
_encode,
|
373
|
+
root_path=utils.KeyPath(i, path),
|
374
|
+
)
|
364
375
|
else:
|
365
376
|
if template_value != input_value:
|
366
377
|
raise ValueError(
|
367
|
-
|
368
|
-
f
|
369
|
-
f'Template={
|
370
|
-
f'Input={
|
378
|
+
'Unmatched value between template and input. '
|
379
|
+
f"(Path='{path}', "
|
380
|
+
f'Template={utils.quote_if_str(template_value)}, '
|
381
|
+
f'Input={utils.quote_if_str(input_value)})'
|
382
|
+
)
|
371
383
|
return template_value
|
372
|
-
|
373
|
-
|
384
|
+
|
385
|
+
utils.merge_tree(self._value, value, _encode, root_path=self._root_path)
|
374
386
|
return geno.DNA(None, children)
|
375
387
|
|
376
388
|
def try_encode(self, value: Any) -> Tuple[bool, geno.DNA]:
|
@@ -399,18 +411,18 @@ class ObjectTemplate(base.HyperValue, object_utils.Formattable):
|
|
399
411
|
root_indent: int = 0,
|
400
412
|
**kwargs) -> str:
|
401
413
|
"""Format this object."""
|
402
|
-
details =
|
403
|
-
self._value, compact, verbose, root_indent, **kwargs)
|
414
|
+
details = utils.format(self._value, compact, verbose, root_indent, **kwargs)
|
404
415
|
return f'{self.__class__.__name__}(value={details})'
|
405
416
|
|
406
417
|
def custom_apply(
|
407
418
|
self,
|
408
|
-
path:
|
419
|
+
path: utils.KeyPath,
|
409
420
|
value_spec: pg_typing.ValueSpec,
|
410
421
|
allow_partial: bool,
|
411
|
-
child_transform: Optional[
|
412
|
-
[
|
413
|
-
|
422
|
+
child_transform: Optional[
|
423
|
+
Callable[[utils.KeyPath, pg_typing.Field, Any], Any]
|
424
|
+
] = None,
|
425
|
+
) -> Tuple[bool, 'ObjectTemplate']:
|
414
426
|
"""Validate candidates during value_spec binding time."""
|
415
427
|
# Check if value_spec directly accepts `self`.
|
416
428
|
if not value_spec.value_type or not isinstance(self, value_spec.value_type):
|
pyglove/core/logging_test.py
CHANGED
@@ -11,8 +11,6 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for pyglove.object_utils."""
|
15
|
-
|
16
14
|
import io
|
17
15
|
import logging
|
18
16
|
import unittest
|
@@ -14,8 +14,8 @@
|
|
14
14
|
"""Object factory based on patchers."""
|
15
15
|
|
16
16
|
from typing import Any, Callable, Dict, Optional, Type, Union
|
17
|
-
from pyglove.core import object_utils
|
18
17
|
from pyglove.core import symbolic
|
18
|
+
from pyglove.core import utils
|
19
19
|
from pyglove.core.patching import rule_based
|
20
20
|
|
21
21
|
|
@@ -88,7 +88,7 @@ def ObjectFactory( # pylint: disable=invalid-name
|
|
88
88
|
# Step 3: Patch with additional parameter override dict if available.
|
89
89
|
if params_override:
|
90
90
|
value = value.rebind(
|
91
|
-
|
92
|
-
raise_on_no_change=False
|
91
|
+
utils.flatten(from_maybe_serialized(params_override, dict)),
|
92
|
+
raise_on_no_change=False,
|
93
|
+
)
|
93
94
|
return value
|
94
|
-
|
@@ -15,8 +15,8 @@
|
|
15
15
|
|
16
16
|
import re
|
17
17
|
from typing import Any, Callable, Optional, Tuple, Type, Union
|
18
|
-
from pyglove.core import object_utils
|
19
18
|
from pyglove.core import symbolic
|
19
|
+
from pyglove.core import utils
|
20
20
|
|
21
21
|
|
22
22
|
def patch_on_key(
|
@@ -214,11 +214,11 @@ def patch_on_member(
|
|
214
214
|
|
215
215
|
def _conditional_patch(
|
216
216
|
src: symbolic.Symbolic,
|
217
|
-
condition: Callable[
|
218
|
-
[object_utils.KeyPath, Any, symbolic.Symbolic], bool],
|
217
|
+
condition: Callable[[utils.KeyPath, Any, symbolic.Symbolic], bool],
|
219
218
|
value: Any = None,
|
220
219
|
value_fn: Optional[Callable[[Any], Any]] = None,
|
221
|
-
skip_notification: Optional[bool] = None
|
220
|
+
skip_notification: Optional[bool] = None,
|
221
|
+
) -> Any:
|
222
222
|
"""Recursive patch values on condition.
|
223
223
|
|
224
224
|
Args:
|
@@ -16,9 +16,9 @@
|
|
16
16
|
import re
|
17
17
|
import typing
|
18
18
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
19
|
-
from pyglove.core import object_utils
|
20
19
|
from pyglove.core import symbolic
|
21
20
|
from pyglove.core import typing as pg_typing
|
21
|
+
from pyglove.core import utils
|
22
22
|
|
23
23
|
|
24
24
|
class Patcher(symbolic.Functor):
|
@@ -350,7 +350,7 @@ def from_uri(uri: str) -> Patcher:
|
|
350
350
|
name, args, kwargs = parse_uri(uri)
|
351
351
|
patcher_cls = typing.cast(Type[Any], _PATCHER_REGISTRY.get(name))
|
352
352
|
args, kwargs = parse_args(patcher_cls.__signature__, args, kwargs)
|
353
|
-
return patcher_cls(
|
353
|
+
return patcher_cls(utils.MISSING_VALUE, *args, **kwargs)
|
354
354
|
|
355
355
|
|
356
356
|
def parse_uri(uri: str) -> Tuple[str, List[str], Dict[str, str]]:
|
@@ -467,7 +467,8 @@ def parse_arg(patcher_id: str, arg_name: str,
|
|
467
467
|
f'{value_spec!r} cannot be used for Patcher argument.\n'
|
468
468
|
f'Consider to treat this argument as string and parse it yourself.')
|
469
469
|
return value_spec.apply(
|
470
|
-
arg, root_path=
|
470
|
+
arg, root_path=utils.KeyPath.parse(f'{patcher_id}.{arg_name}')
|
471
|
+
)
|
471
472
|
|
472
473
|
|
473
474
|
def parse_list(string: str,
|