langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +17 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -63,6 +63,42 @@ class FunctionGenerationTest(unittest.TestCase):
63
63
 
64
64
  lm = fake.StaticSequence([unittest_lm_response, function_gen_lm_response])
65
65
 
66
+ @function_generation.function_gen(lm=lm, unittest='auto')
67
+ def linear_search(items, target): # pylint: disable=unused-argument
68
+ """Performs a linear search on a list to find a target value.
69
+
70
+ Args:
71
+ items (list): The list to search within.
72
+ target: The value to search for.
73
+
74
+ Returns:
75
+ int: The index of the target value if found, otherwise -1.
76
+ """
77
+
78
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
79
+ self.assertEqual(linear_search.source(), function_gen_lm_response)
80
+
81
+ def test_generate_function_without_unittest(self):
82
+ function_gen_lm_response = inspect.cleandoc("""
83
+ def linear_search(items, target):
84
+ \"\"\"
85
+ Performs a linear search on a list to find a target value.
86
+
87
+ Args:
88
+ items (list): The list to search within.
89
+ target: The value to search for.
90
+
91
+ Returns:
92
+ int: The index of the target value if found, otherwise -1.
93
+ \"\"\"
94
+ for i, item in enumerate(items):
95
+ if item == target:
96
+ return i
97
+ return -1
98
+ """)
99
+
100
+ lm = fake.StaticSequence([function_gen_lm_response])
101
+
66
102
  @function_generation.function_gen(lm=lm)
67
103
  def linear_search(items, target): # pylint: disable=unused-argument
68
104
  """Performs a linear search on a list to find a target value.
@@ -258,7 +294,9 @@ class FunctionGenerationTest(unittest.TestCase):
258
294
  cache_file = os.path.join(cache_file_dir, 'cache_file.json')
259
295
 
260
296
  @function_generation.function_gen(
261
- lm=lm, unittest=_unittest_fn, cache_filename=cache_file
297
+ lm=lm,
298
+ unittest=_unittest_fn,
299
+ cache_filename=cache_file,
262
300
  )
263
301
  def linear_search(items, target): # pylint: disable=unused-argument
264
302
  """Performs a linear search on a list to find a target value.
@@ -273,6 +311,36 @@ class FunctionGenerationTest(unittest.TestCase):
273
311
 
274
312
  self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
275
313
 
314
+ def test_context_passthrough(self):
315
+
316
+ class Number(pg.Object):
317
+ value: int
318
+
319
+ function_gen_lm_response = inspect.cleandoc("""
320
+ ```python
321
+ def add(a: Number, b: Number) -> Number:
322
+ \"\"\"Adds two numbers together.\"\"\"
323
+ return Number(a.value + b.value)
324
+ ```
325
+ """)
326
+
327
+ lm = fake.StaticSequence(
328
+ [function_gen_lm_response]
329
+ )
330
+
331
+ def _unittest_fn(func):
332
+ assert func(Number(1), Number(2)) == Number(3)
333
+
334
+ custom_unittest = _unittest_fn
335
+
336
+ @function_generation.function_gen(
337
+ lm=lm, unittest=custom_unittest, num_retries=1
338
+ )
339
+ def add(a: Number, b: Number) -> Number: # pylint: disable=unused-argument
340
+ """Adds two numbers together."""
341
+
342
+ self.assertEqual(add(Number(2), Number(3)), Number(5))
343
+
276
344
  def test_siganture_check(self):
277
345
  incorrect_signature_lm_response = inspect.cleandoc("""
278
346
  ```python
@@ -310,7 +378,9 @@ class FunctionGenerationTest(unittest.TestCase):
310
378
 
311
379
  custom_unittest = _unittest_fn
312
380
 
313
- @function_generation.function_gen(lm=lm, unittest=custom_unittest)
381
+ @function_generation.function_gen(
382
+ lm=lm, unittest=custom_unittest, num_retries=2
383
+ )
314
384
  def linear_search(items, target): # pylint: disable=unused-argument
315
385
  """Performs a linear search on a list to find a target value.
316
386
 
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """The base of symbolic mapping methods."""
15
15
 
16
+ import functools
16
17
  import io
17
18
  from typing import Annotated, Any, Callable
18
19
  import langfun.core as lf
@@ -45,20 +46,22 @@ class MappingError(Exception): # pylint: disable=g-bad-exception-name
45
46
  r = io.StringIO()
46
47
  error_message = str(self.cause).rstrip()
47
48
  r.write(
48
- lf.colored(
49
+ pg.colored(
49
50
  f'{self.cause.__class__.__name__}: {error_message}', 'magenta'
50
51
  )
51
52
  )
52
53
  if include_lm_response:
53
54
  r.write('\n\n')
54
- r.write(lf.colored('[LM Response]', 'blue', styles=['bold']))
55
+ r.write(pg.colored('[LM Response]', 'blue', styles=['bold']))
55
56
  r.write('\n')
56
- r.write(lf.colored(self.lm_response.text, 'blue'))
57
+ r.write(pg.colored(self.lm_response.text, 'blue'))
57
58
  return r.getvalue()
58
59
 
59
60
 
60
61
  @pg.use_init_args(['input', 'output', 'schema', 'context'])
61
- class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
62
+ class MappingExample(lf.NaturalLanguageFormattable,
63
+ lf.Component,
64
+ pg.views.HtmlTreeView.Extension):
62
65
  """Mapping example between text, schema and structured value."""
63
66
 
64
67
  input: pg.typing.Annotated[
@@ -92,6 +95,15 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
92
95
  'The natural language context for this mapping. ',
93
96
  ] = None
94
97
 
98
+ metadata: Annotated[
99
+ dict[str, Any],
100
+ (
101
+ 'The metadata associated with the mapping example, '
102
+ 'which chould carry structured data, such as tool function input. '
103
+ 'It is a `pg.Dict` object whose keys can be accessed by attributes.'
104
+ ),
105
+ ] = pg.Dict()
106
+
95
107
  def schema_repr(
96
108
  self, protocol: schema_lib.SchemaProtocol = 'python', **kwargs
97
109
  ) -> str:
@@ -107,7 +119,11 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
107
119
 
108
120
  @classmethod
109
121
  def value_repr(
110
- cls, value: Any, protocol: schema_lib.SchemaProtocol = 'python', **kwargs
122
+ cls,
123
+ value: Any,
124
+ protocol: schema_lib.SchemaProtocol = 'python',
125
+ use_modality_ref: bool = False,
126
+ **kwargs
111
127
  ) -> str:
112
128
  if isinstance(value, str):
113
129
  return value
@@ -116,7 +132,7 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
116
132
  return str(value)
117
133
 
118
134
  # Placehold modalities if they are present.
119
- if pg.contains(value, type=lf.Modality):
135
+ if use_modality_ref and pg.contains(value, type=lf.Modality):
120
136
  value = lf.ModalityRef.placehold(value)
121
137
  return schema_lib.value_repr(protocol).repr(value, **kwargs)
122
138
 
@@ -147,24 +163,83 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
147
163
  def natural_language_format(self) -> str:
148
164
  result = io.StringIO()
149
165
  if self.context:
150
- result.write(lf.colored('[CONTEXT]\n', styles=['bold']))
151
- result.write(lf.colored(self.context, color='magenta'))
166
+ result.write(pg.colored('[CONTEXT]\n', styles=['bold']))
167
+ result.write(pg.colored(self.context, color='magenta'))
152
168
  result.write('\n\n')
153
169
 
154
- result.write(lf.colored('[INPUT]\n', styles=['bold']))
155
- result.write(lf.colored(self.input_repr(), color='green'))
156
- result.write('\n\n')
170
+ result.write(pg.colored('[INPUT]\n', styles=['bold']))
171
+ result.write(pg.colored(self.input_repr(), color='green'))
157
172
 
158
173
  if self.schema is not None:
159
- result.write(lf.colored('[SCHEMA]\n', styles=['bold']))
160
- result.write(lf.colored(self.schema_repr(), color='red'))
161
174
  result.write('\n\n')
175
+ result.write(pg.colored('[SCHEMA]\n', styles=['bold']))
176
+ result.write(pg.colored(self.schema_repr(), color='red'))
162
177
 
163
178
  if schema_lib.MISSING != self.output:
164
- result.write(lf.colored('[OUTPUT]\n', styles=['bold']))
165
- result.write(lf.colored(self.output_repr(), color='blue'))
179
+ result.write('\n\n')
180
+ result.write(pg.colored('[OUTPUT]\n', styles=['bold']))
181
+ result.write(pg.colored(self.output_repr(), color='blue'))
182
+
183
+ if self.metadata:
184
+ result.write('\n\n')
185
+ result.write(pg.colored('[METADATA]\n', styles=['bold']))
186
+ result.write(pg.colored(str(self.metadata), color='cyan'))
166
187
  return result.getvalue().strip()
167
188
 
189
+ @classmethod
190
+ @functools.cache
191
+ def _html_tree_view_config(cls) -> dict[str, Any]:
192
+
193
+ def render_value(view, *, value, **kwargs):
194
+ if isinstance(value, lf.Template):
195
+ # Make a shallow copy to make sure modalities are rooted by
196
+ # the input.
197
+ value = value.clone().render()
198
+ if value is None:
199
+ return None
200
+ return view.render(value, **kwargs)
201
+
202
+ return pg.views.HtmlTreeView.get_kwargs(
203
+ super()._html_tree_view_config(),
204
+ dict(
205
+ include_keys=['input', 'output', 'context', 'schema', 'metadata'],
206
+ extra_flags=dict(
207
+ render_value_fn=render_value,
208
+ ),
209
+ child_config=dict(
210
+ input=dict(
211
+ collapse_level=1,
212
+ ),
213
+ output=dict(
214
+ css_classes=['lf-example-output'],
215
+ collapse_level=1,
216
+ ),
217
+ schema=dict(
218
+ css_classes=['lf-example-schema'],
219
+ collapse_level=1,
220
+ ),
221
+ metadata=dict(
222
+ css_classes=['lf-example-metadata'],
223
+ collapse_level=1,
224
+ ),
225
+ ),
226
+ )
227
+ )
228
+
229
+ @classmethod
230
+ @functools.cache
231
+ def _html_tree_view_css_styles(cls) -> list[str]:
232
+ return super()._html_tree_view_css_styles() + [
233
+ """
234
+ .lf-example-output {
235
+ color: dodgerblue;
236
+ }
237
+ .lf-example-schema {
238
+ color: blue;
239
+ }
240
+ """
241
+ ]
242
+
168
243
 
169
244
  class Mapping(lf.LangFunc):
170
245
  """Base class for mapping.
@@ -243,13 +318,7 @@ class Mapping(lf.LangFunc):
243
318
  {{ input_title }}:
244
319
  {{ example.input_repr(protocol, compact=False) | indent(2, True) }}
245
320
 
246
- {% if has_modality_refs(example.input) -%}
247
- {{ modality_refs_title }}:
248
- {{ modality_refs_repr(example.input) | indent(2, True) }}
249
-
250
- {% endif -%}
251
-
252
- {%- if example.schema -%}
321
+ {% if example.schema -%}
253
322
  {{ schema_title }}:
254
323
  {{ example.schema_repr(protocol) | indent(2, True) }}
255
324
 
@@ -270,10 +339,6 @@ class Mapping(lf.LangFunc):
270
339
 
271
340
  schema_title: Annotated[str, 'The section title for schema.'] = 'SCHEMA'
272
341
 
273
- modality_refs_title: Annotated[
274
- str, 'The section title for modality refs.'
275
- ] = 'MODALITY_REFERENCES'
276
-
277
342
  protocol: Annotated[
278
343
  schema_lib.SchemaProtocol,
279
344
  'The protocol for representing the schema and value.',
@@ -344,6 +409,12 @@ class Mapping(lf.LangFunc):
344
409
  lm_output = self.postprocess_response(lm_output)
345
410
  lm_output.result = self.postprocess_result(self.parse_result(lm_output))
346
411
  except Exception as e: # pylint: disable=broad-exception-caught
412
+ if (self.lm.cache is not None
413
+ and lm_output.lm_input.cache_seed is not None):
414
+ success = self.lm.cache.delete(
415
+ self.lm, lm_output.lm_input, lm_output.lm_input.cache_seed
416
+ )
417
+ assert success
347
418
  if self.default == lf.RAISE_IF_HAS_ERROR:
348
419
  raise MappingError(lm_output, e) from e
349
420
  lm_output.result = self.default
@@ -378,24 +449,3 @@ class Mapping(lf.LangFunc):
378
449
  """Gets additional symbol definitions besides schema as globals."""
379
450
  return {'ModalityRef': lf.modality.ModalityRef}
380
451
 
381
- #
382
- # Helper methods for handling modalities.
383
- #
384
-
385
- def has_modality_refs(self, value: Any) -> bool:
386
- """Returns true if the value has modalities."""
387
- return not isinstance(value, lf.Modality) and pg.contains(
388
- value, type=lf.Modality
389
- )
390
-
391
- def modalities(self, value: Any) -> dict[str, lf.Modality]:
392
- return lf.Modality.from_value(value)
393
-
394
- def modality_refs_repr(self, value: Any) -> str:
395
- with lf.modality.format_modality_as_ref(True):
396
- return pg.format(
397
- self.modalities(value),
398
- compact=False,
399
- verbose=False,
400
- python_format=True,
401
- )
@@ -14,6 +14,7 @@
14
14
  """Tests for structured mapping example."""
15
15
 
16
16
  import inspect
17
+ from typing import Any
17
18
  import unittest
18
19
 
19
20
  import langfun.core as lf
@@ -28,11 +29,11 @@ class MappingErrorTest(unittest.TestCase):
28
29
  lf.AIMessage('hi'), ValueError('Cannot parse message.')
29
30
  )
30
31
  self.assertEqual(
31
- lf.text_formatting.decolored(str(error)),
32
+ pg.decolor(str(error)),
32
33
  'ValueError: Cannot parse message.\n\n[LM Response]\nhi',
33
34
  )
34
35
  self.assertEqual(
35
- lf.text_formatting.decolored(error.format(include_lm_response=False)),
36
+ pg.decolor(error.format(include_lm_response=False)),
36
37
  'ValueError: Cannot parse message.',
37
38
  )
38
39
 
@@ -129,6 +130,33 @@ class MappingExampleTest(unittest.TestCase):
129
130
  """),
130
131
  )
131
132
 
133
+ def test_str_with_metadata(self):
134
+ self.assertEqual(
135
+ str(
136
+ mapping.MappingExample(
137
+ '1 + 1 = 2',
138
+ schema=int,
139
+ context='Give the answer.',
140
+ metadata={'foo': 'bar'},
141
+ )
142
+ ),
143
+ inspect.cleandoc("""
144
+ \x1b[1m[CONTEXT]
145
+ \x1b[0m\x1b[35mGive the answer.\x1b[0m
146
+
147
+ \x1b[1m[INPUT]
148
+ \x1b[0m\x1b[32m1 + 1 = 2\x1b[0m
149
+
150
+ \x1b[1m[SCHEMA]
151
+ \x1b[0m\x1b[31mint\x1b[0m
152
+
153
+ \x1b[1m[METADATA]
154
+ \x1b[0m\x1b[36m{
155
+ foo = 'bar'
156
+ }\x1b[0m
157
+ """),
158
+ )
159
+
132
160
  def test_serialization(self):
133
161
  example = mapping.MappingExample(
134
162
  'the answer is 2', 2, int, context='compute 1 + 1'
@@ -137,6 +165,66 @@ class MappingExampleTest(unittest.TestCase):
137
165
  pg.eq(pg.from_json_str(example.to_json_str()), example)
138
166
  )
139
167
 
168
+ def assert_html_content(self, html, expected):
169
+ expected = inspect.cleandoc(expected).strip()
170
+ actual = html.content.strip()
171
+ if actual != expected:
172
+ print(actual)
173
+ self.assertEqual(actual, expected)
174
+
175
+ def test_html(self):
176
+
177
+ class Answer(pg.Object):
178
+ answer: int
179
+
180
+ class Addition(lf.Template):
181
+ """Template Addition.
182
+
183
+ {{x}} + {{y}} = ?
184
+ """
185
+ x: Any
186
+ y: Any
187
+
188
+ example = mapping.MappingExample(
189
+ input=Addition(x=1, y=2),
190
+ schema=Answer,
191
+ context='compute 1 + 1',
192
+ output=Answer(answer=3),
193
+ metadata={'foo': 'bar'},
194
+ )
195
+ self.assert_html_content(
196
+ example.to_html(
197
+ enable_summary_tooltip=False,
198
+ extra_flags=dict(
199
+ include_message_metadata=False
200
+ )
201
+ ),
202
+ """
203
+ <details open class="pyglove mapping-example"><summary><div class="summary-title">MappingExample(...)</div></summary><div class="complex-value mapping-example"><details open class="pyglove user-message lf-message"><summary><div class="summary-name lf-message">input<span class="tooltip lf-message">input</span></div><div class="summary-title lf-message">UserMessage(...)</div></summary><div class="complex_value"><div class="message-tags"><span>rendered</span></div><div class="message-text">1 + 2 = ?</div></div></details><details open class="pyglove answer lf-example-output"><summary><div class="summary-name lf-example-output">output<span class="tooltip lf-example-output">output</span></div><div class="summary-title lf-example-output">Answer(...)</div></summary><div class="complex-value answer"><details open class="pyglove int"><summary><div class="summary-name">answer<span class="tooltip">output.answer</span></div><div class="summary-title">int</div></summary><span class="simple-value int">3</span></details></div></details><details open class="pyglove str"><summary><div class="summary-name">context<span class="tooltip">context</span></div><div class="summary-title">str</div></summary><span class="simple-value str">&#x27;compute 1 + 1&#x27;</span></details><details open class="pyglove schema lf-example-schema"><summary><div class="summary-name lf-example-schema">schema<span class="tooltip lf-example-schema">schema</span></div><div class="summary-title lf-example-schema">Schema(...)</div></summary><div class="lf-schema-definition">Answer
204
+
205
+ ```python
206
+ class Answer:
207
+ answer: int
208
+ ```</div></details><details open class="pyglove dict lf-example-metadata"><summary><div class="summary-name lf-example-metadata">metadata<span class="tooltip lf-example-metadata">metadata</span></div><div class="summary-title lf-example-metadata">Dict(...)</div></summary><div class="complex-value dict"><details open class="pyglove str"><summary><div class="summary-name">foo<span class="tooltip">metadata.foo</span></div><div class="summary-title">str</div></summary><span class="simple-value str">&#x27;bar&#x27;</span></details></div></details></div></details>
209
+ """
210
+ )
211
+
212
+ example = mapping.MappingExample(
213
+ input=Addition(x=1, y=2),
214
+ output=Answer(answer=3),
215
+ )
216
+ self.assert_html_content(
217
+ example.to_html(
218
+ enable_summary_tooltip=False,
219
+ extra_flags=dict(
220
+ include_message_metadata=False
221
+ )
222
+ ),
223
+ """
224
+ <details open class="pyglove mapping-example"><summary><div class="summary-title">MappingExample(...)</div></summary><div class="complex-value mapping-example"><details open class="pyglove user-message lf-message"><summary><div class="summary-name lf-message">input<span class="tooltip lf-message">input</span></div><div class="summary-title lf-message">UserMessage(...)</div></summary><div class="complex_value"><div class="message-tags"><span>rendered</span></div><div class="message-text">1 + 2 = ?</div></div></details><details open class="pyglove answer lf-example-output"><summary><div class="summary-name lf-example-output">output<span class="tooltip lf-example-output">output</span></div><div class="summary-title lf-example-output">Answer(...)</div></summary><div class="complex-value answer"><details open class="pyglove int"><summary><div class="summary-name">answer<span class="tooltip">output.answer</span></div><div class="summary-title">int</div></summary><span class="simple-value int">3</span></details></div></details><details open class="pyglove contextual-attribute lf-example-schema"><summary><div class="summary-name lf-example-schema">schema<span class="tooltip lf-example-schema">schema</span></div><div class="summary-title lf-example-schema">ContextualAttribute(...)</div></summary><span class="simple-value none-type">None</span></details><details open class="pyglove dict lf-example-metadata"><summary><div class="summary-name lf-example-metadata">metadata<span class="tooltip lf-example-metadata">metadata</span></div><div class="summary-title lf-example-metadata">Dict(...)</div></summary><div class="complex-value dict"><span class="empty-container"></span></div></details></div></details>
225
+ """
226
+ )
227
+
140
228
 
141
229
  if __name__ == '__main__':
142
230
  unittest.main()
@@ -16,13 +16,13 @@ from typing import Any, Callable, Type, Union
16
16
 
17
17
  import langfun.core as lf
18
18
  from langfun.core.structured import mapping
19
- from langfun.core.structured import prompting
19
+ from langfun.core.structured import querying
20
20
  from langfun.core.structured import schema as schema_lib
21
21
  import pyglove as pg
22
22
 
23
23
 
24
24
  @lf.use_init_args(['schema', 'default', 'examples'])
25
- class ParseStructure(mapping.Mapping):
25
+ class _ParseStructure(mapping.Mapping):
26
26
  """Parse an object out from a natural language text."""
27
27
 
28
28
  context_title = 'USER_REQUEST'
@@ -37,7 +37,7 @@ class ParseStructure(mapping.Mapping):
37
37
  ]
38
38
 
39
39
 
40
- class ParseStructureJson(ParseStructure):
40
+ class _ParseStructureJson(_ParseStructure):
41
41
  """Parse an object out from a NL text using JSON as the protocol."""
42
42
 
43
43
  preamble = """
@@ -53,7 +53,7 @@ class ParseStructureJson(ParseStructure):
53
53
  output_title = 'JSON'
54
54
 
55
55
 
56
- class ParseStructurePython(ParseStructure):
56
+ class _ParseStructurePython(_ParseStructure):
57
57
  """Parse an object out from a NL text using Python as the protocol."""
58
58
 
59
59
  preamble = """
@@ -87,7 +87,7 @@ def parse(
87
87
  returns_message: bool = False,
88
88
  **kwargs,
89
89
  ) -> Any:
90
- """Parse a natural langugage message based on schema.
90
+ """Parse a natural language message based on schema.
91
91
 
92
92
  Examples:
93
93
 
@@ -270,29 +270,41 @@ def call(
270
270
  if schema in (str, None):
271
271
  return lm_output if returns_message else lm_output.text
272
272
 
273
+ def _chain_nl_output_message(parsing_message: lf.Message):
274
+ """Chain the source of the parsed output to the LM output."""
275
+ parsing_message.root.source = lm_output
276
+ parsing_message.tag('parsing-lm-output')
277
+ parsing_message.lm_input.tag('parsing-lm-input')
278
+
273
279
  # Call `parsing_lm` for structured parsing.
274
- return prompting.query(
275
- lm_output,
276
- schema,
277
- examples=parsing_examples,
278
- lm=parsing_lm or lm,
279
- include_context=parsing_include_context,
280
- cache_seed=cache_seed,
281
- autofix=autofix,
282
- autofix_lm=autofix_lm or lm,
283
- protocol=protocol,
284
- returns_message=returns_message,
285
- **kwargs,
286
- )
280
+ try:
281
+ parsing_message = querying.query(
282
+ lm_output.text,
283
+ schema,
284
+ examples=parsing_examples,
285
+ lm=parsing_lm or lm,
286
+ include_context=parsing_include_context,
287
+ cache_seed=cache_seed,
288
+ autofix=autofix,
289
+ autofix_lm=autofix_lm or lm,
290
+ protocol=protocol,
291
+ returns_message=True,
292
+ **kwargs,
293
+ )
294
+ _chain_nl_output_message(parsing_message)
295
+ except mapping.MappingError as e:
296
+ _chain_nl_output_message(e.lm_response)
297
+ raise e
298
+ return parsing_message if returns_message else parsing_message.result
287
299
 
288
300
 
289
301
  def _parse_structure_cls(
290
302
  protocol: schema_lib.SchemaProtocol,
291
- ) -> Type[ParseStructure]:
303
+ ) -> Type[_ParseStructure]:
292
304
  if protocol == 'json':
293
- return ParseStructureJson
305
+ return _ParseStructureJson
294
306
  elif protocol == 'python':
295
- return ParseStructurePython
307
+ return _ParseStructurePython
296
308
  else:
297
309
  raise ValueError(f'Unknown protocol: {protocol!r}.')
298
310
 
@@ -37,7 +37,7 @@ class Itinerary(pg.Object):
37
37
  class ParseStructurePythonTest(unittest.TestCase):
38
38
 
39
39
  def test_render_no_examples(self):
40
- l = parsing.ParseStructurePython(int)
40
+ l = parsing._ParseStructurePython(int)
41
41
  m = lf.AIMessage('12 / 6 + 2 = 4')
42
42
  self.assertEqual(
43
43
  l.render(input=m, context='Compute 12 / 6 + 2.').text,
@@ -62,7 +62,7 @@ class ParseStructurePythonTest(unittest.TestCase):
62
62
  )
63
63
 
64
64
  def test_render_no_context(self):
65
- l = parsing.ParseStructurePython(int)
65
+ l = parsing._ParseStructurePython(int)
66
66
  m = lf.AIMessage('12 / 6 + 2 = 4')
67
67
 
68
68
  self.assertEqual(
@@ -85,7 +85,7 @@ class ParseStructurePythonTest(unittest.TestCase):
85
85
  )
86
86
 
87
87
  def test_render(self):
88
- l = parsing.ParseStructurePython(
88
+ l = parsing._ParseStructurePython(
89
89
  int,
90
90
  examples=[
91
91
  mapping.MappingExample(
@@ -212,7 +212,7 @@ class ParseStructurePythonTest(unittest.TestCase):
212
212
  ),
213
213
  override_attrs=True,
214
214
  ):
215
- l = parsing.ParseStructurePython(
215
+ l = parsing._ParseStructurePython(
216
216
  [Itinerary],
217
217
  examples=[
218
218
  mapping.MappingExample(
@@ -285,7 +285,7 @@ class ParseStructurePythonTest(unittest.TestCase):
285
285
  self.assertEqual(
286
286
  r,
287
287
  lf.AIMessage(
288
- '1', score=1.0, result=1, logprobs=None,
288
+ '1', score=1.0, result=1, logprobs=None, is_cached=False,
289
289
  usage=lf.LMSamplingUsage(652, 1, 653),
290
290
  tags=['lm-response', 'lm-output', 'transformed']
291
291
  ),
@@ -295,7 +295,7 @@ class ParseStructurePythonTest(unittest.TestCase):
295
295
  class ParseStructureJsonTest(unittest.TestCase):
296
296
 
297
297
  def test_render_no_examples(self):
298
- l = parsing.ParseStructureJson(int)
298
+ l = parsing._ParseStructureJson(int)
299
299
  m = lf.AIMessage('12 / 6 + 2 = 4')
300
300
  self.assertEqual(
301
301
  l.render(input=m, context='Compute 12 / 6 + 2.').text,
@@ -320,7 +320,7 @@ class ParseStructureJsonTest(unittest.TestCase):
320
320
  )
321
321
 
322
322
  def test_render_no_context(self):
323
- l = parsing.ParseStructureJson(int)
323
+ l = parsing._ParseStructureJson(int)
324
324
  m = lf.AIMessage('12 / 6 + 2 = 4')
325
325
 
326
326
  self.assertEqual(
@@ -343,7 +343,7 @@ class ParseStructureJsonTest(unittest.TestCase):
343
343
  )
344
344
 
345
345
  def test_render(self):
346
- l = parsing.ParseStructureJson(
346
+ l = parsing._ParseStructureJson(
347
347
  int,
348
348
  examples=[
349
349
  mapping.MappingExample(
@@ -504,7 +504,7 @@ class ParseStructureJsonTest(unittest.TestCase):
504
504
  override_attrs=True,
505
505
  ):
506
506
  message = lf.LangFunc(lm_input)()
507
- l = parsing.ParseStructureJson(
507
+ l = parsing._ParseStructureJson(
508
508
  [Itinerary],
509
509
  examples=[
510
510
  mapping.MappingExample(
@@ -645,6 +645,7 @@ class CallTest(unittest.TestCase):
645
645
  result=3,
646
646
  score=1.0,
647
647
  logprobs=None,
648
+ is_cached=False,
648
649
  usage=lf.LMSamplingUsage(315, 1, 316),
649
650
  tags=['lm-response', 'lm-output', 'transformed']
650
651
  ),
@@ -669,6 +670,49 @@ class CallTest(unittest.TestCase):
669
670
  3,
670
671
  )
671
672
 
673
+ def test_call_with_parsing_message_chaining(self):
674
+ output = parsing.call(
675
+ 'Compute 1 + 2',
676
+ int,
677
+ lm=fake.StaticSequence(['three']),
678
+ parsing_lm=fake.StaticSequence(['3']),
679
+ parsing_examples=[
680
+ mapping.MappingExample(
681
+ context='Multiple four and five',
682
+ input='twenty',
683
+ schema=int,
684
+ output=20,
685
+ )
686
+ ],
687
+ returns_message=True,
688
+ )
689
+ self.assertIn('parsing-lm-output', output.tags)
690
+ self.assertIn('parsing-lm-input', output.source.tags)
691
+ self.assertEqual(output.root.text, 'Compute 1 + 2')
692
+
693
+ def test_call_with_parsing_message_chaining_on_parsing_error(self):
694
+ try:
695
+ output = parsing.call(
696
+ 'Compute 1 + 2',
697
+ int,
698
+ lm=fake.StaticSequence(['three']),
699
+ parsing_lm=fake.StaticSequence(['abc']),
700
+ parsing_examples=[
701
+ mapping.MappingExample(
702
+ context='Multiple four and five',
703
+ input='twenty',
704
+ schema=int,
705
+ output=20,
706
+ )
707
+ ],
708
+ returns_message=True,
709
+ )
710
+ except mapping.MappingError as e:
711
+ output = e.lm_response
712
+ self.assertIn('parsing-lm-output', output.tags)
713
+ self.assertIn('parsing-lm-input', output.source.tags)
714
+ self.assertEqual(output.root.text, 'Compute 1 + 2')
715
+
672
716
  def test_call_with_autofix(self):
673
717
  lm = fake.StaticSequence(
674
718
  [