langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -13,205 +13,120 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Python code parsing."""
|
15
15
|
|
16
|
-
import ast
|
17
16
|
import inspect
|
18
17
|
import io
|
19
18
|
import re
|
20
19
|
|
21
|
-
import langfun.core as lf
|
22
|
-
from langfun.core.coding.python import errors
|
23
|
-
from langfun.core.coding.python import permissions
|
24
|
-
|
25
|
-
|
26
|
-
class PythonCodeParser(lf.Component):
|
27
|
-
"""Python code parser with permission control."""
|
28
|
-
|
29
|
-
_ID_REGEX = re.compile('^[a-zA-Z_\\-]*$')
|
30
|
-
|
31
|
-
class _CodeValidator(ast.NodeVisitor):
|
32
|
-
"""Python AST node visitor for ensuring code are permitted."""
|
33
|
-
|
34
|
-
def __init__(self, code: str, permission: permissions.CodePermission):
|
35
|
-
super().__init__()
|
36
|
-
self.code = code
|
37
|
-
self.permission = permission
|
38
|
-
|
39
|
-
def verify(
|
40
|
-
self,
|
41
|
-
node,
|
42
|
-
flag: permissions.CodePermission,
|
43
|
-
node_type,
|
44
|
-
error_message: str,
|
45
|
-
) -> None:
|
46
|
-
if isinstance(node, node_type) and not (self.permission & flag):
|
47
|
-
raise SyntaxError(
|
48
|
-
error_message, (
|
49
|
-
'<generated-code>',
|
50
|
-
node.lineno,
|
51
|
-
node.col_offset,
|
52
|
-
self._code_line(node.lineno),
|
53
|
-
node.end_lineno,
|
54
|
-
node.end_col_offset,
|
55
|
-
))
|
56
|
-
|
57
|
-
def _code_line(self, lineno):
|
58
|
-
return self.code.split('\n')[lineno - 1]
|
59
|
-
|
60
|
-
def generic_visit(self, node):
|
61
|
-
self.verify(
|
62
|
-
node,
|
63
|
-
permissions.CodePermission.CONDITION,
|
64
|
-
(ast.If, ast.Match),
|
65
|
-
'Condition is not allowed.',
|
66
|
-
)
|
67
|
-
|
68
|
-
self.verify(
|
69
|
-
node,
|
70
|
-
permissions.CodePermission.LOOP,
|
71
|
-
(ast.For, ast.While, ast.AsyncFor, ast.AsyncWith),
|
72
|
-
'Loop is not allowed.',
|
73
|
-
)
|
74
|
-
|
75
|
-
self.verify(
|
76
|
-
node,
|
77
|
-
permissions.CodePermission.EXCEPTION,
|
78
|
-
(ast.Try, ast.Raise, ast.Assert),
|
79
|
-
'Exception is not allowed.',
|
80
|
-
)
|
81
|
-
|
82
|
-
self.verify(
|
83
|
-
node,
|
84
|
-
permissions.CodePermission.CLASS_DEFINITION,
|
85
|
-
ast.ClassDef,
|
86
|
-
'Class definition is not allowed.',
|
87
|
-
)
|
88
|
-
|
89
|
-
self.verify(
|
90
|
-
node,
|
91
|
-
permissions.CodePermission.FUNCTION_DEFINITION,
|
92
|
-
(
|
93
|
-
ast.FunctionDef,
|
94
|
-
ast.AsyncFunctionDef,
|
95
|
-
ast.Return,
|
96
|
-
ast.Yield,
|
97
|
-
ast.YieldFrom,
|
98
|
-
),
|
99
|
-
'Function definition is not allowed.',
|
100
|
-
)
|
101
|
-
|
102
|
-
self.verify(
|
103
|
-
node,
|
104
|
-
permissions.CodePermission.IMPORT,
|
105
|
-
(ast.Import, ast.ImportFrom),
|
106
|
-
'`import` is not allowed.',
|
107
|
-
)
|
108
|
-
|
109
|
-
super().generic_visit(node)
|
110
|
-
|
111
|
-
def parse(
|
112
|
-
self, code: str, permission: permissions.CodePermission
|
113
|
-
) -> tuple[str, ast.AST]:
|
114
|
-
code = self.clean(code)
|
115
|
-
try:
|
116
|
-
parsed_code = ast.parse(code, mode='exec')
|
117
|
-
PythonCodeParser._CodeValidator(code, permission).visit(parsed_code)
|
118
|
-
except SyntaxError as e:
|
119
|
-
raise errors.CodeError(code, e) from e
|
120
|
-
return code, parsed_code
|
121
|
-
|
122
|
-
def clean(self, code_text: str) -> str:
|
123
|
-
# TODO(daiyip): Deal with markdown in docstrings.
|
124
|
-
code = io.StringIO()
|
125
|
-
quote_char = None
|
126
|
-
in_code = False
|
127
|
-
i = 0
|
128
|
-
in_comment = False
|
129
|
-
while i < len(code_text):
|
130
|
-
c = code_text[i]
|
131
|
-
# Detect code block separator (```).
|
132
|
-
if (not in_comment
|
133
|
-
and quote_char is None
|
134
|
-
and c == '`'
|
135
|
-
and code_text[i:i + 3] == '```'):
|
136
|
-
in_code = not in_code
|
137
|
-
if in_code:
|
138
|
-
i += 3
|
139
|
-
continue
|
140
|
-
else:
|
141
|
-
break
|
142
|
-
|
143
|
-
# Detect string literal boundary.
|
144
|
-
if (in_code
|
145
|
-
and not in_comment
|
146
|
-
and c in ('\'', '"')
|
147
|
-
and i > 0
|
148
|
-
and code_text[i - 1] != '\\'):
|
149
|
-
# Handle ''' and """.
|
150
|
-
if code_text[i: i + 3] == c * 3:
|
151
|
-
c = c * 3
|
152
|
-
i += 2
|
153
|
-
|
154
|
-
if quote_char is None:
|
155
|
-
quote_char = c
|
156
|
-
elif quote_char == c:
|
157
|
-
# NOTE(daiyip): at times, LM forgets to escape quotes inside a string.
|
158
|
-
# Thus we do some smart checking here to automatically correct such
|
159
|
-
# case. This logic here is pretty involved in handling special cases.
|
160
|
-
# We might want to revisit them later.
|
161
|
-
|
162
|
-
# Peek forward to see if it could be a valid string.
|
163
|
-
nt, nnt_start = _next_token(code_text, i + 1)
|
164
|
-
if nt in (',', '[', ']', '}', ')', '+', '*', '%', '\n', ':'):
|
165
|
-
end_quote = True
|
166
|
-
elif nt == ' ':
|
167
|
-
# Detect if . could be a method invocation.
|
168
|
-
# NOTE(daiyip): 'in' and 'not in' might have false positives. But
|
169
|
-
# given the chance is low, we do not complicate the reasoning logic
|
170
|
-
# for now.
|
171
|
-
nnt, _ = _next_token(code_text, nnt_start, skip_whitespace=True)
|
172
|
-
end_quote = nnt in ('+', '*', '%', '#', '[', 'in', 'not', ':')
|
173
|
-
elif nt == '.':
|
174
|
-
# Detect if . could be method invocation on string.
|
175
|
-
nnt, nnnt_start = _next_token(code_text, nnt_start)
|
176
|
-
nnnt, _ = _next_token(code_text, nnnt_start)
|
177
|
-
end_quote = nnt.isidentifier() and nnnt == '('
|
178
|
-
else:
|
179
|
-
end_quote = False
|
180
|
-
|
181
|
-
if end_quote:
|
182
|
-
quote_char = None
|
183
|
-
else:
|
184
|
-
c = f'\\{c}'
|
185
|
-
# Detect comment.
|
186
|
-
elif c == '#' and quote_char is None:
|
187
|
-
in_comment = True
|
188
|
-
# Detect end-of-comment.
|
189
|
-
elif c == '\n':
|
190
|
-
# NOTE(daiyip): deal with cases that LM forgot to escape linebreaks
|
191
|
-
# within strings.
|
192
|
-
if quote_char is not None:
|
193
|
-
# Only add \\ for ' and " (other than ''' and """).
|
194
|
-
if len(quote_char) == 1:
|
195
|
-
c = '\\n'
|
196
|
-
else:
|
197
|
-
in_comment = False
|
198
20
|
|
21
|
+
_ID_REGEX = re.compile('^[a-zA-Z_\\-]*$')
|
22
|
+
|
23
|
+
|
24
|
+
def clean(code_text: str) -> str:
|
25
|
+
"""Cleans up Python code.
|
26
|
+
|
27
|
+
LLM may generate code with markdown annotations, as well as minor syntax
|
28
|
+
errors. This function removes such annotations and fixes minor syntax errors
|
29
|
+
without extra LLM calls.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
code_text: The code text to clean up.
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
The cleaned up code text.
|
36
|
+
"""
|
37
|
+
# TODO(daiyip): Deal with markdown in docstrings.
|
38
|
+
code = io.StringIO()
|
39
|
+
quote_char = None
|
40
|
+
in_code = False
|
41
|
+
i = 0
|
42
|
+
in_comment = False
|
43
|
+
while i < len(code_text):
|
44
|
+
c = code_text[i]
|
45
|
+
# Detect code block separator (```).
|
46
|
+
if (not in_comment
|
47
|
+
and quote_char is None
|
48
|
+
and c == '`'
|
49
|
+
and code_text[i:i + 3] == '```'):
|
50
|
+
in_code = not in_code
|
199
51
|
if in_code:
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
#
|
212
|
-
|
213
|
-
|
214
|
-
|
52
|
+
i += 3
|
53
|
+
continue
|
54
|
+
else:
|
55
|
+
break
|
56
|
+
|
57
|
+
# Detect string literal boundary.
|
58
|
+
if (in_code
|
59
|
+
and not in_comment
|
60
|
+
and c in ('\'', '"')
|
61
|
+
and i > 0
|
62
|
+
and code_text[i - 1] != '\\'):
|
63
|
+
# Handle ''' and """.
|
64
|
+
if code_text[i: i + 3] == c * 3:
|
65
|
+
c = c * 3
|
66
|
+
i += 2
|
67
|
+
|
68
|
+
if quote_char is None:
|
69
|
+
quote_char = c
|
70
|
+
elif quote_char == c:
|
71
|
+
# NOTE(daiyip): at times, LM forgets to escape quotes inside a string.
|
72
|
+
# Thus we do some smart checking here to automatically correct such
|
73
|
+
# case. This logic here is pretty involved in handling special cases.
|
74
|
+
# We might want to revisit them later.
|
75
|
+
|
76
|
+
# Peek forward to see if it could be a valid string.
|
77
|
+
nt, nnt_start = _next_token(code_text, i + 1)
|
78
|
+
if (len(c) == 3
|
79
|
+
or nt in (',', '[', ']', '}', ')', '+', '*', '%', '\n', ':')):
|
80
|
+
end_quote = True
|
81
|
+
elif nt == ' ':
|
82
|
+
# Detect if . could be a method invocation.
|
83
|
+
# NOTE(daiyip): 'in' and 'not in' might have false positives. But
|
84
|
+
# given the chance is low, we do not complicate the reasoning logic
|
85
|
+
# for now.
|
86
|
+
nnt, _ = _next_token(code_text, nnt_start, skip_whitespace=True)
|
87
|
+
end_quote = nnt in ('+', '*', '%', '#', '[', 'in', 'not', ':')
|
88
|
+
elif nt == '.':
|
89
|
+
# Detect if . could be method invocation on string.
|
90
|
+
nnt, nnnt_start = _next_token(code_text, nnt_start)
|
91
|
+
nnnt, _ = _next_token(code_text, nnnt_start)
|
92
|
+
end_quote = nnt.isidentifier() and nnnt == '('
|
93
|
+
else:
|
94
|
+
end_quote = False
|
95
|
+
|
96
|
+
if end_quote:
|
97
|
+
quote_char = None
|
98
|
+
else:
|
99
|
+
c = f'\\{c}'
|
100
|
+
# Detect comment.
|
101
|
+
elif c == '#' and quote_char is None:
|
102
|
+
in_comment = True
|
103
|
+
# Detect end-of-comment.
|
104
|
+
elif c == '\n':
|
105
|
+
# NOTE(daiyip): deal with cases that LM forgot to escape linebreaks
|
106
|
+
# within strings.
|
107
|
+
if quote_char is not None:
|
108
|
+
# Only add \\ for ' and " (other than ''' and """).
|
109
|
+
if len(quote_char) == 1:
|
110
|
+
c = '\\n'
|
111
|
+
else:
|
112
|
+
in_comment = False
|
113
|
+
|
114
|
+
if in_code:
|
115
|
+
code.write(c)
|
116
|
+
|
117
|
+
i += 1
|
118
|
+
|
119
|
+
code = code.getvalue()
|
120
|
+
if code:
|
121
|
+
pos = code.find('\n')
|
122
|
+
# Strip markdown code type. E.g. ```python
|
123
|
+
if pos > 0 and _ID_REGEX.match(code[:pos]):
|
124
|
+
code = code[pos:]
|
125
|
+
else:
|
126
|
+
# Maybe-code that resides not within a code markdown block.
|
127
|
+
# Adding '\n' makes inspect.cleandoc to make right adjustment.
|
128
|
+
code = '\n' + code_text
|
129
|
+
return inspect.cleandoc(code).strip()
|
215
130
|
|
216
131
|
|
217
132
|
def _next_token(
|
@@ -15,18 +15,16 @@
|
|
15
15
|
|
16
16
|
import inspect
|
17
17
|
import unittest
|
18
|
-
from langfun.core.coding.python import errors
|
19
18
|
from langfun.core.coding.python import parsing
|
20
|
-
from langfun.core.coding.python import permissions
|
21
19
|
|
22
20
|
|
23
|
-
class
|
21
|
+
class CleanTest(unittest.TestCase):
|
24
22
|
|
25
23
|
def assert_clean(self, code: str, cleaned_code: str, clean: bool = True):
|
26
24
|
if clean:
|
27
25
|
cleaned_code = inspect.cleandoc(cleaned_code)
|
28
26
|
self.assertEqual(
|
29
|
-
parsing.
|
27
|
+
parsing.clean(code), cleaned_code
|
30
28
|
)
|
31
29
|
|
32
30
|
def test_clean(self):
|
@@ -272,107 +270,6 @@ class PythonCodeParserTest(unittest.TestCase):
|
|
272
270
|
"""
|
273
271
|
)
|
274
272
|
|
275
|
-
def assert_allowed(self, code: str, permission: permissions.CodePermission):
|
276
|
-
_, ast = parsing.PythonCodeParser().parse(code, permission)
|
277
|
-
self.assertIsNotNone(ast)
|
278
|
-
|
279
|
-
def assert_not_allowed(
|
280
|
-
self, code: str, permission: permissions.CodePermission
|
281
|
-
):
|
282
|
-
with self.assertRaisesRegex(errors.CodeError, '.* is not allowed'):
|
283
|
-
parsing.PythonCodeParser().parse(code, permission)
|
284
|
-
|
285
|
-
def test_parse_with_allowed_code(self):
|
286
|
-
self.assert_allowed(
|
287
|
-
"""
|
288
|
-
x = y + 1
|
289
|
-
z = x + y
|
290
|
-
""",
|
291
|
-
permissions.CodePermission.BASIC,
|
292
|
-
)
|
293
|
-
self.assert_allowed(
|
294
|
-
"""
|
295
|
-
if x > 0:
|
296
|
-
print(x)
|
297
|
-
""",
|
298
|
-
permissions.CodePermission.CONDITION,
|
299
|
-
)
|
300
|
-
self.assert_allowed(
|
301
|
-
"""
|
302
|
-
for i in range(5):
|
303
|
-
print(i)
|
304
|
-
""",
|
305
|
-
permissions.CodePermission.LOOP,
|
306
|
-
)
|
307
|
-
self.assert_allowed(
|
308
|
-
"""
|
309
|
-
assert x > 1
|
310
|
-
""",
|
311
|
-
permissions.CodePermission.EXCEPTION,
|
312
|
-
)
|
313
|
-
self.assert_allowed(
|
314
|
-
"""
|
315
|
-
class A:
|
316
|
-
pass
|
317
|
-
""",
|
318
|
-
permissions.CodePermission.CLASS_DEFINITION,
|
319
|
-
)
|
320
|
-
self.assert_allowed(
|
321
|
-
"""
|
322
|
-
def foo(x, y):
|
323
|
-
return x + y
|
324
|
-
""",
|
325
|
-
permissions.CodePermission.FUNCTION_DEFINITION,
|
326
|
-
)
|
327
|
-
self.assert_allowed(
|
328
|
-
"""
|
329
|
-
import re
|
330
|
-
""",
|
331
|
-
permissions.CodePermission.IMPORT,
|
332
|
-
)
|
333
|
-
|
334
|
-
def test_parse_with_not_allowed_code(self):
|
335
|
-
self.assert_not_allowed(
|
336
|
-
"""
|
337
|
-
if x > 0:
|
338
|
-
print(x)
|
339
|
-
""",
|
340
|
-
permissions.CodePermission.BASIC,
|
341
|
-
)
|
342
|
-
self.assert_not_allowed(
|
343
|
-
"""
|
344
|
-
for i in range(5):
|
345
|
-
print(i)
|
346
|
-
""",
|
347
|
-
permissions.CodePermission.BASIC,
|
348
|
-
)
|
349
|
-
self.assert_not_allowed(
|
350
|
-
"""
|
351
|
-
assert x > 1
|
352
|
-
""",
|
353
|
-
permissions.CodePermission.BASIC,
|
354
|
-
)
|
355
|
-
self.assert_not_allowed(
|
356
|
-
"""
|
357
|
-
class A:
|
358
|
-
pass
|
359
|
-
""",
|
360
|
-
permissions.CodePermission.BASIC,
|
361
|
-
)
|
362
|
-
self.assert_not_allowed(
|
363
|
-
"""
|
364
|
-
def foo(x, y):
|
365
|
-
return x + y
|
366
|
-
""",
|
367
|
-
permissions.CodePermission.BASIC,
|
368
|
-
)
|
369
|
-
self.assert_not_allowed(
|
370
|
-
"""
|
371
|
-
import re
|
372
|
-
""",
|
373
|
-
permissions.CodePermission.BASIC,
|
374
|
-
)
|
375
|
-
|
376
273
|
|
377
274
|
if __name__ == '__main__':
|
378
275
|
unittest.main()
|
langfun/core/component.py
CHANGED
@@ -73,7 +73,7 @@ class Component(pg.Object):
|
|
73
73
|
field.value.set_default(attr_value)
|
74
74
|
additional_fields.append(field)
|
75
75
|
if additional_fields:
|
76
|
-
|
76
|
+
cls.update_schema(additional_fields)
|
77
77
|
|
78
78
|
def _on_bound(self):
|
79
79
|
super()._on_bound()
|
@@ -210,6 +210,22 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
|
|
210
210
|
return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
|
211
211
|
|
212
212
|
|
213
|
+
def context_value(var_name: str, default: Any = RAISE_IF_HAS_ERROR) -> Any:
|
214
|
+
"""Returns the value of a variable defined in `lf.context`."""
|
215
|
+
override = get_contextual_override(var_name)
|
216
|
+
if override is None:
|
217
|
+
if default == RAISE_IF_HAS_ERROR:
|
218
|
+
raise KeyError(f'{var_name!r} does not exist in current context.')
|
219
|
+
return default
|
220
|
+
return override.value
|
221
|
+
|
222
|
+
|
223
|
+
def all_contextual_values() -> dict[str, Any]:
|
224
|
+
"""Returns all contextual values provided from `lf.context` in scope."""
|
225
|
+
overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
|
226
|
+
return {k: v.value for k, v in overrides.items()}
|
227
|
+
|
228
|
+
|
213
229
|
@contextlib.contextmanager
|
214
230
|
def _contextual_scope(
|
215
231
|
tls: threading.local, tls_key, **variables
|
@@ -237,7 +253,9 @@ def _get_scoped_value(
|
|
237
253
|
return scoped_values.get(var_name, default)
|
238
254
|
|
239
255
|
|
240
|
-
class ContextualAttribute(
|
256
|
+
class ContextualAttribute(
|
257
|
+
pg.symbolic.ValueFromParentChain, pg.views.HtmlTreeView.Extension
|
258
|
+
):
|
241
259
|
"""Attributes whose values are inferred from the context of the component.
|
242
260
|
|
243
261
|
Please see go/langfun-component#attribute-value-retrieval for details.
|
@@ -270,6 +288,55 @@ class ContextualAttribute(pg.symbolic.ValueFromParentChain):
|
|
270
288
|
else:
|
271
289
|
return pg.MISSING_VALUE
|
272
290
|
|
291
|
+
def _html_tree_view_content(
|
292
|
+
self,
|
293
|
+
*,
|
294
|
+
view: pg.views.HtmlTreeView,
|
295
|
+
parent: Any = None,
|
296
|
+
root_path: pg.KeyPath | None = None,
|
297
|
+
**kwargs,
|
298
|
+
) -> pg.Html:
|
299
|
+
inferred_value = pg.MISSING_VALUE
|
300
|
+
if isinstance(parent, pg.Symbolic) and root_path:
|
301
|
+
inferred_value = parent.sym_inferred(root_path.key, pg.MISSING_VALUE)
|
302
|
+
|
303
|
+
if inferred_value is not pg.MISSING_VALUE:
|
304
|
+
kwargs.pop('name', None)
|
305
|
+
return view.render(
|
306
|
+
inferred_value, parent=self,
|
307
|
+
root_path=pg.KeyPath('<inferred>', root_path),
|
308
|
+
**view.get_passthrough_kwargs(**kwargs)
|
309
|
+
)
|
310
|
+
return pg.Html.element(
|
311
|
+
'div',
|
312
|
+
[
|
313
|
+
'(not available)',
|
314
|
+
],
|
315
|
+
css_classes=['unavailable-contextual'],
|
316
|
+
)
|
317
|
+
|
318
|
+
def _html_tree_view_config(self) -> dict[str, Any]:
|
319
|
+
return pg.views.HtmlTreeView.get_kwargs(
|
320
|
+
super()._html_tree_view_config(),
|
321
|
+
dict(
|
322
|
+
collapse_level=1,
|
323
|
+
)
|
324
|
+
)
|
325
|
+
|
326
|
+
@classmethod
|
327
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
328
|
+
return super()._html_tree_view_css_styles() + [
|
329
|
+
"""
|
330
|
+
.contextual-attribute {
|
331
|
+
color: purple;
|
332
|
+
}
|
333
|
+
.unavailable-contextual {
|
334
|
+
color: gray;
|
335
|
+
font-style: italic;
|
336
|
+
}
|
337
|
+
"""
|
338
|
+
]
|
339
|
+
|
273
340
|
|
274
341
|
# NOTE(daiyip): Returning Any instead of `lf.ContextualAttribute` to avoid
|
275
342
|
# pytype check error as `contextual()` can be assigned to any type.
|
langfun/core/component_test.py
CHANGED
@@ -13,6 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Contextual component and app test."""
|
15
15
|
|
16
|
+
import inspect
|
17
|
+
from typing import Any
|
16
18
|
import unittest
|
17
19
|
import weakref
|
18
20
|
|
@@ -84,6 +86,12 @@ class ComponentContextTest(unittest.TestCase):
|
|
84
86
|
lf.get_contextual_override('y'),
|
85
87
|
lf.ContextualOverride(3, cascade=False, override_attrs=False),
|
86
88
|
)
|
89
|
+
self.assertEqual(lf.context_value('x'), 3)
|
90
|
+
self.assertIsNone(lf.context_value('f', None))
|
91
|
+
with self.assertRaisesRegex(KeyError, '.* does not exist'):
|
92
|
+
lf.context_value('f')
|
93
|
+
|
94
|
+
self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
|
87
95
|
|
88
96
|
# Member attributes take precedence over `lf.context`.
|
89
97
|
self.assertEqual(a1.x, 1)
|
@@ -291,6 +299,52 @@ class ContextualAttributeTest(unittest.TestCase):
|
|
291
299
|
self.assertEqual(c.z, 3)
|
292
300
|
self.assertEqual(b.z, 3)
|
293
301
|
|
302
|
+
def test_to_html(self):
|
303
|
+
class A(lf.Component):
|
304
|
+
x: int = 1
|
305
|
+
y: int = lf.contextual()
|
306
|
+
|
307
|
+
def assert_content(html, expected):
|
308
|
+
expected = inspect.cleandoc(expected).strip()
|
309
|
+
actual = html.content.strip()
|
310
|
+
if actual != expected:
|
311
|
+
print(actual)
|
312
|
+
self.assertEqual(actual.strip(), expected)
|
313
|
+
|
314
|
+
self.assertIn(
|
315
|
+
inspect.cleandoc(
|
316
|
+
"""
|
317
|
+
.contextual-attribute {
|
318
|
+
color: purple;
|
319
|
+
}
|
320
|
+
.unavailable-contextual {
|
321
|
+
color: gray;
|
322
|
+
font-style: italic;
|
323
|
+
}
|
324
|
+
"""
|
325
|
+
),
|
326
|
+
A().to_html().style_section,
|
327
|
+
)
|
328
|
+
|
329
|
+
assert_content(
|
330
|
+
A().to_html(enable_summary_tooltip=False),
|
331
|
+
"""
|
332
|
+
<details open class="pyglove a"><summary><div class="summary-title">A(...)</div></summary><div class="complex-value a"><details open class="pyglove int"><summary><div class="summary-name">x<span class="tooltip">x</span></div><div class="summary-title">int</div></summary><span class="simple-value int">1</span></details><details open class="pyglove contextual-attribute"><summary><div class="summary-name">y<span class="tooltip">y</span></div><div class="summary-title">ContextualAttribute(...)</div></summary><div class="unavailable-contextual">(not available)</div></details></div></details>
|
333
|
+
"""
|
334
|
+
)
|
335
|
+
|
336
|
+
class B(lf.Component):
|
337
|
+
z: Any
|
338
|
+
y: int = 2
|
339
|
+
|
340
|
+
b = B(A())
|
341
|
+
assert_content(
|
342
|
+
b.z.to_html(enable_summary_tooltip=False),
|
343
|
+
"""
|
344
|
+
<details open class="pyglove a"><summary><div class="summary-title">A(...)</div></summary><div class="complex-value a"><details open class="pyglove int"><summary><div class="summary-name">x<span class="tooltip">x</span></div><div class="summary-title">int</div></summary><span class="simple-value int">1</span></details><details open class="pyglove contextual-attribute"><summary><div class="summary-name">y<span class="tooltip">y</span></div><div class="summary-title">ContextualAttribute(...)</div></summary><span class="simple-value int">2</span></details></div></details>
|
345
|
+
"""
|
346
|
+
)
|
347
|
+
|
294
348
|
|
295
349
|
if __name__ == '__main__':
|
296
350
|
unittest.main()
|