langfun 0.1.2.dev202501150804__py3-none-any.whl → 0.1.2.dev202501160804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/core/__init__.py CHANGED
@@ -75,7 +75,6 @@ from langfun.core.sampling import random_sample
75
75
  from langfun.core.concurrent import RetryEntry
76
76
  from langfun.core.concurrent import concurrent_execute
77
77
  from langfun.core.concurrent import concurrent_map
78
- from langfun.core.concurrent import with_context_access
79
78
  from langfun.core.concurrent import with_retry
80
79
 
81
80
  # Interface for natural language formattable.
langfun/core/component.py CHANGED
@@ -13,10 +13,7 @@
13
13
  # limitations under the License.
14
14
  """langfun Component."""
15
15
 
16
- import contextlib
17
- import dataclasses
18
- import threading
19
- from typing import Annotated, Any, ContextManager, Iterator, Type
16
+ from typing import ContextManager
20
17
  import pyglove as pg
21
18
 
22
19
 
@@ -24,22 +21,9 @@ import pyglove as pg
24
21
  RAISE_IF_HAS_ERROR = (pg.MISSING_VALUE,)
25
22
 
26
23
 
27
- class Component(pg.Object):
24
+ class Component(pg.ContextualObject):
28
25
  """Base class for langfun components."""
29
26
 
30
- # Override __repr__ format to use inferred values when available.
31
- __repr_format_kwargs__ = dict(
32
- compact=True,
33
- use_inferred=True,
34
- )
35
-
36
- # Override __str__ format to use inferred values when available.
37
- __str_format_kwargs__ = dict(
38
- compact=False,
39
- verbose=False,
40
- use_inferred=True,
41
- )
42
-
43
27
  # Allow symbolic assignment, which invalidates the object and recomputes
44
28
  # states upon update.
45
29
  allow_symbolic_assignment = True
@@ -75,122 +59,23 @@ class Component(pg.Object):
75
59
  if additional_fields:
76
60
  cls.update_schema(additional_fields)
77
61
 
78
- def _on_bound(self):
79
- super()._on_bound()
80
- self._tls = threading.local()
81
-
82
- def _sym_inferred(self, key: str, **kwargs):
83
- """Override to allow attribute to access scoped value.
84
-
85
- Args:
86
- key: attribute name.
87
- **kwargs: Optional keyword arguments for value inference.
88
-
89
- Returns:
90
- The value of the symbolic attribute. If not available, returns the
91
- default value.
92
-
93
- Raises:
94
- AttributeError: If the attribute does not exist or contextual attribute
95
- is not ready.
96
- """
97
- if key not in self._sym_attributes:
98
- raise AttributeError(key)
99
-
100
- # Step 1: Try use value from `self.override`.
101
- # The reason is that `self.override` is short-lived and explicitly specified
102
- # by the user in scenarios like `LangFunc.render`, which should not be
103
- # affected by `lf.context`.
104
- v = _get_scoped_value(self._tls, _CONTEXT_OVERRIDES, key)
105
- if v is not None:
106
- return v.value
107
-
108
- # Step 2: Try use value from `lf.context` with `override_attrs`.
109
- # This gives users a chance to override the bound attributes of components
110
- # from the top, allowing change of bindings without modifying the code
111
- # that produces the components.
112
- override = get_contextual_override(key)
113
- if override and override.override_attrs:
114
- return override.value
115
-
116
- # Step 3: Try use value from the symbolic tree, starting from self to
117
- # the root of the tree.
118
- # Step 4: If the value is not present, use the value from `context()` (
119
- # override_attrs=False).
120
- # Step 5: Otherwise use the default value from `ContextualAttribute`.
121
- return super()._sym_inferred(key, context_override=override, **kwargs)
122
-
123
- def override(
124
- self, **kwargs) -> ContextManager[dict[str, 'ContextualOverride']]:
125
- """Context manager to override the attributes of this component."""
126
- vs = {k: ContextualOverride(v) for k, v in kwargs.items()}
127
- return _contextual_scope(self._tls, _CONTEXT_OVERRIDES, **vs)
128
-
129
- def __getattribute__(self, name: str) -> Any:
130
- """Override __getattribute__ to deal with class attribute override."""
131
- if not name.startswith('_') and hasattr(self.__class__, name):
132
- tls = self.__dict__.get('_tls', None)
133
- if tls is not None:
134
- v = _get_scoped_value(tls, _CONTEXT_OVERRIDES, name)
135
- if v is not None:
136
- return v.value
137
- return super().__getattribute__(name)
138
-
139
-
140
- _global_tls = threading.local()
141
-
142
- _CONTEXT_OVERRIDES = 'context_overrides'
143
-
144
62
 
145
- @dataclasses.dataclass(frozen=True)
146
- class ContextualOverride:
147
- """Value marker for contextual override for an attribute."""
63
+ # Aliases from PyGlove for ease of access.
64
+ context = pg.contextual_override
65
+ get_contextual_override = pg.utils.get_contextual_override
66
+ context_value = pg.utils.contextual_value
67
+ all_contextual_values = pg.utils.all_contextual_values
68
+ contextual = pg.contextual_attribute
148
69
 
149
- # Overridden value.
150
- value: Any
151
-
152
- # If True, this override will apply to both current scope and nested scope,
153
- # meaning current `lf.context` will take precedence over all nested
154
- # `lf.context` on this attribute.
155
- cascade: bool = False
156
-
157
- # If True, this override will apply to attributes that already have values.
158
- override_attrs: bool = False
159
-
160
-
161
- def context(
162
- *,
163
- cascade: bool = False,
164
- override_attrs: bool = False,
165
- **variables,
166
- ) -> ContextManager[dict[str, ContextualOverride]]:
167
- """Context manager to provide overriden values for contextual attributes.
168
-
169
- Args:
170
- cascade: If True, this override will apply to both current scope and nested
171
- scope, meaning that this `lf.context` will take precedence over all
172
- nested `lf.context` on the overriden variables.
173
- override_attrs: If True, this override will apply to attributes that already
174
- have values. Otherwise overridden variables will only be used for
175
- contextual attributes whose values are not present.
176
- **variables: Key/values as override for contextual attributes.
177
-
178
- Returns:
179
- A dict of attribute names to their contextual overrides.
180
- """
181
- vs = {}
182
- for k, v in variables.items():
183
- if not isinstance(v, ContextualOverride):
184
- v = ContextualOverride(v, cascade, override_attrs)
185
- vs[k] = v
186
- return _contextual_scope(_global_tls, _CONTEXT_OVERRIDES, **vs)
70
+ # Decorator for setting the positional arguments for Component.
71
+ use_init_args = pg.use_init_args
187
72
 
188
73
 
189
74
  def use_settings(
190
75
  *,
191
76
  cascade: bool = False,
192
77
  **settings,
193
- ) -> ContextManager[dict[str, ContextualOverride]]:
78
+ ) -> ContextManager[dict[str, pg.utils.ContextualOverride]]:
194
79
  """Shortcut method for overriding component attributes.
195
80
 
196
81
  Args:
@@ -203,150 +88,3 @@ def use_settings(
203
88
  A dict of attribute names to their contextual overrides.
204
89
  """
205
90
  return context(cascade=cascade, override_attrs=True, **settings)
206
-
207
-
208
- def get_contextual_override(var_name: str) -> ContextualOverride | None:
209
- """Returns the overriden contextual value in current scope."""
210
- return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
211
-
212
-
213
- def context_value(var_name: str, default: Any = RAISE_IF_HAS_ERROR) -> Any:
214
- """Returns the value of a variable defined in `lf.context`."""
215
- override = get_contextual_override(var_name)
216
- if override is None:
217
- if default == RAISE_IF_HAS_ERROR:
218
- raise KeyError(f'{var_name!r} does not exist in current context.')
219
- return default
220
- return override.value
221
-
222
-
223
- def all_contextual_values() -> dict[str, Any]:
224
- """Returns all contextual values provided from `lf.context` in scope."""
225
- overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
226
- return {k: v.value for k, v in overrides.items()}
227
-
228
-
229
- @contextlib.contextmanager
230
- def _contextual_scope(
231
- tls: threading.local, tls_key, **variables
232
- ) -> Iterator[dict[str, ContextualOverride]]:
233
- """Context manager to set variables within a scope."""
234
- previous_values = getattr(tls, tls_key, {})
235
- current_values = dict(previous_values)
236
- for k, v in variables.items():
237
- old_v = current_values.get(k, None)
238
- if old_v and old_v.cascade:
239
- v = old_v
240
- current_values[k] = v
241
- try:
242
- setattr(tls, tls_key, current_values)
243
- yield current_values
244
- finally:
245
- setattr(tls, tls_key, previous_values)
246
-
247
-
248
- def _get_scoped_value(
249
- tls: threading.local, tls_key: str, var_name: str, default: Any = None
250
- ) -> ContextualOverride:
251
- """Gets the value for requested variable from current scope."""
252
- scoped_values = getattr(tls, tls_key, {})
253
- return scoped_values.get(var_name, default)
254
-
255
-
256
- class ContextualAttribute(
257
- pg.symbolic.ValueFromParentChain, pg.views.HtmlTreeView.Extension
258
- ):
259
- """Attributes whose values are inferred from the context of the component.
260
-
261
- Please see go/langfun-component#attribute-value-retrieval for details.
262
- """
263
-
264
- NO_DEFAULT = (pg.MISSING_VALUE,)
265
-
266
- type: Annotated[Type[Any] | None, 'An optional type constraint.'] = None
267
-
268
- default: Any = NO_DEFAULT
269
-
270
- def value_from(
271
- self,
272
- parent,
273
- *,
274
- context_override: ContextualOverride | None = None,
275
- **kwargs,
276
- ):
277
- if parent not in (None, self.sym_parent) and isinstance(parent, Component):
278
- # Apply original search logic along the component containing chain.
279
- return super().value_from(parent, **kwargs)
280
- elif parent is None:
281
- # When there is no value inferred from the symbolic tree.
282
- # Search context override, and then attribute-level default.
283
- if context_override:
284
- return context_override.value
285
- if self.default == ContextualAttribute.NO_DEFAULT:
286
- return pg.MISSING_VALUE
287
- return self.default
288
- else:
289
- return pg.MISSING_VALUE
290
-
291
- def _html_tree_view_content(
292
- self,
293
- *,
294
- view: pg.views.HtmlTreeView,
295
- parent: Any = None,
296
- root_path: pg.KeyPath | None = None,
297
- **kwargs,
298
- ) -> pg.Html:
299
- inferred_value = pg.MISSING_VALUE
300
- if isinstance(parent, pg.Symbolic) and root_path:
301
- inferred_value = parent.sym_inferred(root_path.key, pg.MISSING_VALUE)
302
-
303
- if inferred_value is not pg.MISSING_VALUE:
304
- kwargs.pop('name', None)
305
- return view.render(
306
- inferred_value, parent=self,
307
- root_path=pg.KeyPath('<inferred>', root_path),
308
- **view.get_passthrough_kwargs(**kwargs)
309
- )
310
- return pg.Html.element(
311
- 'div',
312
- [
313
- '(not available)',
314
- ],
315
- css_classes=['unavailable-contextual'],
316
- )
317
-
318
- def _html_tree_view_config(self) -> dict[str, Any]:
319
- return pg.views.HtmlTreeView.get_kwargs(
320
- super()._html_tree_view_config(),
321
- dict(
322
- collapse_level=1,
323
- )
324
- )
325
-
326
- @classmethod
327
- def _html_tree_view_css_styles(cls) -> list[str]:
328
- return super()._html_tree_view_css_styles() + [
329
- """
330
- .contextual-attribute {
331
- color: purple;
332
- }
333
- .unavailable-contextual {
334
- color: gray;
335
- font-style: italic;
336
- }
337
- """
338
- ]
339
-
340
-
341
- # NOTE(daiyip): Returning Any instead of `lf.ContextualAttribute` to avoid
342
- # pytype check error as `contextual()` can be assigned to any type.
343
- def contextual(
344
- type: Type[Any] | None = None, # pylint: disable=redefined-builtin
345
- default: Any = ContextualAttribute.NO_DEFAULT,
346
- ) -> Any:
347
- """Value marker for a contextual attribute."""
348
- return ContextualAttribute(type=type, default=default, allow_partial=True)
349
-
350
-
351
- # Decorator for setting the positional arguments for Component.
352
- use_init_args = pg.use_init_args
@@ -73,26 +73,7 @@ class ComponentContextTest(unittest.TestCase):
73
73
  self.assertEqual(a1.y, 2)
74
74
  self.assertEqual(a1.z, -1)
75
75
 
76
- with lf.context(x=3, y=3, z=3) as parent_override:
77
- self.assertEqual(
78
- parent_override,
79
- dict(
80
- x=lf.ContextualOverride(3, cascade=False, override_attrs=False),
81
- y=lf.ContextualOverride(3, cascade=False, override_attrs=False),
82
- z=lf.ContextualOverride(3, cascade=False, override_attrs=False),
83
- ),
84
- )
85
- self.assertEqual(
86
- lf.get_contextual_override('y'),
87
- lf.ContextualOverride(3, cascade=False, override_attrs=False),
88
- )
89
- self.assertEqual(lf.context_value('x'), 3)
90
- self.assertIsNone(lf.context_value('f', None))
91
- with self.assertRaisesRegex(KeyError, '.* does not exist'):
92
- lf.context_value('f')
93
-
94
- self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
95
-
76
+ with lf.context(x=3, y=3, z=3):
96
77
  # Member attributes take precedence over `lf.context`.
97
78
  self.assertEqual(a1.x, 1)
98
79
  self.assertEqual(a1.y, 2)
@@ -109,15 +90,7 @@ class ComponentContextTest(unittest.TestCase):
109
90
  self.assertEqual(a1.z, 3)
110
91
 
111
92
  # Test nested contextual override with override_attrs=True (default).
112
- with lf.context(y=4, z=4, override_attrs=True) as nested_override:
113
- self.assertEqual(
114
- nested_override,
115
- dict(
116
- x=lf.ContextualOverride(3, cascade=False, override_attrs=False),
117
- y=lf.ContextualOverride(4, cascade=False, override_attrs=True),
118
- z=lf.ContextualOverride(4, cascade=False, override_attrs=True),
119
- ),
120
- )
93
+ with lf.context(y=4, z=4, override_attrs=True):
121
94
 
122
95
  # Member attribute is not overriden as current scope does not override
123
96
  # `x``.
@@ -25,7 +25,6 @@ import threading
25
25
  import time
26
26
  from typing import Annotated, Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
27
27
 
28
- from langfun.core import component
29
28
  import pyglove as pg
30
29
 
31
30
 
@@ -39,18 +38,6 @@ except ImportError:
39
38
  tqdm = None
40
39
 
41
40
 
42
- def with_context_access(func: Callable[..., Any]) -> Callable[..., Any]:
43
- """Derives a user function with the access to the current context."""
44
- with component.context() as current_context:
45
- pass
46
-
47
- def _func(*args, **kwargs) -> Any:
48
- with component.context(**current_context):
49
- return func(*args, **kwargs)
50
-
51
- return _func
52
-
53
-
54
41
  class RetryError(RuntimeError):
55
42
  """Retry error."""
56
43
 
@@ -249,7 +236,8 @@ def concurrent_execute(
249
236
  try:
250
237
  executed_jobs = list(
251
238
  executor.map(
252
- lambda job: job(), [with_context_access(job) for job in jobs]
239
+ lambda job: job(),
240
+ [pg.with_contextual_override(job) for job in jobs]
253
241
  )
254
242
  )
255
243
  for job in executed_jobs:
@@ -736,9 +724,7 @@ def concurrent_map(
736
724
  retry_interval=retry_interval,
737
725
  exponential_backoff=exponential_backoff,
738
726
  )
739
- future = executor.submit(
740
- with_context_access(job),
741
- )
727
+ future = executor.submit(pg.with_contextual_override(job))
742
728
  pending_futures.append(future)
743
729
  future_to_job[future] = job
744
730
  total += 1
@@ -29,22 +29,6 @@ class A(component.Component):
29
29
  y: int = component.contextual()
30
30
 
31
31
 
32
- class WithContextAccessTest(unittest.TestCase):
33
-
34
- def test_context_access(self):
35
- inputs = [A(1), A(2)]
36
- with futures.ThreadPoolExecutor() as executor:
37
- with component.context(y=3):
38
- self.assertEqual(
39
- list(
40
- executor.map(
41
- concurrent.with_context_access(lambda x: x.y), inputs
42
- )
43
- ),
44
- [3, 3],
45
- )
46
-
47
-
48
32
  class RetryErrorTest(unittest.TestCase):
49
33
 
50
34
  def test_basics(self):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: langfun
3
- Version: 0.1.2.dev202501150804
3
+ Version: 0.1.2.dev202501160804
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -1,9 +1,9 @@
1
1
  langfun/__init__.py,sha256=fhfPXpHN7GoGqixpFfqhQkYxFs_siP_LhbjZhd3lhio,2497
2
- langfun/core/__init__.py,sha256=oo81OUA4Xy2q7icxcxAMicVzbla5gMDf9Nzw_iIAkis,4627
3
- langfun/core/component.py,sha256=HVrEoTL1Y01iqOHC3FYdbAOnffqfHHtGJXoK1vkdEwo,11583
4
- langfun/core/component_test.py,sha256=sG-T2wpvBfHqWGZE7sc4NayJj2aj5QFBzSwFiwrGEIc,10376
5
- langfun/core/concurrent.py,sha256=4CqOlH9YrtrK-Jijp09hLX7AYxRYL8slpy-tL2aqnIM,33121
6
- langfun/core/concurrent_test.py,sha256=0ucEYxPgTErOOK1TIN-TSqZphHS1Vl1VssyQP49cpOI,17997
2
+ langfun/core/__init__.py,sha256=Cy3BU8R9VHnJmbxiN9QXXBb1K09ZoESYIJGf_uwJrDs,4571
3
+ langfun/core/component.py,sha256=g1kQM0bryYYYWVDrSMnHfc74wIBbpfe5_B3s-UIP5GE,3028
4
+ langfun/core/component_test.py,sha256=0CxTgjAud3aj8wBauFhG2FHDqrxCTl4OI4gzQTad-40,9254
5
+ langfun/core/concurrent.py,sha256=zY-pXqlGqss_GI20tM1gXvyW8QepVPUuFNmutcIdhbI,32760
6
+ langfun/core/concurrent_test.py,sha256=rc5T-2giWgtbwNuN6gmei7Uwo66HsJeeRtXZCpya_QU,17590
7
7
  langfun/core/console.py,sha256=V_mOiFi9oGh8gLsUeR56pdFDkuvYOpvQt7DY1KUTWTA,2535
8
8
  langfun/core/console_test.py,sha256=pBOcuNMJdVELywvroptfcRtJMsegMm3wSlHAL2TdxVk,1679
9
9
  langfun/core/langfunc.py,sha256=G50YgoVZ0y1GFw2ev41MlOqr6qa8YakbvNC0h_E0PiA,11140
@@ -146,8 +146,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
146
146
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
147
147
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
148
148
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
149
- langfun-0.1.2.dev202501150804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
150
- langfun-0.1.2.dev202501150804.dist-info/METADATA,sha256=HAqG3RU6S4kwv8Knx7vDVp01x8Kuc7It1QnYg0oeol8,8172
151
- langfun-0.1.2.dev202501150804.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
152
- langfun-0.1.2.dev202501150804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
153
- langfun-0.1.2.dev202501150804.dist-info/RECORD,,
149
+ langfun-0.1.2.dev202501160804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
150
+ langfun-0.1.2.dev202501160804.dist-info/METADATA,sha256=_XM3ancZIb8-33gpRxLKmdJOBZsMfd1_2-4otzha19Q,8172
151
+ langfun-0.1.2.dev202501160804.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
152
+ langfun-0.1.2.dev202501160804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
153
+ langfun-0.1.2.dev202501160804.dist-info/RECORD,,