pyglove 0.4.5.dev202501050808__py3-none-any.whl → 0.4.5.dev202501060809__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyglove/core/__init__.py +24 -21
- pyglove/core/geno/base.py +53 -38
- pyglove/core/geno/base_test.py +2 -4
- pyglove/core/geno/categorical.py +36 -27
- pyglove/core/geno/custom.py +18 -15
- pyglove/core/geno/numerical.py +19 -16
- pyglove/core/geno/space.py +3 -4
- pyglove/core/hyper/base.py +6 -6
- pyglove/core/hyper/categorical.py +91 -52
- pyglove/core/hyper/custom.py +7 -7
- pyglove/core/hyper/custom_test.py +9 -10
- pyglove/core/hyper/derived.py +30 -22
- pyglove/core/hyper/derived_test.py +2 -4
- pyglove/core/hyper/dynamic_evaluation.py +3 -4
- pyglove/core/hyper/evolvable.py +57 -46
- pyglove/core/hyper/numerical.py +48 -24
- pyglove/core/hyper/numerical_test.py +9 -9
- pyglove/core/hyper/object_template.py +58 -46
- pyglove/core/logging_test.py +0 -2
- pyglove/core/patching/object_factory.py +4 -4
- pyglove/core/patching/pattern_based.py +4 -4
- pyglove/core/patching/rule_based.py +4 -3
- pyglove/core/symbolic/base.py +167 -131
- pyglove/core/symbolic/base_test.py +17 -19
- pyglove/core/symbolic/boilerplate.py +4 -5
- pyglove/core/symbolic/class_wrapper.py +9 -9
- pyglove/core/symbolic/compounding.py +2 -2
- pyglove/core/symbolic/compounding_test.py +2 -4
- pyglove/core/symbolic/dict.py +70 -54
- pyglove/core/symbolic/dict_test.py +117 -100
- pyglove/core/symbolic/diff.py +12 -12
- pyglove/core/symbolic/flags.py +1 -1
- pyglove/core/symbolic/functor.py +16 -15
- pyglove/core/symbolic/functor_test.py +2 -4
- pyglove/core/symbolic/inferred.py +2 -2
- pyglove/core/symbolic/list.py +70 -47
- pyglove/core/symbolic/list_test.py +117 -98
- pyglove/core/symbolic/object.py +42 -40
- pyglove/core/symbolic/object_test.py +95 -88
- pyglove/core/symbolic/origin.py +5 -7
- pyglove/core/symbolic/pure_symbolic.py +4 -3
- pyglove/core/symbolic/ref.py +12 -8
- pyglove/core/tuning/local_backend.py +2 -2
- pyglove/core/tuning/protocols.py +3 -3
- pyglove/core/typing/annotation_conversion.py +3 -3
- pyglove/core/typing/callable_ext.py +11 -13
- pyglove/core/typing/callable_signature.py +19 -18
- pyglove/core/typing/callable_signature_test.py +3 -5
- pyglove/core/typing/class_schema.py +48 -44
- pyglove/core/typing/class_schema_test.py +3 -5
- pyglove/core/typing/custom_typing.py +5 -4
- pyglove/core/typing/key_specs.py +5 -7
- pyglove/core/typing/key_specs_test.py +4 -4
- pyglove/core/typing/type_conversion.py +4 -5
- pyglove/core/typing/type_conversion_test.py +12 -12
- pyglove/core/typing/typed_missing.py +6 -7
- pyglove/core/typing/typed_missing_test.py +7 -8
- pyglove/core/typing/value_specs.py +210 -141
- pyglove/core/typing/value_specs_test.py +12 -13
- pyglove/core/utils/__init__.py +159 -0
- pyglove/core/{object_utils → utils}/common_traits_test.py +1 -3
- pyglove/core/{object_utils → utils}/docstr_utils_test.py +1 -3
- pyglove/core/{object_utils → utils}/error_utils.py +3 -3
- pyglove/core/{object_utils → utils}/error_utils_test.py +1 -1
- pyglove/core/{object_utils → utils}/formatting.py +1 -1
- pyglove/core/{object_utils → utils}/formatting_test.py +1 -2
- pyglove/core/{object_utils → utils}/hierarchical.py +23 -25
- pyglove/core/{object_utils → utils}/hierarchical_test.py +3 -5
- pyglove/core/{object_utils → utils}/json_conversion_test.py +1 -3
- pyglove/core/{object_utils → utils}/missing.py +2 -2
- pyglove/core/{object_utils → utils}/missing_test.py +2 -4
- pyglove/core/{object_utils → utils}/thread_local_test.py +1 -3
- pyglove/core/{object_utils → utils}/timing.py +3 -3
- pyglove/core/{object_utils → utils}/timing_test.py +2 -3
- pyglove/core/{object_utils → utils}/value_location.py +2 -2
- pyglove/core/{object_utils → utils}/value_location_test.py +2 -4
- pyglove/core/views/base.py +25 -29
- pyglove/core/views/html/base.py +14 -15
- pyglove/core/views/html/controls/base.py +5 -5
- pyglove/core/views/html/controls/progress_bar.py +3 -5
- pyglove/core/views/html/tree_view.py +37 -35
- {pyglove-0.4.5.dev202501050808.dist-info → pyglove-0.4.5.dev202501060809.dist-info}/METADATA +1 -1
- {pyglove-0.4.5.dev202501050808.dist-info → pyglove-0.4.5.dev202501060809.dist-info}/RECORD +90 -90
- {pyglove-0.4.5.dev202501050808.dist-info → pyglove-0.4.5.dev202501060809.dist-info}/WHEEL +1 -1
- pyglove/core/object_utils/__init__.py +0 -161
- /pyglove/core/{object_utils → utils}/common_traits.py +0 -0
- /pyglove/core/{object_utils → utils}/docstr_utils.py +0 -0
- /pyglove/core/{object_utils → utils}/json_conversion.py +0 -0
- /pyglove/core/{object_utils → utils}/thread_local.py +0 -0
- {pyglove-0.4.5.dev202501050808.dist-info → pyglove-0.4.5.dev202501060809.dist-info}/LICENSE +0 -0
- {pyglove-0.4.5.dev202501050808.dist-info → pyglove-0.4.5.dev202501060809.dist-info}/top_level.txt +0 -0
@@ -18,9 +18,9 @@ import typing
|
|
18
18
|
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
|
19
19
|
|
20
20
|
from pyglove.core import geno
|
21
|
-
from pyglove.core import object_utils
|
22
21
|
from pyglove.core import symbolic
|
23
22
|
from pyglove.core import typing as pg_typing
|
23
|
+
from pyglove.core import utils
|
24
24
|
from pyglove.core.hyper import base
|
25
25
|
from pyglove.core.hyper import object_template
|
26
26
|
|
@@ -85,7 +85,8 @@ class Choices(base.HyperPrimitive):
|
|
85
85
|
self._value_spec = None
|
86
86
|
|
87
87
|
def _update_children_paths(
|
88
|
-
self, old_path:
|
88
|
+
self, old_path: utils.KeyPath, new_path: utils.KeyPath
|
89
|
+
):
|
89
90
|
"""Customized logic to update children paths."""
|
90
91
|
super()._update_children_paths(old_path, new_path)
|
91
92
|
for t in self._candidate_templates:
|
@@ -104,19 +105,20 @@ class Choices(base.HyperPrimitive):
|
|
104
105
|
return False
|
105
106
|
return True
|
106
107
|
|
107
|
-
def dna_spec(self,
|
108
|
-
location: Optional[object_utils.KeyPath] = None) -> geno.Choices:
|
108
|
+
def dna_spec(self, location: Optional[utils.KeyPath] = None) -> geno.Choices:
|
109
109
|
"""Returns corresponding DNASpec."""
|
110
110
|
return geno.Choices(
|
111
111
|
num_choices=self.num_choices,
|
112
112
|
candidates=[ct.dna_spec() for ct in self._candidate_templates],
|
113
|
-
literal_values=[
|
114
|
-
|
113
|
+
literal_values=[
|
114
|
+
self._literal_value(c) for i, c in enumerate(self.candidates)
|
115
|
+
],
|
115
116
|
distinct=self.choices_distinct,
|
116
117
|
sorted=self.choices_sorted,
|
117
118
|
hints=self.hints,
|
118
119
|
name=self.name,
|
119
|
-
location=location or
|
120
|
+
location=location or utils.KeyPath(),
|
121
|
+
)
|
120
122
|
|
121
123
|
def _literal_value(
|
122
124
|
self, candidate: Any, max_len: int = 120) -> Union[int, float, str]:
|
@@ -124,10 +126,13 @@ class Choices(base.HyperPrimitive):
|
|
124
126
|
if isinstance(candidate, numbers.Number):
|
125
127
|
return candidate
|
126
128
|
|
127
|
-
literal =
|
128
|
-
|
129
|
-
|
130
|
-
|
129
|
+
literal = utils.format(
|
130
|
+
candidate,
|
131
|
+
compact=True,
|
132
|
+
hide_default_values=True,
|
133
|
+
hide_missing_values=True,
|
134
|
+
strip_object_id=True,
|
135
|
+
)
|
131
136
|
if len(literal) > max_len:
|
132
137
|
literal = literal[:max_len - 3] + '...'
|
133
138
|
return literal
|
@@ -139,52 +144,70 @@ class Choices(base.HyperPrimitive):
|
|
139
144
|
# Single choice.
|
140
145
|
if not isinstance(dna.value, int):
|
141
146
|
raise ValueError(
|
142
|
-
|
143
|
-
|
147
|
+
utils.message_on_path(
|
148
|
+
'Did you forget to specify values for conditional choices?\n'
|
144
149
|
f'Expect integer for {self.__class__.__name__}. '
|
145
|
-
f'Encountered: {dna!r}.',
|
150
|
+
f'Encountered: {dna!r}.',
|
151
|
+
self.sym_path,
|
152
|
+
)
|
153
|
+
)
|
146
154
|
if dna.value >= len(self.candidates):
|
147
155
|
raise ValueError(
|
148
|
-
|
156
|
+
utils.message_on_path(
|
149
157
|
f'Choice out of range. Value: {dna.value!r}, '
|
150
|
-
f'Candidates: {len(self.candidates)}.',
|
158
|
+
f'Candidates: {len(self.candidates)}.',
|
159
|
+
self.sym_path,
|
160
|
+
)
|
161
|
+
)
|
151
162
|
choices = [self._candidate_templates[dna.value].decode(
|
152
163
|
geno.DNA(None, dna.children))]
|
153
164
|
else:
|
154
165
|
# Multi choices.
|
155
166
|
if len(dna.children) != self.num_choices:
|
156
167
|
raise ValueError(
|
157
|
-
|
158
|
-
|
168
|
+
utils.message_on_path(
|
169
|
+
'Number of DNA child values does not match the number of '
|
159
170
|
f'choices. Child values: {dna.children!r}, '
|
160
|
-
f'Choices: {self.num_choices}.',
|
171
|
+
f'Choices: {self.num_choices}.',
|
172
|
+
self.sym_path,
|
173
|
+
)
|
174
|
+
)
|
161
175
|
if self.choices_distinct or self.choices_sorted:
|
162
176
|
sub_dna_values = [s.value for s in dna]
|
163
177
|
if (self.choices_distinct
|
164
178
|
and len(set(sub_dna_values)) != len(dna.children)):
|
165
179
|
raise ValueError(
|
166
|
-
|
167
|
-
|
168
|
-
f'Encountered: {sub_dna_values}.',
|
180
|
+
utils.message_on_path(
|
181
|
+
'DNA child values should be distinct. '
|
182
|
+
f'Encountered: {sub_dna_values}.',
|
183
|
+
self.sym_path,
|
184
|
+
)
|
185
|
+
)
|
169
186
|
if self.choices_sorted and sorted(sub_dna_values) != sub_dna_values:
|
170
187
|
raise ValueError(
|
171
|
-
|
172
|
-
|
173
|
-
f'Encountered: {sub_dna_values}.',
|
188
|
+
utils.message_on_path(
|
189
|
+
'DNA child values should be sorted. '
|
190
|
+
f'Encountered: {sub_dna_values}.',
|
191
|
+
self.sym_path,
|
192
|
+
)
|
193
|
+
)
|
174
194
|
choices = []
|
175
195
|
for i, sub_dna in enumerate(dna):
|
176
196
|
if not isinstance(sub_dna.value, int):
|
177
197
|
raise ValueError(
|
178
|
-
|
179
|
-
f'Choice value should be int. '
|
180
|
-
|
181
|
-
|
198
|
+
utils.message_on_path(
|
199
|
+
f'Choice value should be int. Encountered: {sub_dna.value}.',
|
200
|
+
utils.KeyPath(i, self.sym_path),
|
201
|
+
)
|
202
|
+
)
|
182
203
|
if sub_dna.value >= len(self.candidates):
|
183
204
|
raise ValueError(
|
184
|
-
|
205
|
+
utils.message_on_path(
|
185
206
|
f'Choice out of range. Value: {sub_dna.value}, '
|
186
207
|
f'Candidates: {len(self.candidates)}.',
|
187
|
-
|
208
|
+
utils.KeyPath(i, self.sym_path),
|
209
|
+
)
|
210
|
+
)
|
188
211
|
choices.append(self._candidate_templates[sub_dna.value].decode(
|
189
212
|
geno.DNA(None, sub_dna.children)))
|
190
213
|
return choices
|
@@ -240,15 +263,21 @@ class Choices(base.HyperPrimitive):
|
|
240
263
|
"""
|
241
264
|
if not isinstance(value, list):
|
242
265
|
raise ValueError(
|
243
|
-
|
244
|
-
|
245
|
-
f'Encountered: {value!r}.',
|
266
|
+
utils.message_on_path(
|
267
|
+
'Cannot encode value: value should be a list type. '
|
268
|
+
f'Encountered: {value!r}.',
|
269
|
+
self.sym_path,
|
270
|
+
)
|
271
|
+
)
|
246
272
|
choices = []
|
247
273
|
if self.num_choices is not None and len(value) != self.num_choices:
|
248
274
|
raise ValueError(
|
249
|
-
|
250
|
-
|
251
|
-
f'({self.num_choices}). Encountered: {value}.',
|
275
|
+
utils.message_on_path(
|
276
|
+
'Length of input list is different from the number of choices '
|
277
|
+
f'({self.num_choices}). Encountered: {value}.',
|
278
|
+
self.sym_path,
|
279
|
+
)
|
280
|
+
)
|
252
281
|
for v in value:
|
253
282
|
choice_id = None
|
254
283
|
child_dna = None
|
@@ -259,10 +288,12 @@ class Choices(base.HyperPrimitive):
|
|
259
288
|
break
|
260
289
|
if child_dna is None:
|
261
290
|
raise ValueError(
|
262
|
-
|
263
|
-
|
291
|
+
utils.message_on_path(
|
292
|
+
'Cannot encode value: no candidates matches with '
|
264
293
|
f'the value. Value: {v!r}, Candidates: {self.candidates}.',
|
265
|
-
self.sym_path
|
294
|
+
self.sym_path,
|
295
|
+
)
|
296
|
+
)
|
266
297
|
choices.append(geno.DNA(choice_id, [child_dna]))
|
267
298
|
return geno.DNA(None, choices)
|
268
299
|
|
@@ -313,12 +344,13 @@ class ManyOf(Choices):
|
|
313
344
|
|
314
345
|
def custom_apply(
|
315
346
|
self,
|
316
|
-
path:
|
347
|
+
path: utils.KeyPath,
|
317
348
|
value_spec: pg_typing.ValueSpec,
|
318
349
|
allow_partial: bool,
|
319
|
-
child_transform: Optional[
|
320
|
-
[
|
321
|
-
|
350
|
+
child_transform: Optional[
|
351
|
+
Callable[[utils.KeyPath, pg_typing.Field, Any], Any]
|
352
|
+
] = None,
|
353
|
+
) -> Tuple[bool, 'Choices']:
|
322
354
|
"""Validate candidates during value_spec binding time."""
|
323
355
|
# Check if value_spec directly accepts `self`.
|
324
356
|
if value_spec.value_type and isinstance(self, value_spec.value_type):
|
@@ -329,10 +361,12 @@ class ManyOf(Choices):
|
|
329
361
|
dest_spec = value_spec
|
330
362
|
if not dest_spec.is_compatible(src_spec):
|
331
363
|
raise TypeError(
|
332
|
-
|
364
|
+
utils.message_on_path(
|
333
365
|
f'Cannot bind an incompatible value spec {dest_spec!r} '
|
334
366
|
f'to {self.__class__.__name__} with bound spec {src_spec!r}.',
|
335
|
-
path
|
367
|
+
path,
|
368
|
+
)
|
369
|
+
)
|
336
370
|
return (False, self)
|
337
371
|
|
338
372
|
list_spec = typing.cast(
|
@@ -399,12 +433,13 @@ class OneOf(Choices):
|
|
399
433
|
|
400
434
|
def custom_apply(
|
401
435
|
self,
|
402
|
-
path:
|
436
|
+
path: utils.KeyPath,
|
403
437
|
value_spec: pg_typing.ValueSpec,
|
404
438
|
allow_partial: bool,
|
405
|
-
child_transform: Optional[
|
406
|
-
[
|
407
|
-
|
439
|
+
child_transform: Optional[
|
440
|
+
Callable[[utils.KeyPath, pg_typing.Field, Any], Any]
|
441
|
+
] = None,
|
442
|
+
) -> Tuple[bool, 'OneOf']:
|
408
443
|
"""Validate candidates during value_spec binding time."""
|
409
444
|
# Check if value_spec directly accepts `self`.
|
410
445
|
if value_spec.value_type and isinstance(self, value_spec.value_type):
|
@@ -413,10 +448,13 @@ class OneOf(Choices):
|
|
413
448
|
if self._value_spec:
|
414
449
|
if not value_spec.is_compatible(self._value_spec):
|
415
450
|
raise TypeError(
|
416
|
-
|
451
|
+
utils.message_on_path(
|
417
452
|
f'Cannot bind an incompatible value spec {value_spec!r} '
|
418
453
|
f'to {self.__class__.__name__} with bound '
|
419
|
-
f'spec {self._value_spec!r}.',
|
454
|
+
f'spec {self._value_spec!r}.',
|
455
|
+
path,
|
456
|
+
)
|
457
|
+
)
|
420
458
|
return (False, self)
|
421
459
|
|
422
460
|
for i, c in enumerate(self.candidates):
|
@@ -427,6 +465,7 @@ class OneOf(Choices):
|
|
427
465
|
self._value_spec = value_spec
|
428
466
|
return (False, self)
|
429
467
|
|
468
|
+
|
430
469
|
#
|
431
470
|
# Helper methods for creating hyper values.
|
432
471
|
#
|
pyglove/core/hyper/custom.py
CHANGED
@@ -19,8 +19,8 @@ import types
|
|
19
19
|
from typing import Any, Callable, Optional, Tuple, Union
|
20
20
|
|
21
21
|
from pyglove.core import geno
|
22
|
-
from pyglove.core import object_utils
|
23
22
|
from pyglove.core import typing as pg_typing
|
23
|
+
from pyglove.core import utils
|
24
24
|
from pyglove.core.hyper import base
|
25
25
|
|
26
26
|
|
@@ -111,8 +111,7 @@ class CustomHyper(base.HyperPrimitive):
|
|
111
111
|
raise NotImplementedError(
|
112
112
|
f'\'custom_encode\' is not supported by {self.__class__.__name__!r}.')
|
113
113
|
|
114
|
-
def dna_spec(
|
115
|
-
self, location: Optional[object_utils.KeyPath] = None) -> geno.DNASpec:
|
114
|
+
def dna_spec(self, location: Optional[utils.KeyPath] = None) -> geno.DNASpec:
|
116
115
|
"""Always returns CustomDecisionPoint for CustomHyper."""
|
117
116
|
return geno.CustomDecisionPoint(
|
118
117
|
hyper_type=self.__class__.__name__,
|
@@ -147,12 +146,13 @@ class CustomHyper(base.HyperPrimitive):
|
|
147
146
|
|
148
147
|
def custom_apply(
|
149
148
|
self,
|
150
|
-
path:
|
149
|
+
path: utils.KeyPath,
|
151
150
|
value_spec: pg_typing.ValueSpec,
|
152
151
|
allow_partial: bool,
|
153
|
-
child_transform: Optional[
|
154
|
-
[
|
155
|
-
|
152
|
+
child_transform: Optional[
|
153
|
+
Callable[[utils.KeyPath, pg_typing.Field, Any], Any]
|
154
|
+
] = None,
|
155
|
+
) -> Tuple[bool, 'CustomHyper']:
|
156
156
|
"""Validate candidates during value_spec binding time."""
|
157
157
|
del path, value_spec, allow_partial, child_transform
|
158
158
|
# Allow custom hyper to be assigned to any type.
|
@@ -11,15 +11,12 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for pyglove.hyper.CustomHyper."""
|
15
|
-
|
16
14
|
import random
|
17
15
|
import unittest
|
18
16
|
|
19
17
|
from pyglove.core import geno
|
20
|
-
from pyglove.core import object_utils
|
21
18
|
from pyglove.core import symbolic
|
22
|
-
|
19
|
+
from pyglove.core import utils
|
23
20
|
from pyglove.core.hyper.categorical import oneof
|
24
21
|
from pyglove.core.hyper.custom import CustomHyper
|
25
22
|
from pyglove.core.hyper.iter import iterate
|
@@ -58,12 +55,14 @@ class CustomHyperTest(unittest.TestCase):
|
|
58
55
|
"""Test for CustomHyper."""
|
59
56
|
|
60
57
|
def test_dna_spec(self):
|
61
|
-
self.assertTrue(
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
58
|
+
self.assertTrue(
|
59
|
+
symbolic.eq(
|
60
|
+
IntSequence(hints='x').dna_spec('a'),
|
61
|
+
geno.CustomDecisionPoint(
|
62
|
+
hyper_type='IntSequence', location=utils.KeyPath('a'), hints='x'
|
63
|
+
),
|
64
|
+
)
|
65
|
+
)
|
67
66
|
|
68
67
|
def test_decode(self):
|
69
68
|
self.assertEqual(IntSequence().decode(geno.DNA('0,1,2')), [0, 1, 2])
|
pyglove/core/hyper/derived.py
CHANGED
@@ -17,16 +17,19 @@ import abc
|
|
17
17
|
import copy
|
18
18
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
19
19
|
|
20
|
-
from pyglove.core import object_utils
|
21
20
|
from pyglove.core import symbolic
|
22
21
|
from pyglove.core import typing as pg_typing
|
22
|
+
from pyglove.core import utils
|
23
23
|
|
24
24
|
|
25
|
-
@symbolic.members([
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
25
|
+
@symbolic.members([(
|
26
|
+
'reference_paths',
|
27
|
+
pg_typing.List(pg_typing.Object(utils.KeyPath)),
|
28
|
+
(
|
29
|
+
'Paths of referenced values, which are relative paths searched from '
|
30
|
+
'current node to root.'
|
31
|
+
),
|
32
|
+
)])
|
30
33
|
class DerivedValue(symbolic.Object, pg_typing.CustomTyping):
|
31
34
|
"""Base class of value that references to other values in object tree."""
|
32
35
|
|
@@ -36,8 +39,10 @@ class DerivedValue(symbolic.Object, pg_typing.CustomTyping):
|
|
36
39
|
|
37
40
|
def resolve(
|
38
41
|
self, reference_path_or_paths: Optional[Union[str, List[str]]] = None
|
39
|
-
|
40
|
-
|
42
|
+
) -> Union[
|
43
|
+
Tuple[symbolic.Symbolic, utils.KeyPath],
|
44
|
+
List[Tuple[symbolic.Symbolic, utils.KeyPath]],
|
45
|
+
]:
|
41
46
|
"""Resolve reference paths based on the location of this node.
|
42
47
|
|
43
48
|
Args:
|
@@ -54,17 +59,17 @@ class DerivedValue(symbolic.Object, pg_typing.CustomTyping):
|
|
54
59
|
if reference_path_or_paths is None:
|
55
60
|
reference_paths = self.reference_paths
|
56
61
|
elif isinstance(reference_path_or_paths, str):
|
57
|
-
reference_paths = [
|
62
|
+
reference_paths = [utils.KeyPath.parse(reference_path_or_paths)]
|
58
63
|
single_input = True
|
59
|
-
elif isinstance(reference_path_or_paths,
|
64
|
+
elif isinstance(reference_path_or_paths, utils.KeyPath):
|
60
65
|
reference_paths = [reference_path_or_paths]
|
61
66
|
single_input = True
|
62
67
|
elif isinstance(reference_path_or_paths, list):
|
63
68
|
paths = []
|
64
69
|
for path in reference_path_or_paths:
|
65
70
|
if isinstance(path, str):
|
66
|
-
path =
|
67
|
-
elif not isinstance(path,
|
71
|
+
path = utils.KeyPath.parse(path)
|
72
|
+
elif not isinstance(path, utils.KeyPath):
|
68
73
|
raise ValueError('Argument \'reference_path_or_paths\' must be None, '
|
69
74
|
'a string, KeyPath object, a list of strings, or a '
|
70
75
|
'list of KeyPath objects.')
|
@@ -96,8 +101,7 @@ class DerivedValue(symbolic.Object, pg_typing.CustomTyping):
|
|
96
101
|
# Make sure referenced value does not have referenced value.
|
97
102
|
# NOTE(daiyip): We can support dependencies between derived values
|
98
103
|
# in future if needed.
|
99
|
-
if not
|
100
|
-
referenced_value, self._contains_not_derived_value):
|
104
|
+
if not utils.traverse(referenced_value, self._contains_not_derived_value):
|
101
105
|
raise ValueError(
|
102
106
|
f'Derived value (path={referenced_value.sym_path}) should not '
|
103
107
|
f'reference derived values. '
|
@@ -107,15 +111,18 @@ class DerivedValue(symbolic.Object, pg_typing.CustomTyping):
|
|
107
111
|
return self.derive(*referenced_values)
|
108
112
|
|
109
113
|
def _contains_not_derived_value(
|
110
|
-
self, path:
|
114
|
+
self, path: utils.KeyPath, value: Any
|
115
|
+
) -> bool:
|
111
116
|
"""Returns whether a value contains derived value."""
|
112
117
|
if isinstance(value, DerivedValue):
|
113
118
|
return False
|
114
119
|
elif isinstance(value, symbolic.Object):
|
115
120
|
for k, v in value.sym_items():
|
116
|
-
if not
|
117
|
-
v,
|
118
|
-
|
121
|
+
if not utils.traverse(
|
122
|
+
v,
|
123
|
+
self._contains_not_derived_value,
|
124
|
+
root_path=utils.KeyPath(k, path),
|
125
|
+
):
|
119
126
|
return False
|
120
127
|
return True
|
121
128
|
|
@@ -137,12 +144,13 @@ class ValueReference(DerivedValue):
|
|
137
144
|
|
138
145
|
def custom_apply(
|
139
146
|
self,
|
140
|
-
path:
|
147
|
+
path: utils.KeyPath,
|
141
148
|
value_spec: pg_typing.ValueSpec,
|
142
149
|
allow_partial: bool,
|
143
|
-
child_transform: Optional[
|
144
|
-
[
|
145
|
-
|
150
|
+
child_transform: Optional[
|
151
|
+
Callable[[utils.KeyPath, pg_typing.Field, Any], Any]
|
152
|
+
] = None,
|
153
|
+
) -> Tuple[bool, 'DerivedValue']:
|
146
154
|
"""Implement pg_typing.CustomTyping interface."""
|
147
155
|
# TODO(daiyip): perform possible static analysis on referenced paths.
|
148
156
|
del path, value_spec, allow_partial, child_transform
|
@@ -11,13 +11,11 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for pyglove.hyper.ValueReference."""
|
15
|
-
|
16
14
|
import unittest
|
17
15
|
|
18
|
-
from pyglove.core import object_utils
|
19
16
|
from pyglove.core import symbolic
|
20
17
|
from pyglove.core import typing as pg_typing
|
18
|
+
from pyglove.core import utils
|
21
19
|
from pyglove.core.hyper.derived import ValueReference
|
22
20
|
|
23
21
|
|
@@ -47,7 +45,7 @@ class ValueReferenceTest(unittest.TestCase):
|
|
47
45
|
self.assertEqual(sd.c[0].y.resolve(), [(sd.c[0], 'c[0].x[0].z')])
|
48
46
|
self.assertEqual(sd.c[1].y.resolve(), [(sd.c[1], 'c[1].x[0].z')])
|
49
47
|
# Resolve references from this point.
|
50
|
-
self.assertEqual(sd.c[0].y.resolve(
|
48
|
+
self.assertEqual(sd.c[0].y.resolve(utils.KeyPath(0)), (sd.c, 'c[0]'))
|
51
49
|
self.assertEqual(sd.c[0].y.resolve('[0]'), (sd.c, 'c[0]'))
|
52
50
|
self.assertEqual(
|
53
51
|
sd.c[0].y.resolve(['[0]', '[1]']), [(sd.c, 'c[0]'), (sd.c, 'c[1]')])
|
@@ -18,9 +18,9 @@ import types
|
|
18
18
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
19
19
|
|
20
20
|
from pyglove.core import geno
|
21
|
-
from pyglove.core import object_utils
|
22
21
|
from pyglove.core import symbolic
|
23
22
|
from pyglove.core import typing as pg_typing
|
23
|
+
from pyglove.core import utils
|
24
24
|
from pyglove.core.hyper import base
|
25
25
|
from pyglove.core.hyper import categorical
|
26
26
|
from pyglove.core.hyper import custom
|
@@ -520,10 +520,10 @@ class _DynamicEvaluationStack:
|
|
520
520
|
@property
|
521
521
|
def _local_stack(self):
|
522
522
|
"""Returns thread-local stack."""
|
523
|
-
stack =
|
523
|
+
stack = utils.thread_local_get(self._TLS_KEY, None)
|
524
524
|
if stack is None:
|
525
525
|
stack = []
|
526
|
-
|
526
|
+
utils.thread_local_set(self._TLS_KEY, stack)
|
527
527
|
return stack
|
528
528
|
|
529
529
|
def push(self, context: DynamicEvaluationContext):
|
@@ -585,4 +585,3 @@ def trace(
|
|
585
585
|
with context.collect():
|
586
586
|
fun()
|
587
587
|
return context
|
588
|
-
|