langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -13,205 +13,120 @@
13
13
  # limitations under the License.
14
14
  """Python code parsing."""
15
15
 
16
- import ast
17
16
  import inspect
18
17
  import io
19
18
  import re
20
19
 
21
- import langfun.core as lf
22
- from langfun.core.coding.python import errors
23
- from langfun.core.coding.python import permissions
24
-
25
-
26
- class PythonCodeParser(lf.Component):
27
- """Python code parser with permission control."""
28
-
29
- _ID_REGEX = re.compile('^[a-zA-Z_\\-]*$')
30
-
31
- class _CodeValidator(ast.NodeVisitor):
32
- """Python AST node visitor for ensuring code are permitted."""
33
-
34
- def __init__(self, code: str, permission: permissions.CodePermission):
35
- super().__init__()
36
- self.code = code
37
- self.permission = permission
38
-
39
- def verify(
40
- self,
41
- node,
42
- flag: permissions.CodePermission,
43
- node_type,
44
- error_message: str,
45
- ) -> None:
46
- if isinstance(node, node_type) and not (self.permission & flag):
47
- raise SyntaxError(
48
- error_message, (
49
- '<generated-code>',
50
- node.lineno,
51
- node.col_offset,
52
- self._code_line(node.lineno),
53
- node.end_lineno,
54
- node.end_col_offset,
55
- ))
56
-
57
- def _code_line(self, lineno):
58
- return self.code.split('\n')[lineno - 1]
59
-
60
- def generic_visit(self, node):
61
- self.verify(
62
- node,
63
- permissions.CodePermission.CONDITION,
64
- (ast.If, ast.Match),
65
- 'Condition is not allowed.',
66
- )
67
-
68
- self.verify(
69
- node,
70
- permissions.CodePermission.LOOP,
71
- (ast.For, ast.While, ast.AsyncFor, ast.AsyncWith),
72
- 'Loop is not allowed.',
73
- )
74
-
75
- self.verify(
76
- node,
77
- permissions.CodePermission.EXCEPTION,
78
- (ast.Try, ast.Raise, ast.Assert),
79
- 'Exception is not allowed.',
80
- )
81
-
82
- self.verify(
83
- node,
84
- permissions.CodePermission.CLASS_DEFINITION,
85
- ast.ClassDef,
86
- 'Class definition is not allowed.',
87
- )
88
-
89
- self.verify(
90
- node,
91
- permissions.CodePermission.FUNCTION_DEFINITION,
92
- (
93
- ast.FunctionDef,
94
- ast.AsyncFunctionDef,
95
- ast.Return,
96
- ast.Yield,
97
- ast.YieldFrom,
98
- ),
99
- 'Function definition is not allowed.',
100
- )
101
-
102
- self.verify(
103
- node,
104
- permissions.CodePermission.IMPORT,
105
- (ast.Import, ast.ImportFrom),
106
- '`import` is not allowed.',
107
- )
108
-
109
- super().generic_visit(node)
110
-
111
- def parse(
112
- self, code: str, permission: permissions.CodePermission
113
- ) -> tuple[str, ast.AST]:
114
- code = self.clean(code)
115
- try:
116
- parsed_code = ast.parse(code, mode='exec')
117
- PythonCodeParser._CodeValidator(code, permission).visit(parsed_code)
118
- except SyntaxError as e:
119
- raise errors.CodeError(code, e) from e
120
- return code, parsed_code
121
-
122
- def clean(self, code_text: str) -> str:
123
- # TODO(daiyip): Deal with markdown in docstrings.
124
- code = io.StringIO()
125
- quote_char = None
126
- in_code = False
127
- i = 0
128
- in_comment = False
129
- while i < len(code_text):
130
- c = code_text[i]
131
- # Detect code block separator (```).
132
- if (not in_comment
133
- and quote_char is None
134
- and c == '`'
135
- and code_text[i:i + 3] == '```'):
136
- in_code = not in_code
137
- if in_code:
138
- i += 3
139
- continue
140
- else:
141
- break
142
-
143
- # Detect string literal boundary.
144
- if (in_code
145
- and not in_comment
146
- and c in ('\'', '"')
147
- and i > 0
148
- and code_text[i - 1] != '\\'):
149
- # Handle ''' and """.
150
- if code_text[i: i + 3] == c * 3:
151
- c = c * 3
152
- i += 2
153
-
154
- if quote_char is None:
155
- quote_char = c
156
- elif quote_char == c:
157
- # NOTE(daiyip): at times, LM forgets to escape quotes inside a string.
158
- # Thus we do some smart checking here to automatically correct such
159
- # case. This logic here is pretty involved in handling special cases.
160
- # We might want to revisit them later.
161
-
162
- # Peek forward to see if it could be a valid string.
163
- nt, nnt_start = _next_token(code_text, i + 1)
164
- if nt in (',', '[', ']', '}', ')', '+', '*', '%', '\n', ':'):
165
- end_quote = True
166
- elif nt == ' ':
167
- # Detect if . could be a method invocation.
168
- # NOTE(daiyip): 'in' and 'not in' might have false positives. But
169
- # given the chance is low, we do not complicate the reasoning logic
170
- # for now.
171
- nnt, _ = _next_token(code_text, nnt_start, skip_whitespace=True)
172
- end_quote = nnt in ('+', '*', '%', '#', '[', 'in', 'not', ':')
173
- elif nt == '.':
174
- # Detect if . could be method invocation on string.
175
- nnt, nnnt_start = _next_token(code_text, nnt_start)
176
- nnnt, _ = _next_token(code_text, nnnt_start)
177
- end_quote = nnt.isidentifier() and nnnt == '('
178
- else:
179
- end_quote = False
180
-
181
- if end_quote:
182
- quote_char = None
183
- else:
184
- c = f'\\{c}'
185
- # Detect comment.
186
- elif c == '#' and quote_char is None:
187
- in_comment = True
188
- # Detect end-of-comment.
189
- elif c == '\n':
190
- # NOTE(daiyip): deal with cases that LM forgot to escape linebreaks
191
- # within strings.
192
- if quote_char is not None:
193
- # Only add \\ for ' and " (other than ''' and """).
194
- if len(quote_char) == 1:
195
- c = '\\n'
196
- else:
197
- in_comment = False
198
20
 
21
+ _ID_REGEX = re.compile('^[a-zA-Z_\\-]*$')
22
+
23
+
24
+ def clean(code_text: str) -> str:
25
+ """Cleans up Python code.
26
+
27
+ LLM may generate code with markdown annotations, as well as minor syntax
28
+ errors. This function removes such annotations and fixes minor syntax errors
29
+ without extra LLM calls.
30
+
31
+ Args:
32
+ code_text: The code text to clean up.
33
+
34
+ Returns:
35
+ The cleaned up code text.
36
+ """
37
+ # TODO(daiyip): Deal with markdown in docstrings.
38
+ code = io.StringIO()
39
+ quote_char = None
40
+ in_code = False
41
+ i = 0
42
+ in_comment = False
43
+ while i < len(code_text):
44
+ c = code_text[i]
45
+ # Detect code block separator (```).
46
+ if (not in_comment
47
+ and quote_char is None
48
+ and c == '`'
49
+ and code_text[i:i + 3] == '```'):
50
+ in_code = not in_code
199
51
  if in_code:
200
- code.write(c)
201
-
202
- i += 1
203
-
204
- code = code.getvalue()
205
- if code:
206
- pos = code.find('\n')
207
- # Strip markdown code type. E.g. ```python
208
- if pos > 0 and self._ID_REGEX.match(code[:pos]):
209
- code = code[pos:]
210
- else:
211
- # Maybe-code that resides not within a code markdown block.
212
- # Adding '\n' makes inspect.cleandoc to make right adjustment.
213
- code = '\n' + code_text
214
- return inspect.cleandoc(code).strip()
52
+ i += 3
53
+ continue
54
+ else:
55
+ break
56
+
57
+ # Detect string literal boundary.
58
+ if (in_code
59
+ and not in_comment
60
+ and c in ('\'', '"')
61
+ and i > 0
62
+ and code_text[i - 1] != '\\'):
63
+ # Handle ''' and """.
64
+ if code_text[i: i + 3] == c * 3:
65
+ c = c * 3
66
+ i += 2
67
+
68
+ if quote_char is None:
69
+ quote_char = c
70
+ elif quote_char == c:
71
+ # NOTE(daiyip): at times, LM forgets to escape quotes inside a string.
72
+ # Thus we do some smart checking here to automatically correct such
73
+ # case. This logic here is pretty involved in handling special cases.
74
+ # We might want to revisit them later.
75
+
76
+ # Peek forward to see if it could be a valid string.
77
+ nt, nnt_start = _next_token(code_text, i + 1)
78
+ if (len(c) == 3
79
+ or nt in (',', '[', ']', '}', ')', '+', '*', '%', '\n', ':')):
80
+ end_quote = True
81
+ elif nt == ' ':
82
+ # Detect if . could be a method invocation.
83
+ # NOTE(daiyip): 'in' and 'not in' might have false positives. But
84
+ # given the chance is low, we do not complicate the reasoning logic
85
+ # for now.
86
+ nnt, _ = _next_token(code_text, nnt_start, skip_whitespace=True)
87
+ end_quote = nnt in ('+', '*', '%', '#', '[', 'in', 'not', ':')
88
+ elif nt == '.':
89
+ # Detect if . could be method invocation on string.
90
+ nnt, nnnt_start = _next_token(code_text, nnt_start)
91
+ nnnt, _ = _next_token(code_text, nnnt_start)
92
+ end_quote = nnt.isidentifier() and nnnt == '('
93
+ else:
94
+ end_quote = False
95
+
96
+ if end_quote:
97
+ quote_char = None
98
+ else:
99
+ c = f'\\{c}'
100
+ # Detect comment.
101
+ elif c == '#' and quote_char is None:
102
+ in_comment = True
103
+ # Detect end-of-comment.
104
+ elif c == '\n':
105
+ # NOTE(daiyip): deal with cases that LM forgot to escape linebreaks
106
+ # within strings.
107
+ if quote_char is not None:
108
+ # Only add \\ for ' and " (other than ''' and """).
109
+ if len(quote_char) == 1:
110
+ c = '\\n'
111
+ else:
112
+ in_comment = False
113
+
114
+ if in_code:
115
+ code.write(c)
116
+
117
+ i += 1
118
+
119
+ code = code.getvalue()
120
+ if code:
121
+ pos = code.find('\n')
122
+ # Strip markdown code type. E.g. ```python
123
+ if pos > 0 and _ID_REGEX.match(code[:pos]):
124
+ code = code[pos:]
125
+ else:
126
+ # Maybe-code that resides not within a code markdown block.
127
+ # Adding '\n' makes inspect.cleandoc to make right adjustment.
128
+ code = '\n' + code_text
129
+ return inspect.cleandoc(code).strip()
215
130
 
216
131
 
217
132
  def _next_token(
@@ -15,18 +15,16 @@
15
15
 
16
16
  import inspect
17
17
  import unittest
18
- from langfun.core.coding.python import errors
19
18
  from langfun.core.coding.python import parsing
20
- from langfun.core.coding.python import permissions
21
19
 
22
20
 
23
- class PythonCodeParserTest(unittest.TestCase):
21
+ class CleanTest(unittest.TestCase):
24
22
 
25
23
  def assert_clean(self, code: str, cleaned_code: str, clean: bool = True):
26
24
  if clean:
27
25
  cleaned_code = inspect.cleandoc(cleaned_code)
28
26
  self.assertEqual(
29
- parsing.PythonCodeParser().clean(code), cleaned_code
27
+ parsing.clean(code), cleaned_code
30
28
  )
31
29
 
32
30
  def test_clean(self):
@@ -272,107 +270,6 @@ class PythonCodeParserTest(unittest.TestCase):
272
270
  """
273
271
  )
274
272
 
275
- def assert_allowed(self, code: str, permission: permissions.CodePermission):
276
- _, ast = parsing.PythonCodeParser().parse(code, permission)
277
- self.assertIsNotNone(ast)
278
-
279
- def assert_not_allowed(
280
- self, code: str, permission: permissions.CodePermission
281
- ):
282
- with self.assertRaisesRegex(errors.CodeError, '.* is not allowed'):
283
- parsing.PythonCodeParser().parse(code, permission)
284
-
285
- def test_parse_with_allowed_code(self):
286
- self.assert_allowed(
287
- """
288
- x = y + 1
289
- z = x + y
290
- """,
291
- permissions.CodePermission.BASIC,
292
- )
293
- self.assert_allowed(
294
- """
295
- if x > 0:
296
- print(x)
297
- """,
298
- permissions.CodePermission.CONDITION,
299
- )
300
- self.assert_allowed(
301
- """
302
- for i in range(5):
303
- print(i)
304
- """,
305
- permissions.CodePermission.LOOP,
306
- )
307
- self.assert_allowed(
308
- """
309
- assert x > 1
310
- """,
311
- permissions.CodePermission.EXCEPTION,
312
- )
313
- self.assert_allowed(
314
- """
315
- class A:
316
- pass
317
- """,
318
- permissions.CodePermission.CLASS_DEFINITION,
319
- )
320
- self.assert_allowed(
321
- """
322
- def foo(x, y):
323
- return x + y
324
- """,
325
- permissions.CodePermission.FUNCTION_DEFINITION,
326
- )
327
- self.assert_allowed(
328
- """
329
- import re
330
- """,
331
- permissions.CodePermission.IMPORT,
332
- )
333
-
334
- def test_parse_with_not_allowed_code(self):
335
- self.assert_not_allowed(
336
- """
337
- if x > 0:
338
- print(x)
339
- """,
340
- permissions.CodePermission.BASIC,
341
- )
342
- self.assert_not_allowed(
343
- """
344
- for i in range(5):
345
- print(i)
346
- """,
347
- permissions.CodePermission.BASIC,
348
- )
349
- self.assert_not_allowed(
350
- """
351
- assert x > 1
352
- """,
353
- permissions.CodePermission.BASIC,
354
- )
355
- self.assert_not_allowed(
356
- """
357
- class A:
358
- pass
359
- """,
360
- permissions.CodePermission.BASIC,
361
- )
362
- self.assert_not_allowed(
363
- """
364
- def foo(x, y):
365
- return x + y
366
- """,
367
- permissions.CodePermission.BASIC,
368
- )
369
- self.assert_not_allowed(
370
- """
371
- import re
372
- """,
373
- permissions.CodePermission.BASIC,
374
- )
375
-
376
273
 
377
274
  if __name__ == '__main__':
378
275
  unittest.main()
langfun/core/component.py CHANGED
@@ -73,7 +73,7 @@ class Component(pg.Object):
73
73
  field.value.set_default(attr_value)
74
74
  additional_fields.append(field)
75
75
  if additional_fields:
76
- pg.symbolic.update_schema(cls, additional_fields)
76
+ cls.update_schema(additional_fields)
77
77
 
78
78
  def _on_bound(self):
79
79
  super()._on_bound()
@@ -210,6 +210,16 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
210
210
  return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
211
211
 
212
212
 
213
+ def context_value(var_name: str, default: Any = RAISE_IF_HAS_ERROR) -> Any:
214
+ """Returns the value of a variable defined in `lf.context`."""
215
+ override = get_contextual_override(var_name)
216
+ if override is None:
217
+ if default == RAISE_IF_HAS_ERROR:
218
+ raise KeyError(f'{var_name!r} does not exist in current context.')
219
+ return default
220
+ return override.value
221
+
222
+
213
223
  def all_contextual_values() -> dict[str, Any]:
214
224
  """Returns all contextual values provided from `lf.context` in scope."""
215
225
  overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
@@ -243,7 +253,9 @@ def _get_scoped_value(
243
253
  return scoped_values.get(var_name, default)
244
254
 
245
255
 
246
- class ContextualAttribute(pg.symbolic.ValueFromParentChain):
256
+ class ContextualAttribute(
257
+ pg.symbolic.ValueFromParentChain, pg.views.HtmlTreeView.Extension
258
+ ):
247
259
  """Attributes whose values are inferred from the context of the component.
248
260
 
249
261
  Please see go/langfun-component#attribute-value-retrieval for details.
@@ -276,6 +288,55 @@ class ContextualAttribute(pg.symbolic.ValueFromParentChain):
276
288
  else:
277
289
  return pg.MISSING_VALUE
278
290
 
291
+ def _html_tree_view_content(
292
+ self,
293
+ *,
294
+ view: pg.views.HtmlTreeView,
295
+ parent: Any = None,
296
+ root_path: pg.KeyPath | None = None,
297
+ **kwargs,
298
+ ) -> pg.Html:
299
+ inferred_value = pg.MISSING_VALUE
300
+ if isinstance(parent, pg.Symbolic) and root_path:
301
+ inferred_value = parent.sym_inferred(root_path.key, pg.MISSING_VALUE)
302
+
303
+ if inferred_value is not pg.MISSING_VALUE:
304
+ kwargs.pop('name', None)
305
+ return view.render(
306
+ inferred_value, parent=self,
307
+ root_path=pg.KeyPath('<inferred>', root_path),
308
+ **view.get_passthrough_kwargs(**kwargs)
309
+ )
310
+ return pg.Html.element(
311
+ 'div',
312
+ [
313
+ '(not available)',
314
+ ],
315
+ css_classes=['unavailable-contextual'],
316
+ )
317
+
318
+ def _html_tree_view_config(self) -> dict[str, Any]:
319
+ return pg.views.HtmlTreeView.get_kwargs(
320
+ super()._html_tree_view_config(),
321
+ dict(
322
+ collapse_level=1,
323
+ )
324
+ )
325
+
326
+ @classmethod
327
+ def _html_tree_view_css_styles(cls) -> list[str]:
328
+ return super()._html_tree_view_css_styles() + [
329
+ """
330
+ .contextual-attribute {
331
+ color: purple;
332
+ }
333
+ .unavailable-contextual {
334
+ color: gray;
335
+ font-style: italic;
336
+ }
337
+ """
338
+ ]
339
+
279
340
 
280
341
  # NOTE(daiyip): Returning Any instead of `lf.ContextualAttribute` to avoid
281
342
  # pytype check error as `contextual()` can be assigned to any type.
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
  """Contextual component and app test."""
15
15
 
16
+ import inspect
17
+ from typing import Any
16
18
  import unittest
17
19
  import weakref
18
20
 
@@ -84,6 +86,11 @@ class ComponentContextTest(unittest.TestCase):
84
86
  lf.get_contextual_override('y'),
85
87
  lf.ContextualOverride(3, cascade=False, override_attrs=False),
86
88
  )
89
+ self.assertEqual(lf.context_value('x'), 3)
90
+ self.assertIsNone(lf.context_value('f', None))
91
+ with self.assertRaisesRegex(KeyError, '.* does not exist'):
92
+ lf.context_value('f')
93
+
87
94
  self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
88
95
 
89
96
  # Member attributes take precedence over `lf.context`.
@@ -292,6 +299,52 @@ class ContextualAttributeTest(unittest.TestCase):
292
299
  self.assertEqual(c.z, 3)
293
300
  self.assertEqual(b.z, 3)
294
301
 
302
+ def test_to_html(self):
303
+ class A(lf.Component):
304
+ x: int = 1
305
+ y: int = lf.contextual()
306
+
307
+ def assert_content(html, expected):
308
+ expected = inspect.cleandoc(expected).strip()
309
+ actual = html.content.strip()
310
+ if actual != expected:
311
+ print(actual)
312
+ self.assertEqual(actual.strip(), expected)
313
+
314
+ self.assertIn(
315
+ inspect.cleandoc(
316
+ """
317
+ .contextual-attribute {
318
+ color: purple;
319
+ }
320
+ .unavailable-contextual {
321
+ color: gray;
322
+ font-style: italic;
323
+ }
324
+ """
325
+ ),
326
+ A().to_html().style_section,
327
+ )
328
+
329
+ assert_content(
330
+ A().to_html(enable_summary_tooltip=False),
331
+ """
332
+ <details open class="pyglove a"><summary><div class="summary-title">A(...)</div></summary><div class="complex-value a"><details open class="pyglove int"><summary><div class="summary-name">x<span class="tooltip">x</span></div><div class="summary-title">int</div></summary><span class="simple-value int">1</span></details><details open class="pyglove contextual-attribute"><summary><div class="summary-name">y<span class="tooltip">y</span></div><div class="summary-title">ContextualAttribute(...)</div></summary><div class="unavailable-contextual">(not available)</div></details></div></details>
333
+ """
334
+ )
335
+
336
+ class B(lf.Component):
337
+ z: Any
338
+ y: int = 2
339
+
340
+ b = B(A())
341
+ assert_content(
342
+ b.z.to_html(enable_summary_tooltip=False),
343
+ """
344
+ <details open class="pyglove a"><summary><div class="summary-title">A(...)</div></summary><div class="complex-value a"><details open class="pyglove int"><summary><div class="summary-name">x<span class="tooltip">x</span></div><div class="summary-title">int</div></summary><span class="simple-value int">1</span></details><details open class="pyglove contextual-attribute"><summary><div class="summary-name">y<span class="tooltip">y</span></div><div class="summary-title">ContextualAttribute(...)</div></summary><span class="simple-value int">2</span></details></div></details>
345
+ """
346
+ )
347
+
295
348
 
296
349
  if __name__ == '__main__':
297
350
  unittest.main()