langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -13,15 +13,55 @@
13
13
  # limitations under the License.
14
14
  """The base of symbolic mapping methods."""
15
15
 
16
+ import functools
16
17
  import io
17
- from typing import Annotated, Any
18
+ from typing import Annotated, Any, Callable
18
19
  import langfun.core as lf
19
20
  from langfun.core.structured import schema as schema_lib
20
21
  import pyglove as pg
21
22
 
22
23
 
24
+ class MappingError(Exception): # pylint: disable=g-bad-exception-name
25
+ """Mapping error."""
26
+
27
+ def __init__(self, lm_response: lf.Message, cause: Exception):
28
+ self._lm_response = lm_response
29
+ self._cause = cause
30
+
31
+ @property
32
+ def lm_response(self) -> lf.Message:
33
+ """Returns the LM response that failed to be mapped."""
34
+ return self._lm_response
35
+
36
+ @property
37
+ def cause(self) -> Exception:
38
+ """Returns the cause of the error."""
39
+ return self._cause
40
+
41
+ def __str__(self) -> str:
42
+ return self.format(include_lm_response=True)
43
+
44
+ def format(self, include_lm_response: bool = True) -> str:
45
+ """Formats the mapping error."""
46
+ r = io.StringIO()
47
+ error_message = str(self.cause).rstrip()
48
+ r.write(
49
+ pg.colored(
50
+ f'{self.cause.__class__.__name__}: {error_message}', 'magenta'
51
+ )
52
+ )
53
+ if include_lm_response:
54
+ r.write('\n\n')
55
+ r.write(pg.colored('[LM Response]', 'blue', styles=['bold']))
56
+ r.write('\n')
57
+ r.write(pg.colored(self.lm_response.text, 'blue'))
58
+ return r.getvalue()
59
+
60
+
23
61
  @pg.use_init_args(['input', 'output', 'schema', 'context'])
24
- class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
62
+ class MappingExample(lf.NaturalLanguageFormattable,
63
+ lf.Component,
64
+ pg.views.HtmlTreeView.Extension):
25
65
  """Mapping example between text, schema and structured value."""
26
66
 
27
67
  input: pg.typing.Annotated[
@@ -55,6 +95,15 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
55
95
  'The natural language context for this mapping. ',
56
96
  ] = None
57
97
 
98
+ metadata: Annotated[
99
+ dict[str, Any],
100
+ (
101
+ 'The metadata associated with the mapping example, '
102
+ 'which chould carry structured data, such as tool function input. '
103
+ 'It is a `pg.Dict` object whose keys can be accessed by attributes.'
104
+ ),
105
+ ] = pg.Dict()
106
+
58
107
  def schema_repr(
59
108
  self, protocol: schema_lib.SchemaProtocol = 'python', **kwargs
60
109
  ) -> str:
@@ -70,7 +119,11 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
70
119
 
71
120
  @classmethod
72
121
  def value_repr(
73
- cls, value: Any, protocol: schema_lib.SchemaProtocol = 'python', **kwargs
122
+ cls,
123
+ value: Any,
124
+ protocol: schema_lib.SchemaProtocol = 'python',
125
+ use_modality_ref: bool = False,
126
+ **kwargs
74
127
  ) -> str:
75
128
  if isinstance(value, str):
76
129
  return value
@@ -79,7 +132,7 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
79
132
  return str(value)
80
133
 
81
134
  # Placehold modalities if they are present.
82
- if pg.contains(value, type=lf.Modality):
135
+ if use_modality_ref and pg.contains(value, type=lf.Modality):
83
136
  value = lf.ModalityRef.placehold(value)
84
137
  return schema_lib.value_repr(protocol).repr(value, **kwargs)
85
138
 
@@ -110,24 +163,83 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
110
163
  def natural_language_format(self) -> str:
111
164
  result = io.StringIO()
112
165
  if self.context:
113
- result.write(lf.colored('[CONTEXT]\n', styles=['bold']))
114
- result.write(lf.colored(self.context, color='magenta'))
166
+ result.write(pg.colored('[CONTEXT]\n', styles=['bold']))
167
+ result.write(pg.colored(self.context, color='magenta'))
115
168
  result.write('\n\n')
116
169
 
117
- result.write(lf.colored('[INPUT]\n', styles=['bold']))
118
- result.write(lf.colored(self.input_repr(), color='green'))
119
- result.write('\n\n')
170
+ result.write(pg.colored('[INPUT]\n', styles=['bold']))
171
+ result.write(pg.colored(self.input_repr(), color='green'))
120
172
 
121
173
  if self.schema is not None:
122
- result.write(lf.colored('[SCHEMA]\n', styles=['bold']))
123
- result.write(lf.colored(self.schema_repr(), color='red'))
124
174
  result.write('\n\n')
175
+ result.write(pg.colored('[SCHEMA]\n', styles=['bold']))
176
+ result.write(pg.colored(self.schema_repr(), color='red'))
125
177
 
126
178
  if schema_lib.MISSING != self.output:
127
- result.write(lf.colored('[OUTPUT]\n', styles=['bold']))
128
- result.write(lf.colored(self.output_repr(), color='blue'))
179
+ result.write('\n\n')
180
+ result.write(pg.colored('[OUTPUT]\n', styles=['bold']))
181
+ result.write(pg.colored(self.output_repr(), color='blue'))
182
+
183
+ if self.metadata:
184
+ result.write('\n\n')
185
+ result.write(pg.colored('[METADATA]\n', styles=['bold']))
186
+ result.write(pg.colored(str(self.metadata), color='cyan'))
129
187
  return result.getvalue().strip()
130
188
 
189
+ @classmethod
190
+ @functools.cache
191
+ def _html_tree_view_config(cls) -> dict[str, Any]:
192
+
193
+ def render_value(view, *, value, **kwargs):
194
+ if isinstance(value, lf.Template):
195
+ # Make a shallow copy to make sure modalities are rooted by
196
+ # the input.
197
+ value = value.clone().render()
198
+ if value is None:
199
+ return None
200
+ return view.render(value, **kwargs)
201
+
202
+ return pg.views.HtmlTreeView.get_kwargs(
203
+ super()._html_tree_view_config(),
204
+ dict(
205
+ include_keys=['input', 'output', 'context', 'schema', 'metadata'],
206
+ extra_flags=dict(
207
+ render_value_fn=render_value,
208
+ ),
209
+ child_config=dict(
210
+ input=dict(
211
+ collapse_level=1,
212
+ ),
213
+ output=dict(
214
+ css_classes=['lf-example-output'],
215
+ collapse_level=1,
216
+ ),
217
+ schema=dict(
218
+ css_classes=['lf-example-schema'],
219
+ collapse_level=1,
220
+ ),
221
+ metadata=dict(
222
+ css_classes=['lf-example-metadata'],
223
+ collapse_level=1,
224
+ ),
225
+ ),
226
+ )
227
+ )
228
+
229
+ @classmethod
230
+ @functools.cache
231
+ def _html_tree_view_css_styles(cls) -> list[str]:
232
+ return super()._html_tree_view_css_styles() + [
233
+ """
234
+ .lf-example-output {
235
+ color: dodgerblue;
236
+ }
237
+ .lf-example-schema {
238
+ color: blue;
239
+ }
240
+ """
241
+ ]
242
+
131
243
 
132
244
  class Mapping(lf.LangFunc):
133
245
  """Base class for mapping.
@@ -206,13 +318,7 @@ class Mapping(lf.LangFunc):
206
318
  {{ input_title }}:
207
319
  {{ example.input_repr(protocol, compact=False) | indent(2, True) }}
208
320
 
209
- {% if has_modality_refs(example.input) -%}
210
- {{ modality_refs_title }}:
211
- {{ modality_refs_repr(example.input) | indent(2, True) }}
212
-
213
- {% endif -%}
214
-
215
- {%- if example.schema -%}
321
+ {% if example.schema -%}
216
322
  {{ schema_title }}:
217
323
  {{ example.schema_repr(protocol) | indent(2, True) }}
218
324
 
@@ -233,10 +339,6 @@ class Mapping(lf.LangFunc):
233
339
 
234
340
  schema_title: Annotated[str, 'The section title for schema.'] = 'SCHEMA'
235
341
 
236
- modality_refs_title: Annotated[
237
- str, 'The section title for modality refs.'
238
- ] = 'MODALITY_REFERENCES'
239
-
240
342
  protocol: Annotated[
241
343
  schema_lib.SchemaProtocol,
242
344
  'The protocol for representing the schema and value.',
@@ -278,6 +380,14 @@ class Mapping(lf.LangFunc):
278
380
  ),
279
381
  ] = lf.RAISE_IF_HAS_ERROR
280
382
 
383
+ response_postprocess: Annotated[
384
+ Callable[[str], str] | None,
385
+ (
386
+ 'A callable object that post process the raw LLM response before '
387
+ 'parsing it into the output Python object.'
388
+ )
389
+ ] = None
390
+
281
391
  #
282
392
  # Key methods for implementing specific mappings.
283
393
  #
@@ -296,10 +406,17 @@ class Mapping(lf.LangFunc):
296
406
  def transform_output(self, lm_output: lf.Message) -> lf.Message:
297
407
  """Transforms LM response into structure if schema is present."""
298
408
  try:
409
+ lm_output = self.postprocess_response(lm_output)
299
410
  lm_output.result = self.postprocess_result(self.parse_result(lm_output))
300
411
  except Exception as e: # pylint: disable=broad-exception-caught
412
+ if (self.lm.cache is not None
413
+ and lm_output.lm_input.cache_seed is not None):
414
+ success = self.lm.cache.delete(
415
+ self.lm, lm_output.lm_input, lm_output.lm_input.cache_seed
416
+ )
417
+ assert success
301
418
  if self.default == lf.RAISE_IF_HAS_ERROR:
302
- raise e
419
+ raise MappingError(lm_output, e) from e
303
420
  lm_output.result = self.default
304
421
  return lm_output
305
422
 
@@ -316,6 +433,14 @@ class Mapping(lf.LangFunc):
316
433
  autofix_lm=self.autofix_lm or self.lm,
317
434
  )
318
435
 
436
+ def postprocess_response(self, response: lf.Message) -> lf.Message:
437
+ """Post process LLM response."""
438
+ if self.response_postprocess is not None:
439
+ postprocessed_text = self.response_postprocess(response.text)
440
+ if postprocessed_text != response.text:
441
+ return lf.AIMessage(postprocessed_text, source=response)
442
+ return response
443
+
319
444
  def postprocess_result(self, result: Any) -> Any:
320
445
  """Post process structured output."""
321
446
  return result
@@ -324,24 +449,3 @@ class Mapping(lf.LangFunc):
324
449
  """Gets additional symbol definitions besides schema as globals."""
325
450
  return {'ModalityRef': lf.modality.ModalityRef}
326
451
 
327
- #
328
- # Helper methods for handling modalities.
329
- #
330
-
331
- def has_modality_refs(self, value: Any) -> bool:
332
- """Returns true if the value has modalities."""
333
- return not isinstance(value, lf.Modality) and pg.contains(
334
- value, type=lf.Modality
335
- )
336
-
337
- def modalities(self, value: Any) -> dict[str, lf.Modality]:
338
- return lf.Modality.from_value(value)
339
-
340
- def modality_refs_repr(self, value: Any) -> str:
341
- with lf.modality.format_modality_as_ref(True):
342
- return pg.format(
343
- self.modalities(value),
344
- compact=False,
345
- verbose=False,
346
- python_format=True,
347
- )
@@ -14,12 +14,30 @@
14
14
  """Tests for structured mapping example."""
15
15
 
16
16
  import inspect
17
+ from typing import Any
17
18
  import unittest
18
19
 
20
+ import langfun.core as lf
19
21
  from langfun.core.structured import mapping
20
22
  import pyglove as pg
21
23
 
22
24
 
25
+ class MappingErrorTest(unittest.TestCase):
26
+
27
+ def test_format(self):
28
+ error = mapping.MappingError(
29
+ lf.AIMessage('hi'), ValueError('Cannot parse message.')
30
+ )
31
+ self.assertEqual(
32
+ pg.decolor(str(error)),
33
+ 'ValueError: Cannot parse message.\n\n[LM Response]\nhi',
34
+ )
35
+ self.assertEqual(
36
+ pg.decolor(error.format(include_lm_response=False)),
37
+ 'ValueError: Cannot parse message.',
38
+ )
39
+
40
+
23
41
  class MappingExampleTest(unittest.TestCase):
24
42
 
25
43
  def test_basics(self):
@@ -112,6 +130,33 @@ class MappingExampleTest(unittest.TestCase):
112
130
  """),
113
131
  )
114
132
 
133
+ def test_str_with_metadata(self):
134
+ self.assertEqual(
135
+ str(
136
+ mapping.MappingExample(
137
+ '1 + 1 = 2',
138
+ schema=int,
139
+ context='Give the answer.',
140
+ metadata={'foo': 'bar'},
141
+ )
142
+ ),
143
+ inspect.cleandoc("""
144
+ \x1b[1m[CONTEXT]
145
+ \x1b[0m\x1b[35mGive the answer.\x1b[0m
146
+
147
+ \x1b[1m[INPUT]
148
+ \x1b[0m\x1b[32m1 + 1 = 2\x1b[0m
149
+
150
+ \x1b[1m[SCHEMA]
151
+ \x1b[0m\x1b[31mint\x1b[0m
152
+
153
+ \x1b[1m[METADATA]
154
+ \x1b[0m\x1b[36m{
155
+ foo = 'bar'
156
+ }\x1b[0m
157
+ """),
158
+ )
159
+
115
160
  def test_serialization(self):
116
161
  example = mapping.MappingExample(
117
162
  'the answer is 2', 2, int, context='compute 1 + 1'
@@ -120,6 +165,66 @@ class MappingExampleTest(unittest.TestCase):
120
165
  pg.eq(pg.from_json_str(example.to_json_str()), example)
121
166
  )
122
167
 
168
+ def assert_html_content(self, html, expected):
169
+ expected = inspect.cleandoc(expected).strip()
170
+ actual = html.content.strip()
171
+ if actual != expected:
172
+ print(actual)
173
+ self.assertEqual(actual, expected)
174
+
175
+ def test_html(self):
176
+
177
+ class Answer(pg.Object):
178
+ answer: int
179
+
180
+ class Addition(lf.Template):
181
+ """Template Addition.
182
+
183
+ {{x}} + {{y}} = ?
184
+ """
185
+ x: Any
186
+ y: Any
187
+
188
+ example = mapping.MappingExample(
189
+ input=Addition(x=1, y=2),
190
+ schema=Answer,
191
+ context='compute 1 + 1',
192
+ output=Answer(answer=3),
193
+ metadata={'foo': 'bar'},
194
+ )
195
+ self.assert_html_content(
196
+ example.to_html(
197
+ enable_summary_tooltip=False,
198
+ extra_flags=dict(
199
+ include_message_metadata=False
200
+ )
201
+ ),
202
+ """
203
+ <details open class="pyglove mapping-example"><summary><div class="summary-title">MappingExample(...)</div></summary><div class="complex-value mapping-example"><details open class="pyglove user-message lf-message"><summary><div class="summary-name lf-message">input<span class="tooltip lf-message">input</span></div><div class="summary-title lf-message">UserMessage(...)</div></summary><div class="complex_value"><div class="message-tags"><span>rendered</span></div><div class="message-text">1 + 2 = ?</div></div></details><details open class="pyglove answer lf-example-output"><summary><div class="summary-name lf-example-output">output<span class="tooltip lf-example-output">output</span></div><div class="summary-title lf-example-output">Answer(...)</div></summary><div class="complex-value answer"><details open class="pyglove int"><summary><div class="summary-name">answer<span class="tooltip">output.answer</span></div><div class="summary-title">int</div></summary><span class="simple-value int">3</span></details></div></details><details open class="pyglove str"><summary><div class="summary-name">context<span class="tooltip">context</span></div><div class="summary-title">str</div></summary><span class="simple-value str">&#x27;compute 1 + 1&#x27;</span></details><details open class="pyglove schema lf-example-schema"><summary><div class="summary-name lf-example-schema">schema<span class="tooltip lf-example-schema">schema</span></div><div class="summary-title lf-example-schema">Schema(...)</div></summary><div class="lf-schema-definition">Answer
204
+
205
+ ```python
206
+ class Answer:
207
+ answer: int
208
+ ```</div></details><details open class="pyglove dict lf-example-metadata"><summary><div class="summary-name lf-example-metadata">metadata<span class="tooltip lf-example-metadata">metadata</span></div><div class="summary-title lf-example-metadata">Dict(...)</div></summary><div class="complex-value dict"><details open class="pyglove str"><summary><div class="summary-name">foo<span class="tooltip">metadata.foo</span></div><div class="summary-title">str</div></summary><span class="simple-value str">&#x27;bar&#x27;</span></details></div></details></div></details>
209
+ """
210
+ )
211
+
212
+ example = mapping.MappingExample(
213
+ input=Addition(x=1, y=2),
214
+ output=Answer(answer=3),
215
+ )
216
+ self.assert_html_content(
217
+ example.to_html(
218
+ enable_summary_tooltip=False,
219
+ extra_flags=dict(
220
+ include_message_metadata=False
221
+ )
222
+ ),
223
+ """
224
+ <details open class="pyglove mapping-example"><summary><div class="summary-title">MappingExample(...)</div></summary><div class="complex-value mapping-example"><details open class="pyglove user-message lf-message"><summary><div class="summary-name lf-message">input<span class="tooltip lf-message">input</span></div><div class="summary-title lf-message">UserMessage(...)</div></summary><div class="complex_value"><div class="message-tags"><span>rendered</span></div><div class="message-text">1 + 2 = ?</div></div></details><details open class="pyglove answer lf-example-output"><summary><div class="summary-name lf-example-output">output<span class="tooltip lf-example-output">output</span></div><div class="summary-title lf-example-output">Answer(...)</div></summary><div class="complex-value answer"><details open class="pyglove int"><summary><div class="summary-name">answer<span class="tooltip">output.answer</span></div><div class="summary-title">int</div></summary><span class="simple-value int">3</span></details></div></details><details open class="pyglove contextual-attribute lf-example-schema"><summary><div class="summary-name lf-example-schema">schema<span class="tooltip lf-example-schema">schema</span></div><div class="summary-title lf-example-schema">ContextualAttribute(...)</div></summary><span class="simple-value none-type">None</span></details><details open class="pyglove dict lf-example-metadata"><summary><div class="summary-name lf-example-metadata">metadata<span class="tooltip lf-example-metadata">metadata</span></div><div class="summary-title lf-example-metadata">Dict(...)</div></summary><div class="complex-value dict"><span class="empty-container"></span></div></details></div></details>
225
+ """
226
+ )
227
+
123
228
 
124
229
  if __name__ == '__main__':
125
230
  unittest.main()
@@ -16,13 +16,13 @@ from typing import Any, Callable, Type, Union
16
16
 
17
17
  import langfun.core as lf
18
18
  from langfun.core.structured import mapping
19
- from langfun.core.structured import prompting
19
+ from langfun.core.structured import querying
20
20
  from langfun.core.structured import schema as schema_lib
21
21
  import pyglove as pg
22
22
 
23
23
 
24
24
  @lf.use_init_args(['schema', 'default', 'examples'])
25
- class ParseStructure(mapping.Mapping):
25
+ class _ParseStructure(mapping.Mapping):
26
26
  """Parse an object out from a natural language text."""
27
27
 
28
28
  context_title = 'USER_REQUEST'
@@ -37,7 +37,7 @@ class ParseStructure(mapping.Mapping):
37
37
  ]
38
38
 
39
39
 
40
- class ParseStructureJson(ParseStructure):
40
+ class _ParseStructureJson(_ParseStructure):
41
41
  """Parse an object out from a NL text using JSON as the protocol."""
42
42
 
43
43
  preamble = """
@@ -53,7 +53,7 @@ class ParseStructureJson(ParseStructure):
53
53
  output_title = 'JSON'
54
54
 
55
55
 
56
- class ParseStructurePython(ParseStructure):
56
+ class _ParseStructurePython(_ParseStructure):
57
57
  """Parse an object out from a NL text using Python as the protocol."""
58
58
 
59
59
  preamble = """
@@ -87,7 +87,7 @@ def parse(
87
87
  returns_message: bool = False,
88
88
  **kwargs,
89
89
  ) -> Any:
90
- """Parse a natural langugage message based on schema.
90
+ """Parse a natural language message based on schema.
91
91
 
92
92
  Examples:
93
93
 
@@ -270,29 +270,41 @@ def call(
270
270
  if schema in (str, None):
271
271
  return lm_output if returns_message else lm_output.text
272
272
 
273
+ def _chain_nl_output_message(parsing_message: lf.Message):
274
+ """Chain the source of the parsed output to the LM output."""
275
+ parsing_message.root.source = lm_output
276
+ parsing_message.tag('parsing-lm-output')
277
+ parsing_message.lm_input.tag('parsing-lm-input')
278
+
273
279
  # Call `parsing_lm` for structured parsing.
274
- return prompting.query(
275
- lm_output,
276
- schema,
277
- examples=parsing_examples,
278
- lm=parsing_lm or lm,
279
- include_context=parsing_include_context,
280
- cache_seed=cache_seed,
281
- autofix=autofix,
282
- autofix_lm=autofix_lm or lm,
283
- protocol=protocol,
284
- returns_message=returns_message,
285
- **kwargs,
286
- )
280
+ try:
281
+ parsing_message = querying.query(
282
+ lm_output.text,
283
+ schema,
284
+ examples=parsing_examples,
285
+ lm=parsing_lm or lm,
286
+ include_context=parsing_include_context,
287
+ cache_seed=cache_seed,
288
+ autofix=autofix,
289
+ autofix_lm=autofix_lm or lm,
290
+ protocol=protocol,
291
+ returns_message=True,
292
+ **kwargs,
293
+ )
294
+ _chain_nl_output_message(parsing_message)
295
+ except mapping.MappingError as e:
296
+ _chain_nl_output_message(e.lm_response)
297
+ raise e
298
+ return parsing_message if returns_message else parsing_message.result
287
299
 
288
300
 
289
301
  def _parse_structure_cls(
290
302
  protocol: schema_lib.SchemaProtocol,
291
- ) -> Type[ParseStructure]:
303
+ ) -> Type[_ParseStructure]:
292
304
  if protocol == 'json':
293
- return ParseStructureJson
305
+ return _ParseStructureJson
294
306
  elif protocol == 'python':
295
- return ParseStructurePython
307
+ return _ParseStructurePython
296
308
  else:
297
309
  raise ValueError(f'Unknown protocol: {protocol!r}.')
298
310