langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -13,15 +13,55 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""The base of symbolic mapping methods."""
|
15
15
|
|
16
|
+
import functools
|
16
17
|
import io
|
17
|
-
from typing import Annotated, Any
|
18
|
+
from typing import Annotated, Any, Callable
|
18
19
|
import langfun.core as lf
|
19
20
|
from langfun.core.structured import schema as schema_lib
|
20
21
|
import pyglove as pg
|
21
22
|
|
22
23
|
|
24
|
+
class MappingError(Exception): # pylint: disable=g-bad-exception-name
|
25
|
+
"""Mapping error."""
|
26
|
+
|
27
|
+
def __init__(self, lm_response: lf.Message, cause: Exception):
|
28
|
+
self._lm_response = lm_response
|
29
|
+
self._cause = cause
|
30
|
+
|
31
|
+
@property
|
32
|
+
def lm_response(self) -> lf.Message:
|
33
|
+
"""Returns the LM response that failed to be mapped."""
|
34
|
+
return self._lm_response
|
35
|
+
|
36
|
+
@property
|
37
|
+
def cause(self) -> Exception:
|
38
|
+
"""Returns the cause of the error."""
|
39
|
+
return self._cause
|
40
|
+
|
41
|
+
def __str__(self) -> str:
|
42
|
+
return self.format(include_lm_response=True)
|
43
|
+
|
44
|
+
def format(self, include_lm_response: bool = True) -> str:
|
45
|
+
"""Formats the mapping error."""
|
46
|
+
r = io.StringIO()
|
47
|
+
error_message = str(self.cause).rstrip()
|
48
|
+
r.write(
|
49
|
+
pg.colored(
|
50
|
+
f'{self.cause.__class__.__name__}: {error_message}', 'magenta'
|
51
|
+
)
|
52
|
+
)
|
53
|
+
if include_lm_response:
|
54
|
+
r.write('\n\n')
|
55
|
+
r.write(pg.colored('[LM Response]', 'blue', styles=['bold']))
|
56
|
+
r.write('\n')
|
57
|
+
r.write(pg.colored(self.lm_response.text, 'blue'))
|
58
|
+
return r.getvalue()
|
59
|
+
|
60
|
+
|
23
61
|
@pg.use_init_args(['input', 'output', 'schema', 'context'])
|
24
|
-
class MappingExample(lf.NaturalLanguageFormattable,
|
62
|
+
class MappingExample(lf.NaturalLanguageFormattable,
|
63
|
+
lf.Component,
|
64
|
+
pg.views.HtmlTreeView.Extension):
|
25
65
|
"""Mapping example between text, schema and structured value."""
|
26
66
|
|
27
67
|
input: pg.typing.Annotated[
|
@@ -55,6 +95,15 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
|
|
55
95
|
'The natural language context for this mapping. ',
|
56
96
|
] = None
|
57
97
|
|
98
|
+
metadata: Annotated[
|
99
|
+
dict[str, Any],
|
100
|
+
(
|
101
|
+
'The metadata associated with the mapping example, '
|
102
|
+
'which chould carry structured data, such as tool function input. '
|
103
|
+
'It is a `pg.Dict` object whose keys can be accessed by attributes.'
|
104
|
+
),
|
105
|
+
] = pg.Dict()
|
106
|
+
|
58
107
|
def schema_repr(
|
59
108
|
self, protocol: schema_lib.SchemaProtocol = 'python', **kwargs
|
60
109
|
) -> str:
|
@@ -70,7 +119,11 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
|
|
70
119
|
|
71
120
|
@classmethod
|
72
121
|
def value_repr(
|
73
|
-
cls,
|
122
|
+
cls,
|
123
|
+
value: Any,
|
124
|
+
protocol: schema_lib.SchemaProtocol = 'python',
|
125
|
+
use_modality_ref: bool = False,
|
126
|
+
**kwargs
|
74
127
|
) -> str:
|
75
128
|
if isinstance(value, str):
|
76
129
|
return value
|
@@ -79,7 +132,7 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
|
|
79
132
|
return str(value)
|
80
133
|
|
81
134
|
# Placehold modalities if they are present.
|
82
|
-
if pg.contains(value, type=lf.Modality):
|
135
|
+
if use_modality_ref and pg.contains(value, type=lf.Modality):
|
83
136
|
value = lf.ModalityRef.placehold(value)
|
84
137
|
return schema_lib.value_repr(protocol).repr(value, **kwargs)
|
85
138
|
|
@@ -110,24 +163,83 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
|
|
110
163
|
def natural_language_format(self) -> str:
|
111
164
|
result = io.StringIO()
|
112
165
|
if self.context:
|
113
|
-
result.write(
|
114
|
-
result.write(
|
166
|
+
result.write(pg.colored('[CONTEXT]\n', styles=['bold']))
|
167
|
+
result.write(pg.colored(self.context, color='magenta'))
|
115
168
|
result.write('\n\n')
|
116
169
|
|
117
|
-
result.write(
|
118
|
-
result.write(
|
119
|
-
result.write('\n\n')
|
170
|
+
result.write(pg.colored('[INPUT]\n', styles=['bold']))
|
171
|
+
result.write(pg.colored(self.input_repr(), color='green'))
|
120
172
|
|
121
173
|
if self.schema is not None:
|
122
|
-
result.write(lf.colored('[SCHEMA]\n', styles=['bold']))
|
123
|
-
result.write(lf.colored(self.schema_repr(), color='red'))
|
124
174
|
result.write('\n\n')
|
175
|
+
result.write(pg.colored('[SCHEMA]\n', styles=['bold']))
|
176
|
+
result.write(pg.colored(self.schema_repr(), color='red'))
|
125
177
|
|
126
178
|
if schema_lib.MISSING != self.output:
|
127
|
-
result.write(
|
128
|
-
result.write(
|
179
|
+
result.write('\n\n')
|
180
|
+
result.write(pg.colored('[OUTPUT]\n', styles=['bold']))
|
181
|
+
result.write(pg.colored(self.output_repr(), color='blue'))
|
182
|
+
|
183
|
+
if self.metadata:
|
184
|
+
result.write('\n\n')
|
185
|
+
result.write(pg.colored('[METADATA]\n', styles=['bold']))
|
186
|
+
result.write(pg.colored(str(self.metadata), color='cyan'))
|
129
187
|
return result.getvalue().strip()
|
130
188
|
|
189
|
+
@classmethod
|
190
|
+
@functools.cache
|
191
|
+
def _html_tree_view_config(cls) -> dict[str, Any]:
|
192
|
+
|
193
|
+
def render_value(view, *, value, **kwargs):
|
194
|
+
if isinstance(value, lf.Template):
|
195
|
+
# Make a shallow copy to make sure modalities are rooted by
|
196
|
+
# the input.
|
197
|
+
value = value.clone().render()
|
198
|
+
if value is None:
|
199
|
+
return None
|
200
|
+
return view.render(value, **kwargs)
|
201
|
+
|
202
|
+
return pg.views.HtmlTreeView.get_kwargs(
|
203
|
+
super()._html_tree_view_config(),
|
204
|
+
dict(
|
205
|
+
include_keys=['input', 'output', 'context', 'schema', 'metadata'],
|
206
|
+
extra_flags=dict(
|
207
|
+
render_value_fn=render_value,
|
208
|
+
),
|
209
|
+
child_config=dict(
|
210
|
+
input=dict(
|
211
|
+
collapse_level=1,
|
212
|
+
),
|
213
|
+
output=dict(
|
214
|
+
css_classes=['lf-example-output'],
|
215
|
+
collapse_level=1,
|
216
|
+
),
|
217
|
+
schema=dict(
|
218
|
+
css_classes=['lf-example-schema'],
|
219
|
+
collapse_level=1,
|
220
|
+
),
|
221
|
+
metadata=dict(
|
222
|
+
css_classes=['lf-example-metadata'],
|
223
|
+
collapse_level=1,
|
224
|
+
),
|
225
|
+
),
|
226
|
+
)
|
227
|
+
)
|
228
|
+
|
229
|
+
@classmethod
|
230
|
+
@functools.cache
|
231
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
232
|
+
return super()._html_tree_view_css_styles() + [
|
233
|
+
"""
|
234
|
+
.lf-example-output {
|
235
|
+
color: dodgerblue;
|
236
|
+
}
|
237
|
+
.lf-example-schema {
|
238
|
+
color: blue;
|
239
|
+
}
|
240
|
+
"""
|
241
|
+
]
|
242
|
+
|
131
243
|
|
132
244
|
class Mapping(lf.LangFunc):
|
133
245
|
"""Base class for mapping.
|
@@ -206,13 +318,7 @@ class Mapping(lf.LangFunc):
|
|
206
318
|
{{ input_title }}:
|
207
319
|
{{ example.input_repr(protocol, compact=False) | indent(2, True) }}
|
208
320
|
|
209
|
-
{% if
|
210
|
-
{{ modality_refs_title }}:
|
211
|
-
{{ modality_refs_repr(example.input) | indent(2, True) }}
|
212
|
-
|
213
|
-
{% endif -%}
|
214
|
-
|
215
|
-
{%- if example.schema -%}
|
321
|
+
{% if example.schema -%}
|
216
322
|
{{ schema_title }}:
|
217
323
|
{{ example.schema_repr(protocol) | indent(2, True) }}
|
218
324
|
|
@@ -233,10 +339,6 @@ class Mapping(lf.LangFunc):
|
|
233
339
|
|
234
340
|
schema_title: Annotated[str, 'The section title for schema.'] = 'SCHEMA'
|
235
341
|
|
236
|
-
modality_refs_title: Annotated[
|
237
|
-
str, 'The section title for modality refs.'
|
238
|
-
] = 'MODALITY_REFERENCES'
|
239
|
-
|
240
342
|
protocol: Annotated[
|
241
343
|
schema_lib.SchemaProtocol,
|
242
344
|
'The protocol for representing the schema and value.',
|
@@ -278,6 +380,14 @@ class Mapping(lf.LangFunc):
|
|
278
380
|
),
|
279
381
|
] = lf.RAISE_IF_HAS_ERROR
|
280
382
|
|
383
|
+
response_postprocess: Annotated[
|
384
|
+
Callable[[str], str] | None,
|
385
|
+
(
|
386
|
+
'A callable object that post process the raw LLM response before '
|
387
|
+
'parsing it into the output Python object.'
|
388
|
+
)
|
389
|
+
] = None
|
390
|
+
|
281
391
|
#
|
282
392
|
# Key methods for implementing specific mappings.
|
283
393
|
#
|
@@ -296,10 +406,17 @@ class Mapping(lf.LangFunc):
|
|
296
406
|
def transform_output(self, lm_output: lf.Message) -> lf.Message:
|
297
407
|
"""Transforms LM response into structure if schema is present."""
|
298
408
|
try:
|
409
|
+
lm_output = self.postprocess_response(lm_output)
|
299
410
|
lm_output.result = self.postprocess_result(self.parse_result(lm_output))
|
300
411
|
except Exception as e: # pylint: disable=broad-exception-caught
|
412
|
+
if (self.lm.cache is not None
|
413
|
+
and lm_output.lm_input.cache_seed is not None):
|
414
|
+
success = self.lm.cache.delete(
|
415
|
+
self.lm, lm_output.lm_input, lm_output.lm_input.cache_seed
|
416
|
+
)
|
417
|
+
assert success
|
301
418
|
if self.default == lf.RAISE_IF_HAS_ERROR:
|
302
|
-
raise e
|
419
|
+
raise MappingError(lm_output, e) from e
|
303
420
|
lm_output.result = self.default
|
304
421
|
return lm_output
|
305
422
|
|
@@ -316,6 +433,14 @@ class Mapping(lf.LangFunc):
|
|
316
433
|
autofix_lm=self.autofix_lm or self.lm,
|
317
434
|
)
|
318
435
|
|
436
|
+
def postprocess_response(self, response: lf.Message) -> lf.Message:
|
437
|
+
"""Post process LLM response."""
|
438
|
+
if self.response_postprocess is not None:
|
439
|
+
postprocessed_text = self.response_postprocess(response.text)
|
440
|
+
if postprocessed_text != response.text:
|
441
|
+
return lf.AIMessage(postprocessed_text, source=response)
|
442
|
+
return response
|
443
|
+
|
319
444
|
def postprocess_result(self, result: Any) -> Any:
|
320
445
|
"""Post process structured output."""
|
321
446
|
return result
|
@@ -324,24 +449,3 @@ class Mapping(lf.LangFunc):
|
|
324
449
|
"""Gets additional symbol definitions besides schema as globals."""
|
325
450
|
return {'ModalityRef': lf.modality.ModalityRef}
|
326
451
|
|
327
|
-
#
|
328
|
-
# Helper methods for handling modalities.
|
329
|
-
#
|
330
|
-
|
331
|
-
def has_modality_refs(self, value: Any) -> bool:
|
332
|
-
"""Returns true if the value has modalities."""
|
333
|
-
return not isinstance(value, lf.Modality) and pg.contains(
|
334
|
-
value, type=lf.Modality
|
335
|
-
)
|
336
|
-
|
337
|
-
def modalities(self, value: Any) -> dict[str, lf.Modality]:
|
338
|
-
return lf.Modality.from_value(value)
|
339
|
-
|
340
|
-
def modality_refs_repr(self, value: Any) -> str:
|
341
|
-
with lf.modality.format_modality_as_ref(True):
|
342
|
-
return pg.format(
|
343
|
-
self.modalities(value),
|
344
|
-
compact=False,
|
345
|
-
verbose=False,
|
346
|
-
python_format=True,
|
347
|
-
)
|
@@ -14,12 +14,30 @@
|
|
14
14
|
"""Tests for structured mapping example."""
|
15
15
|
|
16
16
|
import inspect
|
17
|
+
from typing import Any
|
17
18
|
import unittest
|
18
19
|
|
20
|
+
import langfun.core as lf
|
19
21
|
from langfun.core.structured import mapping
|
20
22
|
import pyglove as pg
|
21
23
|
|
22
24
|
|
25
|
+
class MappingErrorTest(unittest.TestCase):
|
26
|
+
|
27
|
+
def test_format(self):
|
28
|
+
error = mapping.MappingError(
|
29
|
+
lf.AIMessage('hi'), ValueError('Cannot parse message.')
|
30
|
+
)
|
31
|
+
self.assertEqual(
|
32
|
+
pg.decolor(str(error)),
|
33
|
+
'ValueError: Cannot parse message.\n\n[LM Response]\nhi',
|
34
|
+
)
|
35
|
+
self.assertEqual(
|
36
|
+
pg.decolor(error.format(include_lm_response=False)),
|
37
|
+
'ValueError: Cannot parse message.',
|
38
|
+
)
|
39
|
+
|
40
|
+
|
23
41
|
class MappingExampleTest(unittest.TestCase):
|
24
42
|
|
25
43
|
def test_basics(self):
|
@@ -112,6 +130,33 @@ class MappingExampleTest(unittest.TestCase):
|
|
112
130
|
"""),
|
113
131
|
)
|
114
132
|
|
133
|
+
def test_str_with_metadata(self):
|
134
|
+
self.assertEqual(
|
135
|
+
str(
|
136
|
+
mapping.MappingExample(
|
137
|
+
'1 + 1 = 2',
|
138
|
+
schema=int,
|
139
|
+
context='Give the answer.',
|
140
|
+
metadata={'foo': 'bar'},
|
141
|
+
)
|
142
|
+
),
|
143
|
+
inspect.cleandoc("""
|
144
|
+
\x1b[1m[CONTEXT]
|
145
|
+
\x1b[0m\x1b[35mGive the answer.\x1b[0m
|
146
|
+
|
147
|
+
\x1b[1m[INPUT]
|
148
|
+
\x1b[0m\x1b[32m1 + 1 = 2\x1b[0m
|
149
|
+
|
150
|
+
\x1b[1m[SCHEMA]
|
151
|
+
\x1b[0m\x1b[31mint\x1b[0m
|
152
|
+
|
153
|
+
\x1b[1m[METADATA]
|
154
|
+
\x1b[0m\x1b[36m{
|
155
|
+
foo = 'bar'
|
156
|
+
}\x1b[0m
|
157
|
+
"""),
|
158
|
+
)
|
159
|
+
|
115
160
|
def test_serialization(self):
|
116
161
|
example = mapping.MappingExample(
|
117
162
|
'the answer is 2', 2, int, context='compute 1 + 1'
|
@@ -120,6 +165,66 @@ class MappingExampleTest(unittest.TestCase):
|
|
120
165
|
pg.eq(pg.from_json_str(example.to_json_str()), example)
|
121
166
|
)
|
122
167
|
|
168
|
+
def assert_html_content(self, html, expected):
|
169
|
+
expected = inspect.cleandoc(expected).strip()
|
170
|
+
actual = html.content.strip()
|
171
|
+
if actual != expected:
|
172
|
+
print(actual)
|
173
|
+
self.assertEqual(actual, expected)
|
174
|
+
|
175
|
+
def test_html(self):
|
176
|
+
|
177
|
+
class Answer(pg.Object):
|
178
|
+
answer: int
|
179
|
+
|
180
|
+
class Addition(lf.Template):
|
181
|
+
"""Template Addition.
|
182
|
+
|
183
|
+
{{x}} + {{y}} = ?
|
184
|
+
"""
|
185
|
+
x: Any
|
186
|
+
y: Any
|
187
|
+
|
188
|
+
example = mapping.MappingExample(
|
189
|
+
input=Addition(x=1, y=2),
|
190
|
+
schema=Answer,
|
191
|
+
context='compute 1 + 1',
|
192
|
+
output=Answer(answer=3),
|
193
|
+
metadata={'foo': 'bar'},
|
194
|
+
)
|
195
|
+
self.assert_html_content(
|
196
|
+
example.to_html(
|
197
|
+
enable_summary_tooltip=False,
|
198
|
+
extra_flags=dict(
|
199
|
+
include_message_metadata=False
|
200
|
+
)
|
201
|
+
),
|
202
|
+
"""
|
203
|
+
<details open class="pyglove mapping-example"><summary><div class="summary-title">MappingExample(...)</div></summary><div class="complex-value mapping-example"><details open class="pyglove user-message lf-message"><summary><div class="summary-name lf-message">input<span class="tooltip lf-message">input</span></div><div class="summary-title lf-message">UserMessage(...)</div></summary><div class="complex_value"><div class="message-tags"><span>rendered</span></div><div class="message-text">1 + 2 = ?</div></div></details><details open class="pyglove answer lf-example-output"><summary><div class="summary-name lf-example-output">output<span class="tooltip lf-example-output">output</span></div><div class="summary-title lf-example-output">Answer(...)</div></summary><div class="complex-value answer"><details open class="pyglove int"><summary><div class="summary-name">answer<span class="tooltip">output.answer</span></div><div class="summary-title">int</div></summary><span class="simple-value int">3</span></details></div></details><details open class="pyglove str"><summary><div class="summary-name">context<span class="tooltip">context</span></div><div class="summary-title">str</div></summary><span class="simple-value str">'compute 1 + 1'</span></details><details open class="pyglove schema lf-example-schema"><summary><div class="summary-name lf-example-schema">schema<span class="tooltip lf-example-schema">schema</span></div><div class="summary-title lf-example-schema">Schema(...)</div></summary><div class="lf-schema-definition">Answer
|
204
|
+
|
205
|
+
```python
|
206
|
+
class Answer:
|
207
|
+
answer: int
|
208
|
+
```</div></details><details open class="pyglove dict lf-example-metadata"><summary><div class="summary-name lf-example-metadata">metadata<span class="tooltip lf-example-metadata">metadata</span></div><div class="summary-title lf-example-metadata">Dict(...)</div></summary><div class="complex-value dict"><details open class="pyglove str"><summary><div class="summary-name">foo<span class="tooltip">metadata.foo</span></div><div class="summary-title">str</div></summary><span class="simple-value str">'bar'</span></details></div></details></div></details>
|
209
|
+
"""
|
210
|
+
)
|
211
|
+
|
212
|
+
example = mapping.MappingExample(
|
213
|
+
input=Addition(x=1, y=2),
|
214
|
+
output=Answer(answer=3),
|
215
|
+
)
|
216
|
+
self.assert_html_content(
|
217
|
+
example.to_html(
|
218
|
+
enable_summary_tooltip=False,
|
219
|
+
extra_flags=dict(
|
220
|
+
include_message_metadata=False
|
221
|
+
)
|
222
|
+
),
|
223
|
+
"""
|
224
|
+
<details open class="pyglove mapping-example"><summary><div class="summary-title">MappingExample(...)</div></summary><div class="complex-value mapping-example"><details open class="pyglove user-message lf-message"><summary><div class="summary-name lf-message">input<span class="tooltip lf-message">input</span></div><div class="summary-title lf-message">UserMessage(...)</div></summary><div class="complex_value"><div class="message-tags"><span>rendered</span></div><div class="message-text">1 + 2 = ?</div></div></details><details open class="pyglove answer lf-example-output"><summary><div class="summary-name lf-example-output">output<span class="tooltip lf-example-output">output</span></div><div class="summary-title lf-example-output">Answer(...)</div></summary><div class="complex-value answer"><details open class="pyglove int"><summary><div class="summary-name">answer<span class="tooltip">output.answer</span></div><div class="summary-title">int</div></summary><span class="simple-value int">3</span></details></div></details><details open class="pyglove contextual-attribute lf-example-schema"><summary><div class="summary-name lf-example-schema">schema<span class="tooltip lf-example-schema">schema</span></div><div class="summary-title lf-example-schema">ContextualAttribute(...)</div></summary><span class="simple-value none-type">None</span></details><details open class="pyglove dict lf-example-metadata"><summary><div class="summary-name lf-example-metadata">metadata<span class="tooltip lf-example-metadata">metadata</span></div><div class="summary-title lf-example-metadata">Dict(...)</div></summary><div class="complex-value dict"><span class="empty-container"></span></div></details></div></details>
|
225
|
+
"""
|
226
|
+
)
|
227
|
+
|
123
228
|
|
124
229
|
if __name__ == '__main__':
|
125
230
|
unittest.main()
|
@@ -16,13 +16,13 @@ from typing import Any, Callable, Type, Union
|
|
16
16
|
|
17
17
|
import langfun.core as lf
|
18
18
|
from langfun.core.structured import mapping
|
19
|
-
from langfun.core.structured import
|
19
|
+
from langfun.core.structured import querying
|
20
20
|
from langfun.core.structured import schema as schema_lib
|
21
21
|
import pyglove as pg
|
22
22
|
|
23
23
|
|
24
24
|
@lf.use_init_args(['schema', 'default', 'examples'])
|
25
|
-
class
|
25
|
+
class _ParseStructure(mapping.Mapping):
|
26
26
|
"""Parse an object out from a natural language text."""
|
27
27
|
|
28
28
|
context_title = 'USER_REQUEST'
|
@@ -37,7 +37,7 @@ class ParseStructure(mapping.Mapping):
|
|
37
37
|
]
|
38
38
|
|
39
39
|
|
40
|
-
class
|
40
|
+
class _ParseStructureJson(_ParseStructure):
|
41
41
|
"""Parse an object out from a NL text using JSON as the protocol."""
|
42
42
|
|
43
43
|
preamble = """
|
@@ -53,7 +53,7 @@ class ParseStructureJson(ParseStructure):
|
|
53
53
|
output_title = 'JSON'
|
54
54
|
|
55
55
|
|
56
|
-
class
|
56
|
+
class _ParseStructurePython(_ParseStructure):
|
57
57
|
"""Parse an object out from a NL text using Python as the protocol."""
|
58
58
|
|
59
59
|
preamble = """
|
@@ -87,7 +87,7 @@ def parse(
|
|
87
87
|
returns_message: bool = False,
|
88
88
|
**kwargs,
|
89
89
|
) -> Any:
|
90
|
-
"""Parse a natural
|
90
|
+
"""Parse a natural language message based on schema.
|
91
91
|
|
92
92
|
Examples:
|
93
93
|
|
@@ -270,29 +270,41 @@ def call(
|
|
270
270
|
if schema in (str, None):
|
271
271
|
return lm_output if returns_message else lm_output.text
|
272
272
|
|
273
|
+
def _chain_nl_output_message(parsing_message: lf.Message):
|
274
|
+
"""Chain the source of the parsed output to the LM output."""
|
275
|
+
parsing_message.root.source = lm_output
|
276
|
+
parsing_message.tag('parsing-lm-output')
|
277
|
+
parsing_message.lm_input.tag('parsing-lm-input')
|
278
|
+
|
273
279
|
# Call `parsing_lm` for structured parsing.
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
280
|
+
try:
|
281
|
+
parsing_message = querying.query(
|
282
|
+
lm_output.text,
|
283
|
+
schema,
|
284
|
+
examples=parsing_examples,
|
285
|
+
lm=parsing_lm or lm,
|
286
|
+
include_context=parsing_include_context,
|
287
|
+
cache_seed=cache_seed,
|
288
|
+
autofix=autofix,
|
289
|
+
autofix_lm=autofix_lm or lm,
|
290
|
+
protocol=protocol,
|
291
|
+
returns_message=True,
|
292
|
+
**kwargs,
|
293
|
+
)
|
294
|
+
_chain_nl_output_message(parsing_message)
|
295
|
+
except mapping.MappingError as e:
|
296
|
+
_chain_nl_output_message(e.lm_response)
|
297
|
+
raise e
|
298
|
+
return parsing_message if returns_message else parsing_message.result
|
287
299
|
|
288
300
|
|
289
301
|
def _parse_structure_cls(
|
290
302
|
protocol: schema_lib.SchemaProtocol,
|
291
|
-
) -> Type[
|
303
|
+
) -> Type[_ParseStructure]:
|
292
304
|
if protocol == 'json':
|
293
|
-
return
|
305
|
+
return _ParseStructureJson
|
294
306
|
elif protocol == 'python':
|
295
|
-
return
|
307
|
+
return _ParseStructurePython
|
296
308
|
else:
|
297
309
|
raise ValueError(f'Unknown protocol: {protocol!r}.')
|
298
310
|
|