langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +17 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -63,6 +63,42 @@ class FunctionGenerationTest(unittest.TestCase):
|
|
63
63
|
|
64
64
|
lm = fake.StaticSequence([unittest_lm_response, function_gen_lm_response])
|
65
65
|
|
66
|
+
@function_generation.function_gen(lm=lm, unittest='auto')
|
67
|
+
def linear_search(items, target): # pylint: disable=unused-argument
|
68
|
+
"""Performs a linear search on a list to find a target value.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
items (list): The list to search within.
|
72
|
+
target: The value to search for.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
int: The index of the target value if found, otherwise -1.
|
76
|
+
"""
|
77
|
+
|
78
|
+
self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
|
79
|
+
self.assertEqual(linear_search.source(), function_gen_lm_response)
|
80
|
+
|
81
|
+
def test_generate_function_without_unittest(self):
|
82
|
+
function_gen_lm_response = inspect.cleandoc("""
|
83
|
+
def linear_search(items, target):
|
84
|
+
\"\"\"
|
85
|
+
Performs a linear search on a list to find a target value.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
items (list): The list to search within.
|
89
|
+
target: The value to search for.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
int: The index of the target value if found, otherwise -1.
|
93
|
+
\"\"\"
|
94
|
+
for i, item in enumerate(items):
|
95
|
+
if item == target:
|
96
|
+
return i
|
97
|
+
return -1
|
98
|
+
""")
|
99
|
+
|
100
|
+
lm = fake.StaticSequence([function_gen_lm_response])
|
101
|
+
|
66
102
|
@function_generation.function_gen(lm=lm)
|
67
103
|
def linear_search(items, target): # pylint: disable=unused-argument
|
68
104
|
"""Performs a linear search on a list to find a target value.
|
@@ -258,7 +294,9 @@ class FunctionGenerationTest(unittest.TestCase):
|
|
258
294
|
cache_file = os.path.join(cache_file_dir, 'cache_file.json')
|
259
295
|
|
260
296
|
@function_generation.function_gen(
|
261
|
-
lm=lm,
|
297
|
+
lm=lm,
|
298
|
+
unittest=_unittest_fn,
|
299
|
+
cache_filename=cache_file,
|
262
300
|
)
|
263
301
|
def linear_search(items, target): # pylint: disable=unused-argument
|
264
302
|
"""Performs a linear search on a list to find a target value.
|
@@ -273,6 +311,36 @@ class FunctionGenerationTest(unittest.TestCase):
|
|
273
311
|
|
274
312
|
self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
|
275
313
|
|
314
|
+
def test_context_passthrough(self):
|
315
|
+
|
316
|
+
class Number(pg.Object):
|
317
|
+
value: int
|
318
|
+
|
319
|
+
function_gen_lm_response = inspect.cleandoc("""
|
320
|
+
```python
|
321
|
+
def add(a: Number, b: Number) -> Number:
|
322
|
+
\"\"\"Adds two numbers together.\"\"\"
|
323
|
+
return Number(a.value + b.value)
|
324
|
+
```
|
325
|
+
""")
|
326
|
+
|
327
|
+
lm = fake.StaticSequence(
|
328
|
+
[function_gen_lm_response]
|
329
|
+
)
|
330
|
+
|
331
|
+
def _unittest_fn(func):
|
332
|
+
assert func(Number(1), Number(2)) == Number(3)
|
333
|
+
|
334
|
+
custom_unittest = _unittest_fn
|
335
|
+
|
336
|
+
@function_generation.function_gen(
|
337
|
+
lm=lm, unittest=custom_unittest, num_retries=1
|
338
|
+
)
|
339
|
+
def add(a: Number, b: Number) -> Number: # pylint: disable=unused-argument
|
340
|
+
"""Adds two numbers together."""
|
341
|
+
|
342
|
+
self.assertEqual(add(Number(2), Number(3)), Number(5))
|
343
|
+
|
276
344
|
def test_siganture_check(self):
|
277
345
|
incorrect_signature_lm_response = inspect.cleandoc("""
|
278
346
|
```python
|
@@ -310,7 +378,9 @@ class FunctionGenerationTest(unittest.TestCase):
|
|
310
378
|
|
311
379
|
custom_unittest = _unittest_fn
|
312
380
|
|
313
|
-
@function_generation.function_gen(
|
381
|
+
@function_generation.function_gen(
|
382
|
+
lm=lm, unittest=custom_unittest, num_retries=2
|
383
|
+
)
|
314
384
|
def linear_search(items, target): # pylint: disable=unused-argument
|
315
385
|
"""Performs a linear search on a list to find a target value.
|
316
386
|
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""The base of symbolic mapping methods."""
|
15
15
|
|
16
|
+
import functools
|
16
17
|
import io
|
17
18
|
from typing import Annotated, Any, Callable
|
18
19
|
import langfun.core as lf
|
@@ -45,20 +46,22 @@ class MappingError(Exception): # pylint: disable=g-bad-exception-name
|
|
45
46
|
r = io.StringIO()
|
46
47
|
error_message = str(self.cause).rstrip()
|
47
48
|
r.write(
|
48
|
-
|
49
|
+
pg.colored(
|
49
50
|
f'{self.cause.__class__.__name__}: {error_message}', 'magenta'
|
50
51
|
)
|
51
52
|
)
|
52
53
|
if include_lm_response:
|
53
54
|
r.write('\n\n')
|
54
|
-
r.write(
|
55
|
+
r.write(pg.colored('[LM Response]', 'blue', styles=['bold']))
|
55
56
|
r.write('\n')
|
56
|
-
r.write(
|
57
|
+
r.write(pg.colored(self.lm_response.text, 'blue'))
|
57
58
|
return r.getvalue()
|
58
59
|
|
59
60
|
|
60
61
|
@pg.use_init_args(['input', 'output', 'schema', 'context'])
|
61
|
-
class MappingExample(lf.NaturalLanguageFormattable,
|
62
|
+
class MappingExample(lf.NaturalLanguageFormattable,
|
63
|
+
lf.Component,
|
64
|
+
pg.views.HtmlTreeView.Extension):
|
62
65
|
"""Mapping example between text, schema and structured value."""
|
63
66
|
|
64
67
|
input: pg.typing.Annotated[
|
@@ -92,6 +95,15 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
|
|
92
95
|
'The natural language context for this mapping. ',
|
93
96
|
] = None
|
94
97
|
|
98
|
+
metadata: Annotated[
|
99
|
+
dict[str, Any],
|
100
|
+
(
|
101
|
+
'The metadata associated with the mapping example, '
|
102
|
+
'which chould carry structured data, such as tool function input. '
|
103
|
+
'It is a `pg.Dict` object whose keys can be accessed by attributes.'
|
104
|
+
),
|
105
|
+
] = pg.Dict()
|
106
|
+
|
95
107
|
def schema_repr(
|
96
108
|
self, protocol: schema_lib.SchemaProtocol = 'python', **kwargs
|
97
109
|
) -> str:
|
@@ -107,7 +119,11 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
|
|
107
119
|
|
108
120
|
@classmethod
|
109
121
|
def value_repr(
|
110
|
-
cls,
|
122
|
+
cls,
|
123
|
+
value: Any,
|
124
|
+
protocol: schema_lib.SchemaProtocol = 'python',
|
125
|
+
use_modality_ref: bool = False,
|
126
|
+
**kwargs
|
111
127
|
) -> str:
|
112
128
|
if isinstance(value, str):
|
113
129
|
return value
|
@@ -116,7 +132,7 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
|
|
116
132
|
return str(value)
|
117
133
|
|
118
134
|
# Placehold modalities if they are present.
|
119
|
-
if pg.contains(value, type=lf.Modality):
|
135
|
+
if use_modality_ref and pg.contains(value, type=lf.Modality):
|
120
136
|
value = lf.ModalityRef.placehold(value)
|
121
137
|
return schema_lib.value_repr(protocol).repr(value, **kwargs)
|
122
138
|
|
@@ -147,24 +163,83 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
|
|
147
163
|
def natural_language_format(self) -> str:
|
148
164
|
result = io.StringIO()
|
149
165
|
if self.context:
|
150
|
-
result.write(
|
151
|
-
result.write(
|
166
|
+
result.write(pg.colored('[CONTEXT]\n', styles=['bold']))
|
167
|
+
result.write(pg.colored(self.context, color='magenta'))
|
152
168
|
result.write('\n\n')
|
153
169
|
|
154
|
-
result.write(
|
155
|
-
result.write(
|
156
|
-
result.write('\n\n')
|
170
|
+
result.write(pg.colored('[INPUT]\n', styles=['bold']))
|
171
|
+
result.write(pg.colored(self.input_repr(), color='green'))
|
157
172
|
|
158
173
|
if self.schema is not None:
|
159
|
-
result.write(lf.colored('[SCHEMA]\n', styles=['bold']))
|
160
|
-
result.write(lf.colored(self.schema_repr(), color='red'))
|
161
174
|
result.write('\n\n')
|
175
|
+
result.write(pg.colored('[SCHEMA]\n', styles=['bold']))
|
176
|
+
result.write(pg.colored(self.schema_repr(), color='red'))
|
162
177
|
|
163
178
|
if schema_lib.MISSING != self.output:
|
164
|
-
result.write(
|
165
|
-
result.write(
|
179
|
+
result.write('\n\n')
|
180
|
+
result.write(pg.colored('[OUTPUT]\n', styles=['bold']))
|
181
|
+
result.write(pg.colored(self.output_repr(), color='blue'))
|
182
|
+
|
183
|
+
if self.metadata:
|
184
|
+
result.write('\n\n')
|
185
|
+
result.write(pg.colored('[METADATA]\n', styles=['bold']))
|
186
|
+
result.write(pg.colored(str(self.metadata), color='cyan'))
|
166
187
|
return result.getvalue().strip()
|
167
188
|
|
189
|
+
@classmethod
|
190
|
+
@functools.cache
|
191
|
+
def _html_tree_view_config(cls) -> dict[str, Any]:
|
192
|
+
|
193
|
+
def render_value(view, *, value, **kwargs):
|
194
|
+
if isinstance(value, lf.Template):
|
195
|
+
# Make a shallow copy to make sure modalities are rooted by
|
196
|
+
# the input.
|
197
|
+
value = value.clone().render()
|
198
|
+
if value is None:
|
199
|
+
return None
|
200
|
+
return view.render(value, **kwargs)
|
201
|
+
|
202
|
+
return pg.views.HtmlTreeView.get_kwargs(
|
203
|
+
super()._html_tree_view_config(),
|
204
|
+
dict(
|
205
|
+
include_keys=['input', 'output', 'context', 'schema', 'metadata'],
|
206
|
+
extra_flags=dict(
|
207
|
+
render_value_fn=render_value,
|
208
|
+
),
|
209
|
+
child_config=dict(
|
210
|
+
input=dict(
|
211
|
+
collapse_level=1,
|
212
|
+
),
|
213
|
+
output=dict(
|
214
|
+
css_classes=['lf-example-output'],
|
215
|
+
collapse_level=1,
|
216
|
+
),
|
217
|
+
schema=dict(
|
218
|
+
css_classes=['lf-example-schema'],
|
219
|
+
collapse_level=1,
|
220
|
+
),
|
221
|
+
metadata=dict(
|
222
|
+
css_classes=['lf-example-metadata'],
|
223
|
+
collapse_level=1,
|
224
|
+
),
|
225
|
+
),
|
226
|
+
)
|
227
|
+
)
|
228
|
+
|
229
|
+
@classmethod
|
230
|
+
@functools.cache
|
231
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
232
|
+
return super()._html_tree_view_css_styles() + [
|
233
|
+
"""
|
234
|
+
.lf-example-output {
|
235
|
+
color: dodgerblue;
|
236
|
+
}
|
237
|
+
.lf-example-schema {
|
238
|
+
color: blue;
|
239
|
+
}
|
240
|
+
"""
|
241
|
+
]
|
242
|
+
|
168
243
|
|
169
244
|
class Mapping(lf.LangFunc):
|
170
245
|
"""Base class for mapping.
|
@@ -243,13 +318,7 @@ class Mapping(lf.LangFunc):
|
|
243
318
|
{{ input_title }}:
|
244
319
|
{{ example.input_repr(protocol, compact=False) | indent(2, True) }}
|
245
320
|
|
246
|
-
{% if
|
247
|
-
{{ modality_refs_title }}:
|
248
|
-
{{ modality_refs_repr(example.input) | indent(2, True) }}
|
249
|
-
|
250
|
-
{% endif -%}
|
251
|
-
|
252
|
-
{%- if example.schema -%}
|
321
|
+
{% if example.schema -%}
|
253
322
|
{{ schema_title }}:
|
254
323
|
{{ example.schema_repr(protocol) | indent(2, True) }}
|
255
324
|
|
@@ -270,10 +339,6 @@ class Mapping(lf.LangFunc):
|
|
270
339
|
|
271
340
|
schema_title: Annotated[str, 'The section title for schema.'] = 'SCHEMA'
|
272
341
|
|
273
|
-
modality_refs_title: Annotated[
|
274
|
-
str, 'The section title for modality refs.'
|
275
|
-
] = 'MODALITY_REFERENCES'
|
276
|
-
|
277
342
|
protocol: Annotated[
|
278
343
|
schema_lib.SchemaProtocol,
|
279
344
|
'The protocol for representing the schema and value.',
|
@@ -344,6 +409,12 @@ class Mapping(lf.LangFunc):
|
|
344
409
|
lm_output = self.postprocess_response(lm_output)
|
345
410
|
lm_output.result = self.postprocess_result(self.parse_result(lm_output))
|
346
411
|
except Exception as e: # pylint: disable=broad-exception-caught
|
412
|
+
if (self.lm.cache is not None
|
413
|
+
and lm_output.lm_input.cache_seed is not None):
|
414
|
+
success = self.lm.cache.delete(
|
415
|
+
self.lm, lm_output.lm_input, lm_output.lm_input.cache_seed
|
416
|
+
)
|
417
|
+
assert success
|
347
418
|
if self.default == lf.RAISE_IF_HAS_ERROR:
|
348
419
|
raise MappingError(lm_output, e) from e
|
349
420
|
lm_output.result = self.default
|
@@ -378,24 +449,3 @@ class Mapping(lf.LangFunc):
|
|
378
449
|
"""Gets additional symbol definitions besides schema as globals."""
|
379
450
|
return {'ModalityRef': lf.modality.ModalityRef}
|
380
451
|
|
381
|
-
#
|
382
|
-
# Helper methods for handling modalities.
|
383
|
-
#
|
384
|
-
|
385
|
-
def has_modality_refs(self, value: Any) -> bool:
|
386
|
-
"""Returns true if the value has modalities."""
|
387
|
-
return not isinstance(value, lf.Modality) and pg.contains(
|
388
|
-
value, type=lf.Modality
|
389
|
-
)
|
390
|
-
|
391
|
-
def modalities(self, value: Any) -> dict[str, lf.Modality]:
|
392
|
-
return lf.Modality.from_value(value)
|
393
|
-
|
394
|
-
def modality_refs_repr(self, value: Any) -> str:
|
395
|
-
with lf.modality.format_modality_as_ref(True):
|
396
|
-
return pg.format(
|
397
|
-
self.modalities(value),
|
398
|
-
compact=False,
|
399
|
-
verbose=False,
|
400
|
-
python_format=True,
|
401
|
-
)
|
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Tests for structured mapping example."""
|
15
15
|
|
16
16
|
import inspect
|
17
|
+
from typing import Any
|
17
18
|
import unittest
|
18
19
|
|
19
20
|
import langfun.core as lf
|
@@ -28,11 +29,11 @@ class MappingErrorTest(unittest.TestCase):
|
|
28
29
|
lf.AIMessage('hi'), ValueError('Cannot parse message.')
|
29
30
|
)
|
30
31
|
self.assertEqual(
|
31
|
-
|
32
|
+
pg.decolor(str(error)),
|
32
33
|
'ValueError: Cannot parse message.\n\n[LM Response]\nhi',
|
33
34
|
)
|
34
35
|
self.assertEqual(
|
35
|
-
|
36
|
+
pg.decolor(error.format(include_lm_response=False)),
|
36
37
|
'ValueError: Cannot parse message.',
|
37
38
|
)
|
38
39
|
|
@@ -129,6 +130,33 @@ class MappingExampleTest(unittest.TestCase):
|
|
129
130
|
"""),
|
130
131
|
)
|
131
132
|
|
133
|
+
def test_str_with_metadata(self):
|
134
|
+
self.assertEqual(
|
135
|
+
str(
|
136
|
+
mapping.MappingExample(
|
137
|
+
'1 + 1 = 2',
|
138
|
+
schema=int,
|
139
|
+
context='Give the answer.',
|
140
|
+
metadata={'foo': 'bar'},
|
141
|
+
)
|
142
|
+
),
|
143
|
+
inspect.cleandoc("""
|
144
|
+
\x1b[1m[CONTEXT]
|
145
|
+
\x1b[0m\x1b[35mGive the answer.\x1b[0m
|
146
|
+
|
147
|
+
\x1b[1m[INPUT]
|
148
|
+
\x1b[0m\x1b[32m1 + 1 = 2\x1b[0m
|
149
|
+
|
150
|
+
\x1b[1m[SCHEMA]
|
151
|
+
\x1b[0m\x1b[31mint\x1b[0m
|
152
|
+
|
153
|
+
\x1b[1m[METADATA]
|
154
|
+
\x1b[0m\x1b[36m{
|
155
|
+
foo = 'bar'
|
156
|
+
}\x1b[0m
|
157
|
+
"""),
|
158
|
+
)
|
159
|
+
|
132
160
|
def test_serialization(self):
|
133
161
|
example = mapping.MappingExample(
|
134
162
|
'the answer is 2', 2, int, context='compute 1 + 1'
|
@@ -137,6 +165,66 @@ class MappingExampleTest(unittest.TestCase):
|
|
137
165
|
pg.eq(pg.from_json_str(example.to_json_str()), example)
|
138
166
|
)
|
139
167
|
|
168
|
+
def assert_html_content(self, html, expected):
|
169
|
+
expected = inspect.cleandoc(expected).strip()
|
170
|
+
actual = html.content.strip()
|
171
|
+
if actual != expected:
|
172
|
+
print(actual)
|
173
|
+
self.assertEqual(actual, expected)
|
174
|
+
|
175
|
+
def test_html(self):
|
176
|
+
|
177
|
+
class Answer(pg.Object):
|
178
|
+
answer: int
|
179
|
+
|
180
|
+
class Addition(lf.Template):
|
181
|
+
"""Template Addition.
|
182
|
+
|
183
|
+
{{x}} + {{y}} = ?
|
184
|
+
"""
|
185
|
+
x: Any
|
186
|
+
y: Any
|
187
|
+
|
188
|
+
example = mapping.MappingExample(
|
189
|
+
input=Addition(x=1, y=2),
|
190
|
+
schema=Answer,
|
191
|
+
context='compute 1 + 1',
|
192
|
+
output=Answer(answer=3),
|
193
|
+
metadata={'foo': 'bar'},
|
194
|
+
)
|
195
|
+
self.assert_html_content(
|
196
|
+
example.to_html(
|
197
|
+
enable_summary_tooltip=False,
|
198
|
+
extra_flags=dict(
|
199
|
+
include_message_metadata=False
|
200
|
+
)
|
201
|
+
),
|
202
|
+
"""
|
203
|
+
<details open class="pyglove mapping-example"><summary><div class="summary-title">MappingExample(...)</div></summary><div class="complex-value mapping-example"><details open class="pyglove user-message lf-message"><summary><div class="summary-name lf-message">input<span class="tooltip lf-message">input</span></div><div class="summary-title lf-message">UserMessage(...)</div></summary><div class="complex_value"><div class="message-tags"><span>rendered</span></div><div class="message-text">1 + 2 = ?</div></div></details><details open class="pyglove answer lf-example-output"><summary><div class="summary-name lf-example-output">output<span class="tooltip lf-example-output">output</span></div><div class="summary-title lf-example-output">Answer(...)</div></summary><div class="complex-value answer"><details open class="pyglove int"><summary><div class="summary-name">answer<span class="tooltip">output.answer</span></div><div class="summary-title">int</div></summary><span class="simple-value int">3</span></details></div></details><details open class="pyglove str"><summary><div class="summary-name">context<span class="tooltip">context</span></div><div class="summary-title">str</div></summary><span class="simple-value str">'compute 1 + 1'</span></details><details open class="pyglove schema lf-example-schema"><summary><div class="summary-name lf-example-schema">schema<span class="tooltip lf-example-schema">schema</span></div><div class="summary-title lf-example-schema">Schema(...)</div></summary><div class="lf-schema-definition">Answer
|
204
|
+
|
205
|
+
```python
|
206
|
+
class Answer:
|
207
|
+
answer: int
|
208
|
+
```</div></details><details open class="pyglove dict lf-example-metadata"><summary><div class="summary-name lf-example-metadata">metadata<span class="tooltip lf-example-metadata">metadata</span></div><div class="summary-title lf-example-metadata">Dict(...)</div></summary><div class="complex-value dict"><details open class="pyglove str"><summary><div class="summary-name">foo<span class="tooltip">metadata.foo</span></div><div class="summary-title">str</div></summary><span class="simple-value str">'bar'</span></details></div></details></div></details>
|
209
|
+
"""
|
210
|
+
)
|
211
|
+
|
212
|
+
example = mapping.MappingExample(
|
213
|
+
input=Addition(x=1, y=2),
|
214
|
+
output=Answer(answer=3),
|
215
|
+
)
|
216
|
+
self.assert_html_content(
|
217
|
+
example.to_html(
|
218
|
+
enable_summary_tooltip=False,
|
219
|
+
extra_flags=dict(
|
220
|
+
include_message_metadata=False
|
221
|
+
)
|
222
|
+
),
|
223
|
+
"""
|
224
|
+
<details open class="pyglove mapping-example"><summary><div class="summary-title">MappingExample(...)</div></summary><div class="complex-value mapping-example"><details open class="pyglove user-message lf-message"><summary><div class="summary-name lf-message">input<span class="tooltip lf-message">input</span></div><div class="summary-title lf-message">UserMessage(...)</div></summary><div class="complex_value"><div class="message-tags"><span>rendered</span></div><div class="message-text">1 + 2 = ?</div></div></details><details open class="pyglove answer lf-example-output"><summary><div class="summary-name lf-example-output">output<span class="tooltip lf-example-output">output</span></div><div class="summary-title lf-example-output">Answer(...)</div></summary><div class="complex-value answer"><details open class="pyglove int"><summary><div class="summary-name">answer<span class="tooltip">output.answer</span></div><div class="summary-title">int</div></summary><span class="simple-value int">3</span></details></div></details><details open class="pyglove contextual-attribute lf-example-schema"><summary><div class="summary-name lf-example-schema">schema<span class="tooltip lf-example-schema">schema</span></div><div class="summary-title lf-example-schema">ContextualAttribute(...)</div></summary><span class="simple-value none-type">None</span></details><details open class="pyglove dict lf-example-metadata"><summary><div class="summary-name lf-example-metadata">metadata<span class="tooltip lf-example-metadata">metadata</span></div><div class="summary-title lf-example-metadata">Dict(...)</div></summary><div class="complex-value dict"><span class="empty-container"></span></div></details></div></details>
|
225
|
+
"""
|
226
|
+
)
|
227
|
+
|
140
228
|
|
141
229
|
if __name__ == '__main__':
|
142
230
|
unittest.main()
|
@@ -16,13 +16,13 @@ from typing import Any, Callable, Type, Union
|
|
16
16
|
|
17
17
|
import langfun.core as lf
|
18
18
|
from langfun.core.structured import mapping
|
19
|
-
from langfun.core.structured import
|
19
|
+
from langfun.core.structured import querying
|
20
20
|
from langfun.core.structured import schema as schema_lib
|
21
21
|
import pyglove as pg
|
22
22
|
|
23
23
|
|
24
24
|
@lf.use_init_args(['schema', 'default', 'examples'])
|
25
|
-
class
|
25
|
+
class _ParseStructure(mapping.Mapping):
|
26
26
|
"""Parse an object out from a natural language text."""
|
27
27
|
|
28
28
|
context_title = 'USER_REQUEST'
|
@@ -37,7 +37,7 @@ class ParseStructure(mapping.Mapping):
|
|
37
37
|
]
|
38
38
|
|
39
39
|
|
40
|
-
class
|
40
|
+
class _ParseStructureJson(_ParseStructure):
|
41
41
|
"""Parse an object out from a NL text using JSON as the protocol."""
|
42
42
|
|
43
43
|
preamble = """
|
@@ -53,7 +53,7 @@ class ParseStructureJson(ParseStructure):
|
|
53
53
|
output_title = 'JSON'
|
54
54
|
|
55
55
|
|
56
|
-
class
|
56
|
+
class _ParseStructurePython(_ParseStructure):
|
57
57
|
"""Parse an object out from a NL text using Python as the protocol."""
|
58
58
|
|
59
59
|
preamble = """
|
@@ -87,7 +87,7 @@ def parse(
|
|
87
87
|
returns_message: bool = False,
|
88
88
|
**kwargs,
|
89
89
|
) -> Any:
|
90
|
-
"""Parse a natural
|
90
|
+
"""Parse a natural language message based on schema.
|
91
91
|
|
92
92
|
Examples:
|
93
93
|
|
@@ -270,29 +270,41 @@ def call(
|
|
270
270
|
if schema in (str, None):
|
271
271
|
return lm_output if returns_message else lm_output.text
|
272
272
|
|
273
|
+
def _chain_nl_output_message(parsing_message: lf.Message):
|
274
|
+
"""Chain the source of the parsed output to the LM output."""
|
275
|
+
parsing_message.root.source = lm_output
|
276
|
+
parsing_message.tag('parsing-lm-output')
|
277
|
+
parsing_message.lm_input.tag('parsing-lm-input')
|
278
|
+
|
273
279
|
# Call `parsing_lm` for structured parsing.
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
280
|
+
try:
|
281
|
+
parsing_message = querying.query(
|
282
|
+
lm_output.text,
|
283
|
+
schema,
|
284
|
+
examples=parsing_examples,
|
285
|
+
lm=parsing_lm or lm,
|
286
|
+
include_context=parsing_include_context,
|
287
|
+
cache_seed=cache_seed,
|
288
|
+
autofix=autofix,
|
289
|
+
autofix_lm=autofix_lm or lm,
|
290
|
+
protocol=protocol,
|
291
|
+
returns_message=True,
|
292
|
+
**kwargs,
|
293
|
+
)
|
294
|
+
_chain_nl_output_message(parsing_message)
|
295
|
+
except mapping.MappingError as e:
|
296
|
+
_chain_nl_output_message(e.lm_response)
|
297
|
+
raise e
|
298
|
+
return parsing_message if returns_message else parsing_message.result
|
287
299
|
|
288
300
|
|
289
301
|
def _parse_structure_cls(
|
290
302
|
protocol: schema_lib.SchemaProtocol,
|
291
|
-
) -> Type[
|
303
|
+
) -> Type[_ParseStructure]:
|
292
304
|
if protocol == 'json':
|
293
|
-
return
|
305
|
+
return _ParseStructureJson
|
294
306
|
elif protocol == 'python':
|
295
|
-
return
|
307
|
+
return _ParseStructurePython
|
296
308
|
else:
|
297
309
|
raise ValueError(f'Unknown protocol: {protocol!r}.')
|
298
310
|
|
@@ -37,7 +37,7 @@ class Itinerary(pg.Object):
|
|
37
37
|
class ParseStructurePythonTest(unittest.TestCase):
|
38
38
|
|
39
39
|
def test_render_no_examples(self):
|
40
|
-
l = parsing.
|
40
|
+
l = parsing._ParseStructurePython(int)
|
41
41
|
m = lf.AIMessage('12 / 6 + 2 = 4')
|
42
42
|
self.assertEqual(
|
43
43
|
l.render(input=m, context='Compute 12 / 6 + 2.').text,
|
@@ -62,7 +62,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
62
62
|
)
|
63
63
|
|
64
64
|
def test_render_no_context(self):
|
65
|
-
l = parsing.
|
65
|
+
l = parsing._ParseStructurePython(int)
|
66
66
|
m = lf.AIMessage('12 / 6 + 2 = 4')
|
67
67
|
|
68
68
|
self.assertEqual(
|
@@ -85,7 +85,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
85
85
|
)
|
86
86
|
|
87
87
|
def test_render(self):
|
88
|
-
l = parsing.
|
88
|
+
l = parsing._ParseStructurePython(
|
89
89
|
int,
|
90
90
|
examples=[
|
91
91
|
mapping.MappingExample(
|
@@ -212,7 +212,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
212
212
|
),
|
213
213
|
override_attrs=True,
|
214
214
|
):
|
215
|
-
l = parsing.
|
215
|
+
l = parsing._ParseStructurePython(
|
216
216
|
[Itinerary],
|
217
217
|
examples=[
|
218
218
|
mapping.MappingExample(
|
@@ -285,7 +285,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
285
285
|
self.assertEqual(
|
286
286
|
r,
|
287
287
|
lf.AIMessage(
|
288
|
-
'1', score=1.0, result=1, logprobs=None,
|
288
|
+
'1', score=1.0, result=1, logprobs=None, is_cached=False,
|
289
289
|
usage=lf.LMSamplingUsage(652, 1, 653),
|
290
290
|
tags=['lm-response', 'lm-output', 'transformed']
|
291
291
|
),
|
@@ -295,7 +295,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
295
295
|
class ParseStructureJsonTest(unittest.TestCase):
|
296
296
|
|
297
297
|
def test_render_no_examples(self):
|
298
|
-
l = parsing.
|
298
|
+
l = parsing._ParseStructureJson(int)
|
299
299
|
m = lf.AIMessage('12 / 6 + 2 = 4')
|
300
300
|
self.assertEqual(
|
301
301
|
l.render(input=m, context='Compute 12 / 6 + 2.').text,
|
@@ -320,7 +320,7 @@ class ParseStructureJsonTest(unittest.TestCase):
|
|
320
320
|
)
|
321
321
|
|
322
322
|
def test_render_no_context(self):
|
323
|
-
l = parsing.
|
323
|
+
l = parsing._ParseStructureJson(int)
|
324
324
|
m = lf.AIMessage('12 / 6 + 2 = 4')
|
325
325
|
|
326
326
|
self.assertEqual(
|
@@ -343,7 +343,7 @@ class ParseStructureJsonTest(unittest.TestCase):
|
|
343
343
|
)
|
344
344
|
|
345
345
|
def test_render(self):
|
346
|
-
l = parsing.
|
346
|
+
l = parsing._ParseStructureJson(
|
347
347
|
int,
|
348
348
|
examples=[
|
349
349
|
mapping.MappingExample(
|
@@ -504,7 +504,7 @@ class ParseStructureJsonTest(unittest.TestCase):
|
|
504
504
|
override_attrs=True,
|
505
505
|
):
|
506
506
|
message = lf.LangFunc(lm_input)()
|
507
|
-
l = parsing.
|
507
|
+
l = parsing._ParseStructureJson(
|
508
508
|
[Itinerary],
|
509
509
|
examples=[
|
510
510
|
mapping.MappingExample(
|
@@ -645,6 +645,7 @@ class CallTest(unittest.TestCase):
|
|
645
645
|
result=3,
|
646
646
|
score=1.0,
|
647
647
|
logprobs=None,
|
648
|
+
is_cached=False,
|
648
649
|
usage=lf.LMSamplingUsage(315, 1, 316),
|
649
650
|
tags=['lm-response', 'lm-output', 'transformed']
|
650
651
|
),
|
@@ -669,6 +670,49 @@ class CallTest(unittest.TestCase):
|
|
669
670
|
3,
|
670
671
|
)
|
671
672
|
|
673
|
+
def test_call_with_parsing_message_chaining(self):
|
674
|
+
output = parsing.call(
|
675
|
+
'Compute 1 + 2',
|
676
|
+
int,
|
677
|
+
lm=fake.StaticSequence(['three']),
|
678
|
+
parsing_lm=fake.StaticSequence(['3']),
|
679
|
+
parsing_examples=[
|
680
|
+
mapping.MappingExample(
|
681
|
+
context='Multiple four and five',
|
682
|
+
input='twenty',
|
683
|
+
schema=int,
|
684
|
+
output=20,
|
685
|
+
)
|
686
|
+
],
|
687
|
+
returns_message=True,
|
688
|
+
)
|
689
|
+
self.assertIn('parsing-lm-output', output.tags)
|
690
|
+
self.assertIn('parsing-lm-input', output.source.tags)
|
691
|
+
self.assertEqual(output.root.text, 'Compute 1 + 2')
|
692
|
+
|
693
|
+
def test_call_with_parsing_message_chaining_on_parsing_error(self):
|
694
|
+
try:
|
695
|
+
output = parsing.call(
|
696
|
+
'Compute 1 + 2',
|
697
|
+
int,
|
698
|
+
lm=fake.StaticSequence(['three']),
|
699
|
+
parsing_lm=fake.StaticSequence(['abc']),
|
700
|
+
parsing_examples=[
|
701
|
+
mapping.MappingExample(
|
702
|
+
context='Multiple four and five',
|
703
|
+
input='twenty',
|
704
|
+
schema=int,
|
705
|
+
output=20,
|
706
|
+
)
|
707
|
+
],
|
708
|
+
returns_message=True,
|
709
|
+
)
|
710
|
+
except mapping.MappingError as e:
|
711
|
+
output = e.lm_response
|
712
|
+
self.assertIn('parsing-lm-output', output.tags)
|
713
|
+
self.assertIn('parsing-lm-input', output.source.tags)
|
714
|
+
self.assertEqual(output.root.text, 'Compute 1 + 2')
|
715
|
+
|
672
716
|
def test_call_with_autofix(self):
|
673
717
|
lm = fake.StaticSequence(
|
674
718
|
[
|