pyglove 0.4.5.dev20240319__py3-none-any.whl → 0.4.5.dev202501140808__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyglove/core/__init__.py +54 -20
- pyglove/core/coding/__init__.py +42 -0
- pyglove/core/coding/errors.py +111 -0
- pyglove/core/coding/errors_test.py +98 -0
- pyglove/core/coding/execution.py +309 -0
- pyglove/core/coding/execution_test.py +333 -0
- pyglove/core/{object_utils/codegen.py → coding/function_generation.py} +10 -4
- pyglove/core/{object_utils/codegen_test.py → coding/function_generation_test.py} +5 -7
- pyglove/core/coding/parsing.py +153 -0
- pyglove/core/coding/parsing_test.py +150 -0
- pyglove/core/coding/permissions.py +100 -0
- pyglove/core/coding/permissions_test.py +93 -0
- pyglove/core/geno/base.py +54 -41
- pyglove/core/geno/base_test.py +2 -4
- pyglove/core/geno/categorical.py +37 -28
- pyglove/core/geno/custom.py +19 -16
- pyglove/core/geno/numerical.py +20 -17
- pyglove/core/geno/space.py +4 -5
- pyglove/core/hyper/base.py +6 -6
- pyglove/core/hyper/categorical.py +94 -55
- pyglove/core/hyper/custom.py +7 -7
- pyglove/core/hyper/custom_test.py +9 -10
- pyglove/core/hyper/derived.py +30 -22
- pyglove/core/hyper/derived_test.py +2 -4
- pyglove/core/hyper/dynamic_evaluation.py +5 -6
- pyglove/core/hyper/evolvable.py +57 -46
- pyglove/core/hyper/numerical.py +48 -24
- pyglove/core/hyper/numerical_test.py +9 -9
- pyglove/core/hyper/object_template.py +58 -46
- pyglove/core/io/__init__.py +1 -0
- pyglove/core/io/file_system.py +17 -7
- pyglove/core/io/file_system_test.py +2 -0
- pyglove/core/io/sequence.py +299 -0
- pyglove/core/io/sequence_test.py +124 -0
- pyglove/core/logging_test.py +0 -2
- pyglove/core/patching/object_factory.py +4 -4
- pyglove/core/patching/pattern_based.py +4 -4
- pyglove/core/patching/rule_based.py +17 -5
- pyglove/core/patching/rule_based_test.py +27 -4
- pyglove/core/symbolic/__init__.py +2 -7
- pyglove/core/symbolic/base.py +320 -183
- pyglove/core/symbolic/base_test.py +123 -19
- pyglove/core/symbolic/boilerplate.py +7 -13
- pyglove/core/symbolic/boilerplate_test.py +25 -23
- pyglove/core/symbolic/class_wrapper.py +48 -45
- pyglove/core/symbolic/class_wrapper_test.py +2 -2
- pyglove/core/symbolic/compounding.py +9 -15
- pyglove/core/symbolic/compounding_test.py +2 -4
- pyglove/core/symbolic/dict.py +154 -110
- pyglove/core/symbolic/dict_test.py +238 -130
- pyglove/core/symbolic/diff.py +199 -10
- pyglove/core/symbolic/diff_test.py +226 -0
- pyglove/core/symbolic/flags.py +1 -1
- pyglove/core/symbolic/functor.py +29 -26
- pyglove/core/symbolic/functor_test.py +102 -50
- pyglove/core/symbolic/inferred.py +2 -2
- pyglove/core/symbolic/list.py +81 -50
- pyglove/core/symbolic/list_test.py +119 -97
- pyglove/core/symbolic/object.py +225 -113
- pyglove/core/symbolic/object_test.py +320 -108
- pyglove/core/symbolic/origin.py +17 -14
- pyglove/core/symbolic/origin_test.py +4 -2
- pyglove/core/symbolic/pure_symbolic.py +4 -3
- pyglove/core/symbolic/ref.py +108 -21
- pyglove/core/symbolic/ref_test.py +93 -0
- pyglove/core/symbolic/symbolize_test.py +10 -2
- pyglove/core/tuning/local_backend.py +2 -2
- pyglove/core/tuning/protocols.py +3 -3
- pyglove/core/tuning/sample_test.py +3 -3
- pyglove/core/typing/__init__.py +14 -5
- pyglove/core/typing/annotation_conversion.py +43 -27
- pyglove/core/typing/annotation_conversion_test.py +23 -0
- pyglove/core/typing/callable_ext.py +241 -3
- pyglove/core/typing/callable_ext_test.py +255 -0
- pyglove/core/typing/callable_signature.py +510 -66
- pyglove/core/typing/callable_signature_test.py +619 -99
- pyglove/core/typing/class_schema.py +229 -154
- pyglove/core/typing/class_schema_test.py +149 -95
- pyglove/core/typing/custom_typing.py +5 -4
- pyglove/core/typing/inspect.py +63 -0
- pyglove/core/typing/inspect_test.py +39 -0
- pyglove/core/typing/key_specs.py +10 -11
- pyglove/core/typing/key_specs_test.py +7 -4
- pyglove/core/typing/type_conversion.py +4 -5
- pyglove/core/typing/type_conversion_test.py +12 -12
- pyglove/core/typing/typed_missing.py +6 -7
- pyglove/core/typing/typed_missing_test.py +7 -8
- pyglove/core/typing/value_specs.py +604 -362
- pyglove/core/typing/value_specs_test.py +328 -90
- pyglove/core/utils/__init__.py +164 -0
- pyglove/core/{object_utils → utils}/common_traits.py +3 -67
- pyglove/core/utils/common_traits_test.py +36 -0
- pyglove/core/{object_utils → utils}/docstr_utils.py +23 -0
- pyglove/core/{object_utils → utils}/docstr_utils_test.py +36 -4
- pyglove/core/{object_utils → utils}/error_utils.py +78 -9
- pyglove/core/{object_utils → utils}/error_utils_test.py +61 -5
- pyglove/core/utils/formatting.py +464 -0
- pyglove/core/utils/formatting_test.py +453 -0
- pyglove/core/{object_utils → utils}/hierarchical.py +23 -25
- pyglove/core/{object_utils → utils}/hierarchical_test.py +3 -5
- pyglove/core/{object_utils → utils}/json_conversion.py +177 -52
- pyglove/core/{object_utils → utils}/json_conversion_test.py +97 -16
- pyglove/core/{object_utils → utils}/missing.py +3 -3
- pyglove/core/{object_utils → utils}/missing_test.py +2 -4
- pyglove/core/utils/text_color.py +128 -0
- pyglove/core/utils/text_color_test.py +94 -0
- pyglove/core/{object_utils → utils}/thread_local_test.py +1 -3
- pyglove/core/utils/timing.py +236 -0
- pyglove/core/utils/timing_test.py +154 -0
- pyglove/core/{object_utils → utils}/value_location.py +275 -6
- pyglove/core/utils/value_location_test.py +707 -0
- pyglove/core/views/__init__.py +32 -0
- pyglove/core/views/base.py +804 -0
- pyglove/core/views/base_test.py +580 -0
- pyglove/core/views/html/__init__.py +27 -0
- pyglove/core/views/html/base.py +547 -0
- pyglove/core/views/html/base_test.py +830 -0
- pyglove/core/views/html/controls/__init__.py +35 -0
- pyglove/core/views/html/controls/base.py +275 -0
- pyglove/core/views/html/controls/label.py +207 -0
- pyglove/core/views/html/controls/label_test.py +157 -0
- pyglove/core/views/html/controls/progress_bar.py +183 -0
- pyglove/core/views/html/controls/progress_bar_test.py +97 -0
- pyglove/core/views/html/controls/tab.py +320 -0
- pyglove/core/views/html/controls/tab_test.py +87 -0
- pyglove/core/views/html/controls/tooltip.py +99 -0
- pyglove/core/views/html/controls/tooltip_test.py +99 -0
- pyglove/core/views/html/tree_view.py +1517 -0
- pyglove/core/views/html/tree_view_test.py +1461 -0
- {pyglove-0.4.5.dev20240319.dist-info → pyglove-0.4.5.dev202501140808.dist-info}/METADATA +18 -4
- pyglove-0.4.5.dev202501140808.dist-info/RECORD +214 -0
- {pyglove-0.4.5.dev20240319.dist-info → pyglove-0.4.5.dev202501140808.dist-info}/WHEEL +1 -1
- pyglove/core/object_utils/__init__.py +0 -154
- pyglove/core/object_utils/common_traits_test.py +0 -82
- pyglove/core/object_utils/formatting.py +0 -234
- pyglove/core/object_utils/formatting_test.py +0 -223
- pyglove/core/object_utils/value_location_test.py +0 -385
- pyglove/core/symbolic/schema_utils.py +0 -327
- pyglove/core/symbolic/schema_utils_test.py +0 -57
- pyglove/core/typing/class_schema_utils.py +0 -202
- pyglove/core/typing/class_schema_utils_test.py +0 -194
- pyglove-0.4.5.dev20240319.dist-info/RECORD +0 -185
- /pyglove/core/{object_utils → utils}/thread_local.py +0 -0
- {pyglove-0.4.5.dev20240319.dist-info → pyglove-0.4.5.dev202501140808.dist-info}/LICENSE +0 -0
- {pyglove-0.4.5.dev20240319.dist-info → pyglove-0.4.5.dev202501140808.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import inspect
|
|
18
18
|
import types
|
19
19
|
import typing
|
20
20
|
|
21
|
-
from pyglove.core import
|
21
|
+
from pyglove.core import utils
|
22
22
|
from pyglove.core.typing import annotated
|
23
23
|
from pyglove.core.typing import class_schema
|
24
24
|
from pyglove.core.typing import inspect as pg_inspect
|
@@ -40,7 +40,8 @@ def _field_from_annotation(
|
|
40
40
|
annotation: typing.Any,
|
41
41
|
description: typing.Optional[str] = None,
|
42
42
|
metadata: typing.Optional[typing.Dict[str, typing.Any]] = None,
|
43
|
-
auto_typing=True,
|
43
|
+
auto_typing: bool = True,
|
44
|
+
parent_module: typing.Optional[types.ModuleType] = None
|
44
45
|
) -> class_schema.Field:
|
45
46
|
"""Creates a field from Python annotation."""
|
46
47
|
if isinstance(annotation, annotated.Annotated):
|
@@ -55,7 +56,9 @@ def _field_from_annotation(
|
|
55
56
|
return class_schema.create_field(
|
56
57
|
field_spec,
|
57
58
|
auto_typing=auto_typing,
|
58
|
-
accept_value_as_annotation=False
|
59
|
+
accept_value_as_annotation=False,
|
60
|
+
parent_module=parent_module
|
61
|
+
)
|
59
62
|
|
60
63
|
|
61
64
|
def _value_spec_from_default_value(
|
@@ -70,7 +73,7 @@ def _value_spec_from_default_value(
|
|
70
73
|
elif isinstance(value, tuple):
|
71
74
|
value_spec = vs.Tuple(
|
72
75
|
[_value_spec_from_default_value(elem, False) for elem in value])
|
73
|
-
elif inspect.isfunction(value) or isinstance(value,
|
76
|
+
elif inspect.isfunction(value) or isinstance(value, utils.Functor):
|
74
77
|
value_spec = vs.Callable()
|
75
78
|
elif not isinstance(value, type):
|
76
79
|
value_spec = vs.Object(type(value))
|
@@ -84,7 +87,9 @@ def _value_spec_from_default_value(
|
|
84
87
|
|
85
88
|
def _value_spec_from_type_annotation(
|
86
89
|
annotation: typing.Any,
|
87
|
-
accept_value_as_annotation: bool
|
90
|
+
accept_value_as_annotation: bool,
|
91
|
+
parent_module: typing.Optional[types.ModuleType] = None
|
92
|
+
) -> class_schema.ValueSpec:
|
88
93
|
"""Creates a value spec from type annotation."""
|
89
94
|
if annotation is bool:
|
90
95
|
return vs.Bool()
|
@@ -100,11 +105,15 @@ def _value_spec_from_type_annotation(
|
|
100
105
|
origin = typing.get_origin(annotation) or annotation
|
101
106
|
args = list(typing.get_args(annotation))
|
102
107
|
|
108
|
+
def _sub_value_spec_from_annotation(
|
109
|
+
annotation: typing.Any) -> class_schema.ValueSpec:
|
110
|
+
return _value_spec_from_type_annotation(
|
111
|
+
annotation, accept_value_as_annotation, parent_module)
|
112
|
+
|
103
113
|
# Handling list.
|
104
114
|
if origin in (list, typing.List):
|
105
115
|
return vs.List(
|
106
|
-
|
107
|
-
args[0], True)) if args else vs.List(vs.Any())
|
116
|
+
_sub_value_spec_from_annotation(args[0])) if args else vs.List(vs.Any())
|
108
117
|
# Handling tuple.
|
109
118
|
elif origin in (tuple, typing.Tuple):
|
110
119
|
if not args:
|
@@ -115,16 +124,15 @@ def _value_spec_from_type_annotation(
|
|
115
124
|
raise TypeError(
|
116
125
|
f'Tuple with ellipsis should have exact 2 type arguments. '
|
117
126
|
f'Encountered: {annotation}.')
|
118
|
-
return vs.Tuple(
|
119
|
-
return vs.Tuple([
|
120
|
-
for arg in args])
|
127
|
+
return vs.Tuple(_sub_value_spec_from_annotation(args[0]))
|
128
|
+
return vs.Tuple([_sub_value_spec_from_annotation(arg) for arg in args])
|
121
129
|
# Handling sequence.
|
122
130
|
elif origin in (collections.abc.Sequence,):
|
123
|
-
elem =
|
131
|
+
elem = _sub_value_spec_from_annotation(args[0]) if args else vs.Any()
|
124
132
|
return vs.Union([vs.List(elem), vs.Tuple(elem)])
|
125
133
|
# Handling literals.
|
126
134
|
elif origin is typing.Literal:
|
127
|
-
return vs.Enum(
|
135
|
+
return vs.Enum(utils.MISSING_VALUE, args)
|
128
136
|
# Handling dict.
|
129
137
|
elif origin in (dict, typing.Dict, collections.abc.Mapping):
|
130
138
|
if not args:
|
@@ -133,8 +141,7 @@ def _value_spec_from_type_annotation(
|
|
133
141
|
if args[0] not in (str, typing.Text):
|
134
142
|
raise TypeError(
|
135
143
|
'Dict type field with non-string key is not supported.')
|
136
|
-
elem_value_spec =
|
137
|
-
args[1], accept_value_as_annotation=False)
|
144
|
+
elem_value_spec = _sub_value_spec_from_annotation(args[1])
|
138
145
|
return vs.Dict([(ks.StrKey(), elem_value_spec)])
|
139
146
|
elif origin is collections.abc.Callable:
|
140
147
|
arg_specs = []
|
@@ -149,13 +156,8 @@ def _value_spec_from_type_annotation(
|
|
149
156
|
# Callable[int, Any] => Callable[[int], Any]
|
150
157
|
# Callable[(int, int), Any] => Callable[[int, int], Any]
|
151
158
|
if isinstance(args[0], list):
|
152
|
-
arg_specs = [
|
153
|
-
|
154
|
-
arg, accept_value_as_annotation=False)
|
155
|
-
for arg in args[0]
|
156
|
-
]
|
157
|
-
return_spec = _value_spec_from_type_annotation(
|
158
|
-
args[1], accept_value_as_annotation=False)
|
159
|
+
arg_specs = [_sub_value_spec_from_annotation(arg) for arg in args[0]]
|
160
|
+
return_spec = _sub_value_spec_from_annotation(args[1])
|
159
161
|
return vs.Callable(arg_specs, returns=return_spec)
|
160
162
|
# Handling type
|
161
163
|
elif origin is type or (annotation in (typing.Type, type)):
|
@@ -169,20 +171,32 @@ def _value_spec_from_type_annotation(
|
|
169
171
|
if optional:
|
170
172
|
args.remove(_NoneType)
|
171
173
|
if len(args) == 1:
|
172
|
-
spec =
|
174
|
+
spec = _sub_value_spec_from_annotation(args[0])
|
173
175
|
else:
|
174
|
-
spec = vs.Union([
|
176
|
+
spec = vs.Union([_sub_value_spec_from_annotation(x) for x in args])
|
175
177
|
if optional:
|
176
178
|
spec = spec.noneable()
|
177
179
|
return spec
|
180
|
+
elif origin is typing.Final:
|
181
|
+
return _value_spec_from_type_annotation(
|
182
|
+
args[0],
|
183
|
+
accept_value_as_annotation=False
|
184
|
+
).freeze(vs._FROZEN_VALUE_PLACEHOLDER) # pylint: disable=protected-access
|
178
185
|
elif isinstance(annotation, typing.ForwardRef):
|
179
|
-
|
186
|
+
annotation = annotation.__forward_arg__
|
187
|
+
if parent_module is not None:
|
188
|
+
annotation = class_schema.ForwardRef(parent_module, annotation)
|
189
|
+
return vs.Object(annotation)
|
190
|
+
elif isinstance(annotation, class_schema.ForwardRef):
|
191
|
+
return vs.Object(annotation)
|
180
192
|
# Handling class.
|
181
193
|
elif (
|
182
194
|
inspect.isclass(annotation)
|
183
195
|
or pg_inspect.is_generic(annotation)
|
184
196
|
or (isinstance(annotation, str) and not accept_value_as_annotation)
|
185
197
|
):
|
198
|
+
if isinstance(annotation, str) and parent_module is not None:
|
199
|
+
annotation = class_schema.ForwardRef(parent_module, annotation)
|
186
200
|
return vs.Object(annotation)
|
187
201
|
|
188
202
|
if accept_value_as_annotation:
|
@@ -204,8 +218,9 @@ def _any_spec_with_annotation(annotation: typing.Any) -> vs.Any:
|
|
204
218
|
|
205
219
|
def _value_spec_from_annotation(
|
206
220
|
annotation: typing.Any,
|
207
|
-
auto_typing=False,
|
208
|
-
accept_value_as_annotation=False
|
221
|
+
auto_typing: bool = False,
|
222
|
+
accept_value_as_annotation: bool = False,
|
223
|
+
parent_module: typing.Optional[types.ModuleType] = None
|
209
224
|
) -> class_schema.ValueSpec:
|
210
225
|
"""Creates a value spec from annotation."""
|
211
226
|
if isinstance(annotation, class_schema.ValueSpec):
|
@@ -220,7 +235,8 @@ def _value_spec_from_annotation(
|
|
220
235
|
|
221
236
|
if auto_typing:
|
222
237
|
return _value_spec_from_type_annotation(
|
223
|
-
annotation, accept_value_as_annotation
|
238
|
+
annotation, accept_value_as_annotation, parent_module
|
239
|
+
)
|
224
240
|
else:
|
225
241
|
value_spec = None
|
226
242
|
if accept_value_as_annotation:
|
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Tests for pyglove.core.typing.annotation_conversion."""
|
15
15
|
|
16
16
|
import inspect
|
17
|
+
import sys
|
17
18
|
import typing
|
18
19
|
import unittest
|
19
20
|
|
@@ -22,6 +23,7 @@ from pyglove.core.typing import annotation_conversion
|
|
22
23
|
from pyglove.core.typing import key_specs as ks
|
23
24
|
from pyglove.core.typing import value_specs as vs
|
24
25
|
from pyglove.core.typing.class_schema import Field
|
26
|
+
from pyglove.core.typing.class_schema import ForwardRef
|
25
27
|
from pyglove.core.typing.class_schema import ValueSpec
|
26
28
|
|
27
29
|
|
@@ -74,6 +76,14 @@ class FieldFromAnnotationTest(unittest.TestCase):
|
|
74
76
|
Field.from_annotation('x', str, 'A str', dict(x=1), auto_typing=True),
|
75
77
|
Field('x', vs.Str(), 'A str', dict(x=1)))
|
76
78
|
|
79
|
+
def test_with_parent_module(self):
|
80
|
+
self.assertEqual(
|
81
|
+
Field.from_annotation(
|
82
|
+
'x', 'ValueSpecBase', auto_typing=True, parent_module=vs
|
83
|
+
),
|
84
|
+
Field('x', vs.Object(vs.ValueSpecBase))
|
85
|
+
)
|
86
|
+
|
77
87
|
|
78
88
|
class ValueSpecFromAnnotationTest(unittest.TestCase):
|
79
89
|
"""Tests for ValueSpec.fromAnnotation."""
|
@@ -248,6 +258,11 @@ class ValueSpecFromAnnotationTest(unittest.TestCase):
|
|
248
258
|
self.assertEqual(
|
249
259
|
ValueSpec.from_annotation(
|
250
260
|
typing.ForwardRef('Foo'), True), vs.Object(Foo))
|
261
|
+
self.assertEqual(
|
262
|
+
ValueSpec.from_annotation(
|
263
|
+
ForwardRef(sys.modules[__name__], 'Foo'),
|
264
|
+
True
|
265
|
+
), vs.Object(Foo))
|
251
266
|
|
252
267
|
def test_generic_class(self):
|
253
268
|
X = typing.TypeVar('X')
|
@@ -291,6 +306,14 @@ class ValueSpecFromAnnotationTest(unittest.TestCase):
|
|
291
306
|
ValueSpec.from_annotation(int | str, True),
|
292
307
|
vs.Union([vs.Int(), vs.Str()]))
|
293
308
|
|
309
|
+
def test_final(self):
|
310
|
+
self.assertEqual(
|
311
|
+
ValueSpec.from_annotation(
|
312
|
+
typing.Final[int], True
|
313
|
+
).set_default(1),
|
314
|
+
vs.Int().freeze(1)
|
315
|
+
)
|
316
|
+
|
294
317
|
|
295
318
|
if __name__ == '__main__':
|
296
319
|
unittest.main()
|
@@ -13,11 +13,248 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Callable extensions."""
|
15
15
|
|
16
|
-
|
16
|
+
import contextlib
|
17
|
+
import functools
|
18
|
+
import inspect
|
19
|
+
import types
|
20
|
+
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
|
17
21
|
|
22
|
+
from pyglove.core import utils
|
18
23
|
from pyglove.core.typing import callable_signature
|
19
24
|
|
20
25
|
|
26
|
+
_TLS_KEY_PRESET_KWARGS = '__preset_kwargs__'
|
27
|
+
|
28
|
+
|
29
|
+
class PresetArgValue(utils.Formattable):
|
30
|
+
"""Value placeholder for arguments whose value will be provided by presets.
|
31
|
+
|
32
|
+
Example:
|
33
|
+
|
34
|
+
def foo(x, y=pg.PresetArgValue(default=1))
|
35
|
+
return x + y
|
36
|
+
|
37
|
+
with pg.preset_args(y=2):
|
38
|
+
print(foo(x=1)) # 3: y=2
|
39
|
+
print(foo(x=1)) # 2: y=1
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self, default: Any = utils.MISSING_VALUE):
|
43
|
+
self.default = default
|
44
|
+
|
45
|
+
@property
|
46
|
+
def has_default(self) -> bool:
|
47
|
+
return self.default != utils.MISSING_VALUE
|
48
|
+
|
49
|
+
def __eq__(self, other: Any) -> bool:
|
50
|
+
return isinstance(other, PresetArgValue) and (
|
51
|
+
self.default == other.default
|
52
|
+
)
|
53
|
+
|
54
|
+
def __ne__(self, other: Any) -> bool:
|
55
|
+
return not self.__eq__(other)
|
56
|
+
|
57
|
+
def format(self, *args, **kwargs):
|
58
|
+
return utils.kvlist_str(
|
59
|
+
[
|
60
|
+
('default', self.default, utils.MISSING_VALUE),
|
61
|
+
],
|
62
|
+
label='PresetArgValue',
|
63
|
+
*args,
|
64
|
+
**kwargs,
|
65
|
+
)
|
66
|
+
|
67
|
+
@classmethod
|
68
|
+
def inspect(
|
69
|
+
cls, func: types.FunctionType) -> Dict[str, 'PresetArgValue']:
|
70
|
+
"""Gets the PresetArgValue specified in a function's signature."""
|
71
|
+
assert inspect.isfunction(func), func
|
72
|
+
sig = inspect.signature(func)
|
73
|
+
preset_arg_markers = {}
|
74
|
+
for p in sig.parameters.values():
|
75
|
+
if isinstance(p.default, cls):
|
76
|
+
preset_arg_markers[p.name] = p.default
|
77
|
+
return preset_arg_markers
|
78
|
+
|
79
|
+
@classmethod
|
80
|
+
def resolve_args(
|
81
|
+
cls,
|
82
|
+
call_args: Tuple[Any, ...],
|
83
|
+
call_kwargs: Dict[str, Any],
|
84
|
+
positional_arg_names: Sequence[str],
|
85
|
+
arg_defaults: Dict[str, Any],
|
86
|
+
preset_kwargs: Dict[str, Any],
|
87
|
+
include_all_preset_kwargs: bool = False,
|
88
|
+
) -> Tuple[Sequence[Any], Dict[str, Any]]:
|
89
|
+
"""Resolves calling arguments passed to a method with presets."""
|
90
|
+
# Step 1: compute marked kwargs.
|
91
|
+
resolved_kwargs = {}
|
92
|
+
for arg_name, arg_default in arg_defaults.items():
|
93
|
+
if not isinstance(arg_default, PresetArgValue):
|
94
|
+
resolved_kwargs[arg_name] = arg_default
|
95
|
+
continue
|
96
|
+
if arg_name in preset_kwargs:
|
97
|
+
resolved_kwargs[arg_name] = preset_kwargs[arg_name]
|
98
|
+
elif arg_default.has_default:
|
99
|
+
resolved_kwargs[arg_name] = arg_default.default
|
100
|
+
else:
|
101
|
+
raise ValueError(
|
102
|
+
f'Argument {arg_name!r} is not present as a keyword argument '
|
103
|
+
'from the caller.'
|
104
|
+
)
|
105
|
+
|
106
|
+
# Step 2: add preset kwargs
|
107
|
+
if include_all_preset_kwargs:
|
108
|
+
for k, v in preset_kwargs.items():
|
109
|
+
if k not in resolved_kwargs:
|
110
|
+
resolved_kwargs[k] = v
|
111
|
+
|
112
|
+
# Step 3: merge call kwargs with resolved preset kwargs.
|
113
|
+
resolved_kwargs.update(call_kwargs)
|
114
|
+
|
115
|
+
# Step 3: remove resolved kwargs items as it's present in call args.
|
116
|
+
for i in range(len(call_args)):
|
117
|
+
if i >= len(positional_arg_names):
|
118
|
+
break
|
119
|
+
resolved_kwargs.pop(positional_arg_names[i], None)
|
120
|
+
|
121
|
+
# Step 4: convert kwargs back to postional arguments if applicable
|
122
|
+
resolved_args = call_args
|
123
|
+
if len(positional_arg_names) > len(call_args):
|
124
|
+
resolved_args = list(resolved_args)
|
125
|
+
for arg_name in positional_arg_names[len(call_args):]:
|
126
|
+
if arg_name not in resolved_kwargs:
|
127
|
+
break
|
128
|
+
arg_value = resolved_kwargs.pop(arg_name)
|
129
|
+
resolved_args.append(arg_value)
|
130
|
+
return resolved_args, resolved_kwargs
|
131
|
+
|
132
|
+
|
133
|
+
class _ArgPresets:
|
134
|
+
"""Preset argument collection."""
|
135
|
+
|
136
|
+
def __init__(self, presets: Optional[Dict[str, Dict[str, Any]]] = None):
|
137
|
+
self._presets: Dict[str, Dict[str, Any]] = presets or {}
|
138
|
+
|
139
|
+
def derive(
|
140
|
+
self,
|
141
|
+
kwargs: Dict[str, Any],
|
142
|
+
preset_name: str = 'global',
|
143
|
+
inherit_preset: Union[str, bool] = False
|
144
|
+
) -> '_ArgPresets':
|
145
|
+
"""Derives new presets from current presets."""
|
146
|
+
presets = self._presets.copy() # Just do a shallow copy.
|
147
|
+
if isinstance(inherit_preset, bool) and inherit_preset:
|
148
|
+
inherit_preset = preset_name
|
149
|
+
|
150
|
+
if inherit_preset and inherit_preset in presets:
|
151
|
+
current_preset = presets[inherit_preset].copy()
|
152
|
+
current_preset.update(kwargs)
|
153
|
+
else:
|
154
|
+
current_preset = kwargs
|
155
|
+
presets[preset_name] = current_preset
|
156
|
+
return _ArgPresets(presets)
|
157
|
+
|
158
|
+
def get_preset(self, preset_name: str) -> Dict[str, Any]:
|
159
|
+
return self._presets.get(preset_name, {})
|
160
|
+
|
161
|
+
|
162
|
+
@contextlib.contextmanager
|
163
|
+
def preset_args(
|
164
|
+
kwargs: Dict[str, Any],
|
165
|
+
*,
|
166
|
+
preset_name: str = 'global',
|
167
|
+
inherit_preset: Union[str, bool] = False
|
168
|
+
) -> Iterator[Dict[str, Any]]:
|
169
|
+
"""Context manager to enable calling with user kwargs.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
kwargs: The preset kwargs to be used by preset-enabled functions within
|
173
|
+
the context.
|
174
|
+
preset_name: The name of the preset to specify kwargs.
|
175
|
+
`enable_preset_args` allows users to pass a preset name, which will be
|
176
|
+
used to identify the present to be used.
|
177
|
+
inherit_preset: The name of the preset defined by the parent context to
|
178
|
+
inherit kwargs from. Or a boolean to indicate whether to inherit a
|
179
|
+
parent preset of the same name.
|
180
|
+
|
181
|
+
Yields:
|
182
|
+
Current preset kwargs.
|
183
|
+
"""
|
184
|
+
|
185
|
+
parent_presets = utils.thread_local_peek(
|
186
|
+
_TLS_KEY_PRESET_KWARGS, _ArgPresets()
|
187
|
+
)
|
188
|
+
current_preset = parent_presets.derive(kwargs, preset_name, inherit_preset)
|
189
|
+
utils.thread_local_push(_TLS_KEY_PRESET_KWARGS, current_preset)
|
190
|
+
try:
|
191
|
+
yield current_preset
|
192
|
+
finally:
|
193
|
+
utils.thread_local_pop(_TLS_KEY_PRESET_KWARGS, None)
|
194
|
+
|
195
|
+
|
196
|
+
def enable_preset_args(
|
197
|
+
include_all_preset_kwargs: bool = False,
|
198
|
+
preset_name: str = 'global'
|
199
|
+
) -> Callable[[types.FunctionType], types.FunctionType]:
|
200
|
+
"""Decorator for functions that maybe use preset argument values.
|
201
|
+
|
202
|
+
Usage::
|
203
|
+
|
204
|
+
@pg.typing.enable_preset_args
|
205
|
+
def foo(x, y=pg.typing.PresetArgValue(default=1)):
|
206
|
+
return x + y
|
207
|
+
|
208
|
+
with pg.typing.preset_args(y=2):
|
209
|
+
print(foo(x=1)) # 3: y=2
|
210
|
+
print(foo(x=1)) # 2: y=1
|
211
|
+
|
212
|
+
Args:
|
213
|
+
include_all_preset_kwargs: Whether to include all preset kwargs (even
|
214
|
+
not makred as `PresetArgValue`) when callng the function.
|
215
|
+
preset_name: The name of the preset to specify kwargs.
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
A decorated function that could consume the preset argument values.
|
219
|
+
"""
|
220
|
+
def decorator(func):
|
221
|
+
sig = inspect.signature(func)
|
222
|
+
positional_arg_names = [
|
223
|
+
p.name for p in sig.parameters.values()
|
224
|
+
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
225
|
+
]
|
226
|
+
|
227
|
+
arg_defaults = {}
|
228
|
+
has_preset_value = False
|
229
|
+
has_varkw = False
|
230
|
+
for p in sig.parameters.values():
|
231
|
+
if p.kind == inspect.Parameter.VAR_KEYWORD:
|
232
|
+
has_varkw = True
|
233
|
+
continue
|
234
|
+
if p.kind == inspect.Parameter.VAR_POSITIONAL:
|
235
|
+
continue
|
236
|
+
if p.default == inspect.Parameter.empty:
|
237
|
+
continue
|
238
|
+
if isinstance(p.default, PresetArgValue):
|
239
|
+
has_preset_value = True
|
240
|
+
arg_defaults[p.name] = p.default
|
241
|
+
|
242
|
+
if has_preset_value:
|
243
|
+
@functools.wraps(func)
|
244
|
+
def _func(*args, **kwargs):
|
245
|
+
# Map positional arguments to keyword arguments.
|
246
|
+
presets = utils.thread_local_peek(_TLS_KEY_PRESET_KWARGS, None)
|
247
|
+
preset_kwargs = presets.get_preset(preset_name) if presets else {}
|
248
|
+
args, kwargs = PresetArgValue.resolve_args(
|
249
|
+
args, kwargs, positional_arg_names, arg_defaults, preset_kwargs,
|
250
|
+
include_all_preset_kwargs=include_all_preset_kwargs and has_varkw
|
251
|
+
)
|
252
|
+
return func(*args, **kwargs)
|
253
|
+
return _func
|
254
|
+
return func
|
255
|
+
return decorator
|
256
|
+
|
257
|
+
|
21
258
|
class CallableWithOptionalKeywordArgs:
|
22
259
|
"""Helper class for invoking callable objects with optional keyword args.
|
23
260
|
|
@@ -31,8 +268,9 @@ class CallableWithOptionalKeywordArgs:
|
|
31
268
|
def __init__(self,
|
32
269
|
func: Callable[..., Any],
|
33
270
|
optional_keywords: List[str]):
|
34
|
-
sig = callable_signature.
|
35
|
-
|
271
|
+
sig = callable_signature.signature(
|
272
|
+
func, auto_typing=False, auto_doc=False
|
273
|
+
)
|
36
274
|
|
37
275
|
# Check for variable keyword arguments.
|
38
276
|
if sig.has_varkw:
|