pyglove 0.4.5.dev20240319__py3-none-any.whl → 0.4.5.dev202501140808__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. pyglove/core/__init__.py +54 -20
  2. pyglove/core/coding/__init__.py +42 -0
  3. pyglove/core/coding/errors.py +111 -0
  4. pyglove/core/coding/errors_test.py +98 -0
  5. pyglove/core/coding/execution.py +309 -0
  6. pyglove/core/coding/execution_test.py +333 -0
  7. pyglove/core/{object_utils/codegen.py → coding/function_generation.py} +10 -4
  8. pyglove/core/{object_utils/codegen_test.py → coding/function_generation_test.py} +5 -7
  9. pyglove/core/coding/parsing.py +153 -0
  10. pyglove/core/coding/parsing_test.py +150 -0
  11. pyglove/core/coding/permissions.py +100 -0
  12. pyglove/core/coding/permissions_test.py +93 -0
  13. pyglove/core/geno/base.py +54 -41
  14. pyglove/core/geno/base_test.py +2 -4
  15. pyglove/core/geno/categorical.py +37 -28
  16. pyglove/core/geno/custom.py +19 -16
  17. pyglove/core/geno/numerical.py +20 -17
  18. pyglove/core/geno/space.py +4 -5
  19. pyglove/core/hyper/base.py +6 -6
  20. pyglove/core/hyper/categorical.py +94 -55
  21. pyglove/core/hyper/custom.py +7 -7
  22. pyglove/core/hyper/custom_test.py +9 -10
  23. pyglove/core/hyper/derived.py +30 -22
  24. pyglove/core/hyper/derived_test.py +2 -4
  25. pyglove/core/hyper/dynamic_evaluation.py +5 -6
  26. pyglove/core/hyper/evolvable.py +57 -46
  27. pyglove/core/hyper/numerical.py +48 -24
  28. pyglove/core/hyper/numerical_test.py +9 -9
  29. pyglove/core/hyper/object_template.py +58 -46
  30. pyglove/core/io/__init__.py +1 -0
  31. pyglove/core/io/file_system.py +17 -7
  32. pyglove/core/io/file_system_test.py +2 -0
  33. pyglove/core/io/sequence.py +299 -0
  34. pyglove/core/io/sequence_test.py +124 -0
  35. pyglove/core/logging_test.py +0 -2
  36. pyglove/core/patching/object_factory.py +4 -4
  37. pyglove/core/patching/pattern_based.py +4 -4
  38. pyglove/core/patching/rule_based.py +17 -5
  39. pyglove/core/patching/rule_based_test.py +27 -4
  40. pyglove/core/symbolic/__init__.py +2 -7
  41. pyglove/core/symbolic/base.py +320 -183
  42. pyglove/core/symbolic/base_test.py +123 -19
  43. pyglove/core/symbolic/boilerplate.py +7 -13
  44. pyglove/core/symbolic/boilerplate_test.py +25 -23
  45. pyglove/core/symbolic/class_wrapper.py +48 -45
  46. pyglove/core/symbolic/class_wrapper_test.py +2 -2
  47. pyglove/core/symbolic/compounding.py +9 -15
  48. pyglove/core/symbolic/compounding_test.py +2 -4
  49. pyglove/core/symbolic/dict.py +154 -110
  50. pyglove/core/symbolic/dict_test.py +238 -130
  51. pyglove/core/symbolic/diff.py +199 -10
  52. pyglove/core/symbolic/diff_test.py +226 -0
  53. pyglove/core/symbolic/flags.py +1 -1
  54. pyglove/core/symbolic/functor.py +29 -26
  55. pyglove/core/symbolic/functor_test.py +102 -50
  56. pyglove/core/symbolic/inferred.py +2 -2
  57. pyglove/core/symbolic/list.py +81 -50
  58. pyglove/core/symbolic/list_test.py +119 -97
  59. pyglove/core/symbolic/object.py +225 -113
  60. pyglove/core/symbolic/object_test.py +320 -108
  61. pyglove/core/symbolic/origin.py +17 -14
  62. pyglove/core/symbolic/origin_test.py +4 -2
  63. pyglove/core/symbolic/pure_symbolic.py +4 -3
  64. pyglove/core/symbolic/ref.py +108 -21
  65. pyglove/core/symbolic/ref_test.py +93 -0
  66. pyglove/core/symbolic/symbolize_test.py +10 -2
  67. pyglove/core/tuning/local_backend.py +2 -2
  68. pyglove/core/tuning/protocols.py +3 -3
  69. pyglove/core/tuning/sample_test.py +3 -3
  70. pyglove/core/typing/__init__.py +14 -5
  71. pyglove/core/typing/annotation_conversion.py +43 -27
  72. pyglove/core/typing/annotation_conversion_test.py +23 -0
  73. pyglove/core/typing/callable_ext.py +241 -3
  74. pyglove/core/typing/callable_ext_test.py +255 -0
  75. pyglove/core/typing/callable_signature.py +510 -66
  76. pyglove/core/typing/callable_signature_test.py +619 -99
  77. pyglove/core/typing/class_schema.py +229 -154
  78. pyglove/core/typing/class_schema_test.py +149 -95
  79. pyglove/core/typing/custom_typing.py +5 -4
  80. pyglove/core/typing/inspect.py +63 -0
  81. pyglove/core/typing/inspect_test.py +39 -0
  82. pyglove/core/typing/key_specs.py +10 -11
  83. pyglove/core/typing/key_specs_test.py +7 -4
  84. pyglove/core/typing/type_conversion.py +4 -5
  85. pyglove/core/typing/type_conversion_test.py +12 -12
  86. pyglove/core/typing/typed_missing.py +6 -7
  87. pyglove/core/typing/typed_missing_test.py +7 -8
  88. pyglove/core/typing/value_specs.py +604 -362
  89. pyglove/core/typing/value_specs_test.py +328 -90
  90. pyglove/core/utils/__init__.py +164 -0
  91. pyglove/core/{object_utils → utils}/common_traits.py +3 -67
  92. pyglove/core/utils/common_traits_test.py +36 -0
  93. pyglove/core/{object_utils → utils}/docstr_utils.py +23 -0
  94. pyglove/core/{object_utils → utils}/docstr_utils_test.py +36 -4
  95. pyglove/core/{object_utils → utils}/error_utils.py +78 -9
  96. pyglove/core/{object_utils → utils}/error_utils_test.py +61 -5
  97. pyglove/core/utils/formatting.py +464 -0
  98. pyglove/core/utils/formatting_test.py +453 -0
  99. pyglove/core/{object_utils → utils}/hierarchical.py +23 -25
  100. pyglove/core/{object_utils → utils}/hierarchical_test.py +3 -5
  101. pyglove/core/{object_utils → utils}/json_conversion.py +177 -52
  102. pyglove/core/{object_utils → utils}/json_conversion_test.py +97 -16
  103. pyglove/core/{object_utils → utils}/missing.py +3 -3
  104. pyglove/core/{object_utils → utils}/missing_test.py +2 -4
  105. pyglove/core/utils/text_color.py +128 -0
  106. pyglove/core/utils/text_color_test.py +94 -0
  107. pyglove/core/{object_utils → utils}/thread_local_test.py +1 -3
  108. pyglove/core/utils/timing.py +236 -0
  109. pyglove/core/utils/timing_test.py +154 -0
  110. pyglove/core/{object_utils → utils}/value_location.py +275 -6
  111. pyglove/core/utils/value_location_test.py +707 -0
  112. pyglove/core/views/__init__.py +32 -0
  113. pyglove/core/views/base.py +804 -0
  114. pyglove/core/views/base_test.py +580 -0
  115. pyglove/core/views/html/__init__.py +27 -0
  116. pyglove/core/views/html/base.py +547 -0
  117. pyglove/core/views/html/base_test.py +830 -0
  118. pyglove/core/views/html/controls/__init__.py +35 -0
  119. pyglove/core/views/html/controls/base.py +275 -0
  120. pyglove/core/views/html/controls/label.py +207 -0
  121. pyglove/core/views/html/controls/label_test.py +157 -0
  122. pyglove/core/views/html/controls/progress_bar.py +183 -0
  123. pyglove/core/views/html/controls/progress_bar_test.py +97 -0
  124. pyglove/core/views/html/controls/tab.py +320 -0
  125. pyglove/core/views/html/controls/tab_test.py +87 -0
  126. pyglove/core/views/html/controls/tooltip.py +99 -0
  127. pyglove/core/views/html/controls/tooltip_test.py +99 -0
  128. pyglove/core/views/html/tree_view.py +1517 -0
  129. pyglove/core/views/html/tree_view_test.py +1461 -0
  130. {pyglove-0.4.5.dev20240319.dist-info → pyglove-0.4.5.dev202501140808.dist-info}/METADATA +18 -4
  131. pyglove-0.4.5.dev202501140808.dist-info/RECORD +214 -0
  132. {pyglove-0.4.5.dev20240319.dist-info → pyglove-0.4.5.dev202501140808.dist-info}/WHEEL +1 -1
  133. pyglove/core/object_utils/__init__.py +0 -154
  134. pyglove/core/object_utils/common_traits_test.py +0 -82
  135. pyglove/core/object_utils/formatting.py +0 -234
  136. pyglove/core/object_utils/formatting_test.py +0 -223
  137. pyglove/core/object_utils/value_location_test.py +0 -385
  138. pyglove/core/symbolic/schema_utils.py +0 -327
  139. pyglove/core/symbolic/schema_utils_test.py +0 -57
  140. pyglove/core/typing/class_schema_utils.py +0 -202
  141. pyglove/core/typing/class_schema_utils_test.py +0 -194
  142. pyglove-0.4.5.dev20240319.dist-info/RECORD +0 -185
  143. /pyglove/core/{object_utils → utils}/thread_local.py +0 -0
  144. {pyglove-0.4.5.dev20240319.dist-info → pyglove-0.4.5.dev202501140808.dist-info}/LICENSE +0 -0
  145. {pyglove-0.4.5.dev20240319.dist-info → pyglove-0.4.5.dev202501140808.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import inspect
18
18
  import types
19
19
  import typing
20
20
 
21
- from pyglove.core import object_utils
21
+ from pyglove.core import utils
22
22
  from pyglove.core.typing import annotated
23
23
  from pyglove.core.typing import class_schema
24
24
  from pyglove.core.typing import inspect as pg_inspect
@@ -40,7 +40,8 @@ def _field_from_annotation(
40
40
  annotation: typing.Any,
41
41
  description: typing.Optional[str] = None,
42
42
  metadata: typing.Optional[typing.Dict[str, typing.Any]] = None,
43
- auto_typing=True,
43
+ auto_typing: bool = True,
44
+ parent_module: typing.Optional[types.ModuleType] = None
44
45
  ) -> class_schema.Field:
45
46
  """Creates a field from Python annotation."""
46
47
  if isinstance(annotation, annotated.Annotated):
@@ -55,7 +56,9 @@ def _field_from_annotation(
55
56
  return class_schema.create_field(
56
57
  field_spec,
57
58
  auto_typing=auto_typing,
58
- accept_value_as_annotation=False)
59
+ accept_value_as_annotation=False,
60
+ parent_module=parent_module
61
+ )
59
62
 
60
63
 
61
64
  def _value_spec_from_default_value(
@@ -70,7 +73,7 @@ def _value_spec_from_default_value(
70
73
  elif isinstance(value, tuple):
71
74
  value_spec = vs.Tuple(
72
75
  [_value_spec_from_default_value(elem, False) for elem in value])
73
- elif inspect.isfunction(value) or isinstance(value, object_utils.Functor):
76
+ elif inspect.isfunction(value) or isinstance(value, utils.Functor):
74
77
  value_spec = vs.Callable()
75
78
  elif not isinstance(value, type):
76
79
  value_spec = vs.Object(type(value))
@@ -84,7 +87,9 @@ def _value_spec_from_default_value(
84
87
 
85
88
  def _value_spec_from_type_annotation(
86
89
  annotation: typing.Any,
87
- accept_value_as_annotation: bool) -> class_schema.ValueSpec:
90
+ accept_value_as_annotation: bool,
91
+ parent_module: typing.Optional[types.ModuleType] = None
92
+ ) -> class_schema.ValueSpec:
88
93
  """Creates a value spec from type annotation."""
89
94
  if annotation is bool:
90
95
  return vs.Bool()
@@ -100,11 +105,15 @@ def _value_spec_from_type_annotation(
100
105
  origin = typing.get_origin(annotation) or annotation
101
106
  args = list(typing.get_args(annotation))
102
107
 
108
+ def _sub_value_spec_from_annotation(
109
+ annotation: typing.Any) -> class_schema.ValueSpec:
110
+ return _value_spec_from_type_annotation(
111
+ annotation, accept_value_as_annotation, parent_module)
112
+
103
113
  # Handling list.
104
114
  if origin in (list, typing.List):
105
115
  return vs.List(
106
- _value_spec_from_annotation(
107
- args[0], True)) if args else vs.List(vs.Any())
116
+ _sub_value_spec_from_annotation(args[0])) if args else vs.List(vs.Any())
108
117
  # Handling tuple.
109
118
  elif origin in (tuple, typing.Tuple):
110
119
  if not args:
@@ -115,16 +124,15 @@ def _value_spec_from_type_annotation(
115
124
  raise TypeError(
116
125
  f'Tuple with ellipsis should have exact 2 type arguments. '
117
126
  f'Encountered: {annotation}.')
118
- return vs.Tuple(_value_spec_from_type_annotation(args[0], False))
119
- return vs.Tuple([_value_spec_from_type_annotation(arg, False)
120
- for arg in args])
127
+ return vs.Tuple(_sub_value_spec_from_annotation(args[0]))
128
+ return vs.Tuple([_sub_value_spec_from_annotation(arg) for arg in args])
121
129
  # Handling sequence.
122
130
  elif origin in (collections.abc.Sequence,):
123
- elem = _value_spec_from_annotation(args[0], True) if args else vs.Any()
131
+ elem = _sub_value_spec_from_annotation(args[0]) if args else vs.Any()
124
132
  return vs.Union([vs.List(elem), vs.Tuple(elem)])
125
133
  # Handling literals.
126
134
  elif origin is typing.Literal:
127
- return vs.Enum(object_utils.MISSING_VALUE, args)
135
+ return vs.Enum(utils.MISSING_VALUE, args)
128
136
  # Handling dict.
129
137
  elif origin in (dict, typing.Dict, collections.abc.Mapping):
130
138
  if not args:
@@ -133,8 +141,7 @@ def _value_spec_from_type_annotation(
133
141
  if args[0] not in (str, typing.Text):
134
142
  raise TypeError(
135
143
  'Dict type field with non-string key is not supported.')
136
- elem_value_spec = _value_spec_from_type_annotation(
137
- args[1], accept_value_as_annotation=False)
144
+ elem_value_spec = _sub_value_spec_from_annotation(args[1])
138
145
  return vs.Dict([(ks.StrKey(), elem_value_spec)])
139
146
  elif origin is collections.abc.Callable:
140
147
  arg_specs = []
@@ -149,13 +156,8 @@ def _value_spec_from_type_annotation(
149
156
  # Callable[int, Any] => Callable[[int], Any]
150
157
  # Callable[(int, int), Any] => Callable[[int, int], Any]
151
158
  if isinstance(args[0], list):
152
- arg_specs = [
153
- _value_spec_from_type_annotation(
154
- arg, accept_value_as_annotation=False)
155
- for arg in args[0]
156
- ]
157
- return_spec = _value_spec_from_type_annotation(
158
- args[1], accept_value_as_annotation=False)
159
+ arg_specs = [_sub_value_spec_from_annotation(arg) for arg in args[0]]
160
+ return_spec = _sub_value_spec_from_annotation(args[1])
159
161
  return vs.Callable(arg_specs, returns=return_spec)
160
162
  # Handling type
161
163
  elif origin is type or (annotation in (typing.Type, type)):
@@ -169,20 +171,32 @@ def _value_spec_from_type_annotation(
169
171
  if optional:
170
172
  args.remove(_NoneType)
171
173
  if len(args) == 1:
172
- spec = _value_spec_from_annotation(args[0], True)
174
+ spec = _sub_value_spec_from_annotation(args[0])
173
175
  else:
174
- spec = vs.Union([_value_spec_from_annotation(x, True) for x in set(args)])
176
+ spec = vs.Union([_sub_value_spec_from_annotation(x) for x in args])
175
177
  if optional:
176
178
  spec = spec.noneable()
177
179
  return spec
180
+ elif origin is typing.Final:
181
+ return _value_spec_from_type_annotation(
182
+ args[0],
183
+ accept_value_as_annotation=False
184
+ ).freeze(vs._FROZEN_VALUE_PLACEHOLDER) # pylint: disable=protected-access
178
185
  elif isinstance(annotation, typing.ForwardRef):
179
- return vs.Object(annotation.__forward_arg__)
186
+ annotation = annotation.__forward_arg__
187
+ if parent_module is not None:
188
+ annotation = class_schema.ForwardRef(parent_module, annotation)
189
+ return vs.Object(annotation)
190
+ elif isinstance(annotation, class_schema.ForwardRef):
191
+ return vs.Object(annotation)
180
192
  # Handling class.
181
193
  elif (
182
194
  inspect.isclass(annotation)
183
195
  or pg_inspect.is_generic(annotation)
184
196
  or (isinstance(annotation, str) and not accept_value_as_annotation)
185
197
  ):
198
+ if isinstance(annotation, str) and parent_module is not None:
199
+ annotation = class_schema.ForwardRef(parent_module, annotation)
186
200
  return vs.Object(annotation)
187
201
 
188
202
  if accept_value_as_annotation:
@@ -204,8 +218,9 @@ def _any_spec_with_annotation(annotation: typing.Any) -> vs.Any:
204
218
 
205
219
  def _value_spec_from_annotation(
206
220
  annotation: typing.Any,
207
- auto_typing=False,
208
- accept_value_as_annotation=False
221
+ auto_typing: bool = False,
222
+ accept_value_as_annotation: bool = False,
223
+ parent_module: typing.Optional[types.ModuleType] = None
209
224
  ) -> class_schema.ValueSpec:
210
225
  """Creates a value spec from annotation."""
211
226
  if isinstance(annotation, class_schema.ValueSpec):
@@ -220,7 +235,8 @@ def _value_spec_from_annotation(
220
235
 
221
236
  if auto_typing:
222
237
  return _value_spec_from_type_annotation(
223
- annotation, accept_value_as_annotation)
238
+ annotation, accept_value_as_annotation, parent_module
239
+ )
224
240
  else:
225
241
  value_spec = None
226
242
  if accept_value_as_annotation:
@@ -14,6 +14,7 @@
14
14
  """Tests for pyglove.core.typing.annotation_conversion."""
15
15
 
16
16
  import inspect
17
+ import sys
17
18
  import typing
18
19
  import unittest
19
20
 
@@ -22,6 +23,7 @@ from pyglove.core.typing import annotation_conversion
22
23
  from pyglove.core.typing import key_specs as ks
23
24
  from pyglove.core.typing import value_specs as vs
24
25
  from pyglove.core.typing.class_schema import Field
26
+ from pyglove.core.typing.class_schema import ForwardRef
25
27
  from pyglove.core.typing.class_schema import ValueSpec
26
28
 
27
29
 
@@ -74,6 +76,14 @@ class FieldFromAnnotationTest(unittest.TestCase):
74
76
  Field.from_annotation('x', str, 'A str', dict(x=1), auto_typing=True),
75
77
  Field('x', vs.Str(), 'A str', dict(x=1)))
76
78
 
79
+ def test_with_parent_module(self):
80
+ self.assertEqual(
81
+ Field.from_annotation(
82
+ 'x', 'ValueSpecBase', auto_typing=True, parent_module=vs
83
+ ),
84
+ Field('x', vs.Object(vs.ValueSpecBase))
85
+ )
86
+
77
87
 
78
88
  class ValueSpecFromAnnotationTest(unittest.TestCase):
79
89
  """Tests for ValueSpec.fromAnnotation."""
@@ -248,6 +258,11 @@ class ValueSpecFromAnnotationTest(unittest.TestCase):
248
258
  self.assertEqual(
249
259
  ValueSpec.from_annotation(
250
260
  typing.ForwardRef('Foo'), True), vs.Object(Foo))
261
+ self.assertEqual(
262
+ ValueSpec.from_annotation(
263
+ ForwardRef(sys.modules[__name__], 'Foo'),
264
+ True
265
+ ), vs.Object(Foo))
251
266
 
252
267
  def test_generic_class(self):
253
268
  X = typing.TypeVar('X')
@@ -291,6 +306,14 @@ class ValueSpecFromAnnotationTest(unittest.TestCase):
291
306
  ValueSpec.from_annotation(int | str, True),
292
307
  vs.Union([vs.Int(), vs.Str()]))
293
308
 
309
+ def test_final(self):
310
+ self.assertEqual(
311
+ ValueSpec.from_annotation(
312
+ typing.Final[int], True
313
+ ).set_default(1),
314
+ vs.Int().freeze(1)
315
+ )
316
+
294
317
 
295
318
  if __name__ == '__main__':
296
319
  unittest.main()
@@ -13,11 +13,248 @@
13
13
  # limitations under the License.
14
14
  """Callable extensions."""
15
15
 
16
- from typing import Any, Callable, List
16
+ import contextlib
17
+ import functools
18
+ import inspect
19
+ import types
20
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
17
21
 
22
+ from pyglove.core import utils
18
23
  from pyglove.core.typing import callable_signature
19
24
 
20
25
 
26
+ _TLS_KEY_PRESET_KWARGS = '__preset_kwargs__'
27
+
28
+
29
+ class PresetArgValue(utils.Formattable):
30
+ """Value placeholder for arguments whose value will be provided by presets.
31
+
32
+ Example:
33
+
34
+ def foo(x, y=pg.PresetArgValue(default=1))
35
+ return x + y
36
+
37
+ with pg.preset_args(y=2):
38
+ print(foo(x=1)) # 3: y=2
39
+ print(foo(x=1)) # 2: y=1
40
+ """
41
+
42
+ def __init__(self, default: Any = utils.MISSING_VALUE):
43
+ self.default = default
44
+
45
+ @property
46
+ def has_default(self) -> bool:
47
+ return self.default != utils.MISSING_VALUE
48
+
49
+ def __eq__(self, other: Any) -> bool:
50
+ return isinstance(other, PresetArgValue) and (
51
+ self.default == other.default
52
+ )
53
+
54
+ def __ne__(self, other: Any) -> bool:
55
+ return not self.__eq__(other)
56
+
57
+ def format(self, *args, **kwargs):
58
+ return utils.kvlist_str(
59
+ [
60
+ ('default', self.default, utils.MISSING_VALUE),
61
+ ],
62
+ label='PresetArgValue',
63
+ *args,
64
+ **kwargs,
65
+ )
66
+
67
+ @classmethod
68
+ def inspect(
69
+ cls, func: types.FunctionType) -> Dict[str, 'PresetArgValue']:
70
+ """Gets the PresetArgValue specified in a function's signature."""
71
+ assert inspect.isfunction(func), func
72
+ sig = inspect.signature(func)
73
+ preset_arg_markers = {}
74
+ for p in sig.parameters.values():
75
+ if isinstance(p.default, cls):
76
+ preset_arg_markers[p.name] = p.default
77
+ return preset_arg_markers
78
+
79
+ @classmethod
80
+ def resolve_args(
81
+ cls,
82
+ call_args: Tuple[Any, ...],
83
+ call_kwargs: Dict[str, Any],
84
+ positional_arg_names: Sequence[str],
85
+ arg_defaults: Dict[str, Any],
86
+ preset_kwargs: Dict[str, Any],
87
+ include_all_preset_kwargs: bool = False,
88
+ ) -> Tuple[Sequence[Any], Dict[str, Any]]:
89
+ """Resolves calling arguments passed to a method with presets."""
90
+ # Step 1: compute marked kwargs.
91
+ resolved_kwargs = {}
92
+ for arg_name, arg_default in arg_defaults.items():
93
+ if not isinstance(arg_default, PresetArgValue):
94
+ resolved_kwargs[arg_name] = arg_default
95
+ continue
96
+ if arg_name in preset_kwargs:
97
+ resolved_kwargs[arg_name] = preset_kwargs[arg_name]
98
+ elif arg_default.has_default:
99
+ resolved_kwargs[arg_name] = arg_default.default
100
+ else:
101
+ raise ValueError(
102
+ f'Argument {arg_name!r} is not present as a keyword argument '
103
+ 'from the caller.'
104
+ )
105
+
106
+ # Step 2: add preset kwargs
107
+ if include_all_preset_kwargs:
108
+ for k, v in preset_kwargs.items():
109
+ if k not in resolved_kwargs:
110
+ resolved_kwargs[k] = v
111
+
112
+ # Step 3: merge call kwargs with resolved preset kwargs.
113
+ resolved_kwargs.update(call_kwargs)
114
+
115
+ # Step 3: remove resolved kwargs items as it's present in call args.
116
+ for i in range(len(call_args)):
117
+ if i >= len(positional_arg_names):
118
+ break
119
+ resolved_kwargs.pop(positional_arg_names[i], None)
120
+
121
+ # Step 4: convert kwargs back to postional arguments if applicable
122
+ resolved_args = call_args
123
+ if len(positional_arg_names) > len(call_args):
124
+ resolved_args = list(resolved_args)
125
+ for arg_name in positional_arg_names[len(call_args):]:
126
+ if arg_name not in resolved_kwargs:
127
+ break
128
+ arg_value = resolved_kwargs.pop(arg_name)
129
+ resolved_args.append(arg_value)
130
+ return resolved_args, resolved_kwargs
131
+
132
+
133
+ class _ArgPresets:
134
+ """Preset argument collection."""
135
+
136
+ def __init__(self, presets: Optional[Dict[str, Dict[str, Any]]] = None):
137
+ self._presets: Dict[str, Dict[str, Any]] = presets or {}
138
+
139
+ def derive(
140
+ self,
141
+ kwargs: Dict[str, Any],
142
+ preset_name: str = 'global',
143
+ inherit_preset: Union[str, bool] = False
144
+ ) -> '_ArgPresets':
145
+ """Derives new presets from current presets."""
146
+ presets = self._presets.copy() # Just do a shallow copy.
147
+ if isinstance(inherit_preset, bool) and inherit_preset:
148
+ inherit_preset = preset_name
149
+
150
+ if inherit_preset and inherit_preset in presets:
151
+ current_preset = presets[inherit_preset].copy()
152
+ current_preset.update(kwargs)
153
+ else:
154
+ current_preset = kwargs
155
+ presets[preset_name] = current_preset
156
+ return _ArgPresets(presets)
157
+
158
+ def get_preset(self, preset_name: str) -> Dict[str, Any]:
159
+ return self._presets.get(preset_name, {})
160
+
161
+
162
+ @contextlib.contextmanager
163
+ def preset_args(
164
+ kwargs: Dict[str, Any],
165
+ *,
166
+ preset_name: str = 'global',
167
+ inherit_preset: Union[str, bool] = False
168
+ ) -> Iterator[Dict[str, Any]]:
169
+ """Context manager to enable calling with user kwargs.
170
+
171
+ Args:
172
+ kwargs: The preset kwargs to be used by preset-enabled functions within
173
+ the context.
174
+ preset_name: The name of the preset to specify kwargs.
175
+ `enable_preset_args` allows users to pass a preset name, which will be
176
+ used to identify the present to be used.
177
+ inherit_preset: The name of the preset defined by the parent context to
178
+ inherit kwargs from. Or a boolean to indicate whether to inherit a
179
+ parent preset of the same name.
180
+
181
+ Yields:
182
+ Current preset kwargs.
183
+ """
184
+
185
+ parent_presets = utils.thread_local_peek(
186
+ _TLS_KEY_PRESET_KWARGS, _ArgPresets()
187
+ )
188
+ current_preset = parent_presets.derive(kwargs, preset_name, inherit_preset)
189
+ utils.thread_local_push(_TLS_KEY_PRESET_KWARGS, current_preset)
190
+ try:
191
+ yield current_preset
192
+ finally:
193
+ utils.thread_local_pop(_TLS_KEY_PRESET_KWARGS, None)
194
+
195
+
196
+ def enable_preset_args(
197
+ include_all_preset_kwargs: bool = False,
198
+ preset_name: str = 'global'
199
+ ) -> Callable[[types.FunctionType], types.FunctionType]:
200
+ """Decorator for functions that maybe use preset argument values.
201
+
202
+ Usage::
203
+
204
+ @pg.typing.enable_preset_args
205
+ def foo(x, y=pg.typing.PresetArgValue(default=1)):
206
+ return x + y
207
+
208
+ with pg.typing.preset_args(y=2):
209
+ print(foo(x=1)) # 3: y=2
210
+ print(foo(x=1)) # 2: y=1
211
+
212
+ Args:
213
+ include_all_preset_kwargs: Whether to include all preset kwargs (even
214
+ not makred as `PresetArgValue`) when callng the function.
215
+ preset_name: The name of the preset to specify kwargs.
216
+
217
+ Returns:
218
+ A decorated function that could consume the preset argument values.
219
+ """
220
+ def decorator(func):
221
+ sig = inspect.signature(func)
222
+ positional_arg_names = [
223
+ p.name for p in sig.parameters.values()
224
+ if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
225
+ ]
226
+
227
+ arg_defaults = {}
228
+ has_preset_value = False
229
+ has_varkw = False
230
+ for p in sig.parameters.values():
231
+ if p.kind == inspect.Parameter.VAR_KEYWORD:
232
+ has_varkw = True
233
+ continue
234
+ if p.kind == inspect.Parameter.VAR_POSITIONAL:
235
+ continue
236
+ if p.default == inspect.Parameter.empty:
237
+ continue
238
+ if isinstance(p.default, PresetArgValue):
239
+ has_preset_value = True
240
+ arg_defaults[p.name] = p.default
241
+
242
+ if has_preset_value:
243
+ @functools.wraps(func)
244
+ def _func(*args, **kwargs):
245
+ # Map positional arguments to keyword arguments.
246
+ presets = utils.thread_local_peek(_TLS_KEY_PRESET_KWARGS, None)
247
+ preset_kwargs = presets.get_preset(preset_name) if presets else {}
248
+ args, kwargs = PresetArgValue.resolve_args(
249
+ args, kwargs, positional_arg_names, arg_defaults, preset_kwargs,
250
+ include_all_preset_kwargs=include_all_preset_kwargs and has_varkw
251
+ )
252
+ return func(*args, **kwargs)
253
+ return _func
254
+ return func
255
+ return decorator
256
+
257
+
21
258
  class CallableWithOptionalKeywordArgs:
22
259
  """Helper class for invoking callable objects with optional keyword args.
23
260
 
@@ -31,8 +268,9 @@ class CallableWithOptionalKeywordArgs:
31
268
  def __init__(self,
32
269
  func: Callable[..., Any],
33
270
  optional_keywords: List[str]):
34
- sig = callable_signature.get_signature(func)
35
- absent_keywords = None
271
+ sig = callable_signature.signature(
272
+ func, auto_typing=False, auto_doc=False
273
+ )
36
274
 
37
275
  # Check for variable keyword arguments.
38
276
  if sig.has_varkw: