langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. langfun/__init__.py +2 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +2 -0
  7. langfun/core/eval/base.py +202 -23
  8. langfun/core/eval/base_test.py +49 -10
  9. langfun/core/eval/matching.py +26 -9
  10. langfun/core/eval/matching_test.py +2 -1
  11. langfun/core/eval/scoring.py +15 -6
  12. langfun/core/eval/scoring_test.py +2 -1
  13. langfun/core/langfunc.py +0 -5
  14. langfun/core/langfunc_test.py +6 -4
  15. langfun/core/language_model.py +124 -24
  16. langfun/core/language_model_test.py +249 -26
  17. langfun/core/llms/__init__.py +19 -2
  18. langfun/core/llms/anthropic.py +263 -0
  19. langfun/core/llms/anthropic_test.py +167 -0
  20. langfun/core/llms/cache/in_memory_test.py +37 -28
  21. langfun/core/llms/fake.py +31 -22
  22. langfun/core/llms/fake_test.py +122 -11
  23. langfun/core/llms/google_genai_test.py +8 -3
  24. langfun/core/llms/groq.py +260 -0
  25. langfun/core/llms/groq_test.py +170 -0
  26. langfun/core/llms/llama_cpp.py +3 -1
  27. langfun/core/llms/openai.py +97 -79
  28. langfun/core/llms/openai_test.py +285 -59
  29. langfun/core/modalities/video.py +5 -2
  30. langfun/core/structured/__init__.py +3 -0
  31. langfun/core/structured/completion_test.py +2 -2
  32. langfun/core/structured/function_generation.py +245 -0
  33. langfun/core/structured/function_generation_test.py +329 -0
  34. langfun/core/structured/mapping.py +56 -2
  35. langfun/core/structured/mapping_test.py +17 -0
  36. langfun/core/structured/parsing_test.py +18 -13
  37. langfun/core/structured/prompting.py +27 -6
  38. langfun/core/structured/prompting_test.py +79 -12
  39. langfun/core/structured/schema.py +4 -2
  40. langfun/core/structured/schema_generation_test.py +2 -2
  41. langfun/core/structured/schema_test.py +4 -6
  42. langfun/core/template.py +125 -10
  43. langfun/core/template_test.py +75 -0
  44. langfun/core/templates/selfplay_test.py +6 -2
  45. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
  46. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +49 -43
  47. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
  48. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
  49. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,245 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """LLM-based function generation."""
15
+
16
+ import functools
17
+ import inspect
18
+ import re
19
+ from typing import Any, Callable, Optional, Tuple
20
+
21
+ from langfun.core import language_model
22
+ from langfun.core import template
23
+ from langfun.core.coding import python
24
+ from langfun.core.structured import prompting
25
+ import pyglove as pg
26
+
27
+
28
+ def unittest_gen(signature, lm, num_retries=10):
29
+ """Generates unit tests for a python function signature."""
30
+
31
+ class UnitTest(pg.Object):
32
+ """A valid unit test for a python function."""
33
+
34
+ input: dict[str, Any]
35
+ expected_output: Any
36
+
37
+ class PythonFunctionSignature(pg.Object):
38
+ signature: str
39
+
40
+ unittest_examples = None
41
+ for _ in range(num_retries):
42
+ r = prompting.query(
43
+ PythonFunctionSignature(signature=signature),
44
+ list[UnitTest],
45
+ lm=lm,
46
+ default=None,
47
+ )
48
+ if isinstance(r, list) and r:
49
+ unittest_examples = []
50
+ for unit_test in r:
51
+ unittest_examples.append((unit_test.input, unit_test.expected_output))
52
+ break
53
+
54
+ return unittest_examples
55
+
56
+
57
+ def unittest_with_test_cases(f, unittests):
58
+ """Applies unit tests to a python function to be tested."""
59
+ if not unittests:
60
+ raise ValueError(f"No unit tests provided: {unittests}")
61
+
62
+ for unit_test in unittests:
63
+ inputs = unit_test[0]
64
+ if isinstance(inputs, dict):
65
+ actual = f(**inputs)
66
+ elif isinstance(inputs, tuple):
67
+ actual = f(*inputs)
68
+ else:
69
+ actual = f(inputs)
70
+
71
+ expected = unit_test[1]
72
+ assert (
73
+ actual == expected
74
+ ), f"Test FAILED: Inputs: {inputs}, Expected: {expected}, Actual: {actual}"
75
+
76
+
77
+ def _function_gen(
78
+ func: Callable[..., Any],
79
+ signature: str,
80
+ lm: language_model.LanguageModel,
81
+ num_retries: int = 10,
82
+ unittest: Optional[
83
+ Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
84
+ ] = None,
85
+ ):
86
+ """Generates a python function with LLM and verify its quality with unit testing."""
87
+
88
+ class PythonFunctionPrompt(template.Template):
89
+ r"""A template for a python function generation.
90
+
91
+ Please reply to the last PYTHON_FUNCTION_SIGNATURE with a self-sufficient,
92
+ error-free, and efficiently coded PYTHON_FUNCTION, crafted to the standards
93
+ of a world-class programmer.
94
+
95
+ PYTHON_FUNCTION_SIGNATURE:
96
+ ```python
97
+ def calculate_area_circle(radius: float) -> float:
98
+ \"\"\"Calculates the area of a circle given its radius.
99
+
100
+ Args:
101
+ radius: The radius of the circle.
102
+
103
+ Returns:
104
+ The area of the circle.
105
+ \"\"\"
106
+ ```
107
+
108
+ PYTHON_FUNCTION:
109
+ ```python
110
+ def calculate_area_circle(radius: float) -> float:
111
+ \"\"\"Calculates the area of a circle given its radius.
112
+
113
+ Args:
114
+ radius: The radius of the circle.
115
+
116
+ Returns:
117
+ The area of the circle.
118
+ \"\"\"
119
+ import math
120
+
121
+ area = math.pi * radius**2
122
+ return area
123
+ ```
124
+
125
+ PYTHON_FUNCTION_SIGNATURE:
126
+ ```python
127
+ {{signature}}
128
+ ```
129
+
130
+ PYTHON_FUNCTION:
131
+ """
132
+
133
+ unittest_examples = None
134
+ if unittest is None:
135
+ unittest_examples = unittest_gen(signature, lm=lm)
136
+ elif not callable(unittest):
137
+ unittest_examples = unittest
138
+
139
+ for _ in range(num_retries):
140
+ try:
141
+ source_code = prompting.query(
142
+ PythonFunctionPrompt(signature=signature), lm=lm
143
+ )
144
+ f = python.evaluate(source_code)
145
+
146
+ # Check whether the sigantures are the same.
147
+ if inspect.signature(f) != inspect.signature(func):
148
+ continue
149
+
150
+ if callable(unittest):
151
+ unittest(f)
152
+ else:
153
+ unittest_with_test_cases(f, unittest_examples)
154
+
155
+ return f, source_code
156
+ except Exception: # pylint: disable=broad-exception-caught
157
+ pass
158
+
159
+ return None, None
160
+
161
+
162
+ def _process_signature(signature):
163
+ # Remove the decorator.
164
+ pattern = r"^\@.*function_gen.*$"
165
+ signature = re.sub(pattern, "", signature, flags=re.MULTILINE)
166
+ # Remove the possible 'pass' in an empty function.
167
+ pattern = r"^\s*pass\s*$"
168
+ signature = re.sub(pattern, "", signature, flags=re.MULTILINE)
169
+ return signature.strip()
170
+
171
+
172
+ def function_gen(
173
+ lm: language_model.LanguageModel,
174
+ cache_filename: str | None = None,
175
+ num_retries: int = 10,
176
+ unittest: Optional[
177
+ Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
178
+ ] = None,
179
+ ):
180
+ """A decorator for automating function generation using a language model.
181
+
182
+ This decorator should be applied to functions that are not yet implemented. It
183
+ facilitates the implementation via the specified LLM, ensuring
184
+ quality through unit tests.
185
+
186
+ Args:
187
+ lm (lf.LanguageModel): The language model used for generating function
188
+ implementations.
189
+ cache_filename (str | None): Optional. The path of the file where
190
+ generated function implementations are loaded from or saved to.
191
+ num_retries (int): Maximum number of attempts the language model should
192
+ make to generate a suitable function implementation.
193
+ unittest: This optional parameter enables the definition of custom unit
194
+ tests. You can either provide a list of test cases as tuples of inputs
195
+ and outputs, or a function that throws an error if a test fails. If left
196
+ as None (the default setting), the LLM will automatically create the
197
+ unit test cases.
198
+
199
+ Returns:
200
+ The implemented function object.
201
+ """
202
+
203
+ def _decorate(func):
204
+ setattr(func, "__function__", None)
205
+ setattr(func, "__source_code__", None)
206
+
207
+ @functools.wraps(func)
208
+ def lm_generated_func(*args, **kwargs):
209
+ if func.__function__ is not None:
210
+ return func.__function__(*args, **kwargs)
211
+
212
+ signature = _process_signature(inspect.getsource(func))
213
+ cache = pg.Dict()
214
+ if cache_filename is not None:
215
+ try:
216
+ cache = pg.load(cache_filename)
217
+ except FileNotFoundError:
218
+ pg.logging.warning(
219
+ "Creating a new cache as cache file '%s' does not exist.",
220
+ cache_filename,
221
+ )
222
+
223
+ if signature in cache:
224
+ func.__source_code__ = cache[signature]
225
+ func.__function__ = python.evaluate(func.__source_code__)
226
+ return func.__function__(*args, **kwargs)
227
+
228
+ func.__function__, func.__source_code__ = _function_gen(
229
+ func, signature, lm, num_retries=num_retries, unittest=unittest
230
+ )
231
+ if func.__function__ is None:
232
+ raise ValueError(f"Function generation failed. Signature:\n{signature}")
233
+
234
+ if cache_filename is not None:
235
+ cache[signature] = func.__source_code__
236
+ cache.save(cache_filename)
237
+ return func.__function__(*args, **kwargs)
238
+
239
+ lm_generated_func.__name__ = func.__name__
240
+ lm_generated_func.__qualname__ = func.__qualname__
241
+ lm_generated_func.__module__ = func.__module__
242
+ lm_generated_func.source = lambda: func.__source_code__
243
+ return lm_generated_func
244
+
245
+ return _decorate
@@ -0,0 +1,329 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ import os
16
+ import tempfile
17
+ import unittest
18
+ from langfun.core.llms import fake
19
+ from langfun.core.structured import function_generation
20
+ import pyglove as pg
21
+
22
+
23
+ class FunctionGenerationTest(unittest.TestCase):
24
+
25
+ def test_generate_function(self):
26
+ function_gen_lm_response = inspect.cleandoc("""
27
+ def linear_search(items, target):
28
+ \"\"\"
29
+ Performs a linear search on a list to find a target value.
30
+
31
+ Args:
32
+ items (list): The list to search within.
33
+ target: The value to search for.
34
+
35
+ Returns:
36
+ int: The index of the target value if found, otherwise -1.
37
+ \"\"\"
38
+ for i, item in enumerate(items):
39
+ if item == target:
40
+ return i
41
+ return -1
42
+ """)
43
+ unittest_lm_response = inspect.cleandoc("""
44
+ ```python
45
+ [
46
+ UnitTest(
47
+ input={
48
+ 'items': [1, 2, 3, 4, 5],
49
+ 'target': 3
50
+ },
51
+ expected_output=2
52
+ ),
53
+ UnitTest(
54
+ input={
55
+ 'items': [1, 2, 3, 4, 5],
56
+ 'target': 6
57
+ },
58
+ expected_output=-1
59
+ )
60
+ ]
61
+ ```
62
+ """)
63
+
64
+ lm = fake.StaticSequence([unittest_lm_response, function_gen_lm_response])
65
+
66
+ @function_generation.function_gen(lm=lm)
67
+ def linear_search(items, target): # pylint: disable=unused-argument
68
+ """Performs a linear search on a list to find a target value.
69
+
70
+ Args:
71
+ items (list): The list to search within.
72
+ target: The value to search for.
73
+
74
+ Returns:
75
+ int: The index of the target value if found, otherwise -1.
76
+ """
77
+
78
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
79
+ self.assertEqual(linear_search.source(), function_gen_lm_response)
80
+
81
+ def test_custom_unittest_examples(self):
82
+ function_gen_lm_response = inspect.cleandoc("""
83
+ ```python
84
+ def linear_search(items, target):
85
+ \"\"\"
86
+ Performs a linear search on a list to find a target value.
87
+
88
+ Args:
89
+ items (list): The list to search within.
90
+ target: The value to search for.
91
+
92
+ Returns:
93
+ int: The index of the target value if found, otherwise -1.
94
+ \"\"\"
95
+ for i, item in enumerate(items):
96
+ if item == target:
97
+ return i
98
+ return -1
99
+ ```
100
+ """)
101
+
102
+ lm = fake.StaticSequence([function_gen_lm_response])
103
+
104
+ custom_unittest = [(([1, 2, 3, 4, 5], 3), 2)]
105
+
106
+ @function_generation.function_gen(lm=lm, unittest=custom_unittest)
107
+ def linear_search(items, target): # pylint: disable=unused-argument
108
+ """Performs a linear search on a list to find a target value.
109
+
110
+ Args:
111
+ items (list): The list to search within.
112
+ target: The value to search for.
113
+
114
+ Returns:
115
+ int: The index of the target value if found, otherwise -1.
116
+ """
117
+
118
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
119
+
120
+ def test_custom_unittest_fn(self):
121
+ function_gen_lm_response = inspect.cleandoc("""
122
+ ```python
123
+ def linear_search(items, target):
124
+ \"\"\"
125
+ Performs a linear search on a list to find a target value.
126
+
127
+ Args:
128
+ items (list): The list to search within.
129
+ target: The value to search for.
130
+
131
+ Returns:
132
+ int: The index of the target value if found, otherwise -1.
133
+ \"\"\"
134
+ for i, item in enumerate(items):
135
+ if item == target:
136
+ return i
137
+ return -1
138
+ ```
139
+ """)
140
+
141
+ lm = fake.StaticSequence([function_gen_lm_response])
142
+
143
+ def _unittest_fn(func):
144
+ assert func([1, 2, 3, 4, 5], 3) == 2
145
+ assert func([1, 2, 3, 4, 5], 6) == -1
146
+
147
+ custom_unittest = _unittest_fn
148
+
149
+ @function_generation.function_gen(lm=lm, unittest=custom_unittest)
150
+ def linear_search(items, target): # pylint: disable=unused-argument
151
+ """Performs a linear search on a list to find a target value.
152
+
153
+ Args:
154
+ items (list): The list to search within.
155
+ target: The value to search for.
156
+
157
+ Returns:
158
+ int: The index of the target value if found, otherwise -1.
159
+ """
160
+
161
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
162
+
163
+ def test_load_function_from_cache_file(self):
164
+ lm = fake.StaticSequence([])
165
+
166
+ def _unittest_fn(func):
167
+ assert func([1, 2, 3, 4, 5], 3) == 2
168
+ assert func([1, 2, 3, 4, 5], 6) == -1
169
+
170
+ cache_file_dir = tempfile.gettempdir()
171
+ cache_file = os.path.join(cache_file_dir, 'cache_file.json')
172
+
173
+ cache_key = """@function_generation.function_gen(
174
+ lm=lm,
175
+ unittest=_unittest_fn,
176
+ cache_filename=cache_file,
177
+ )
178
+ def linear_search(items, target): # pylint: disable=unused-argument
179
+ \"\"\"Performs a linear search on a list to find a target value.
180
+
181
+ Args:
182
+ items (list): The list to search within.
183
+ target: The value to search for.
184
+
185
+ Returns:
186
+ int: The index of the target value if found, otherwise -1.
187
+ \"\"\""""
188
+ cache_value = """
189
+ ```python
190
+ def linear_search(items, target):
191
+ \"\"\"
192
+ Performs a linear search on a list to find a target value.
193
+
194
+ Args:
195
+ items (list): The list to search within.
196
+ target: The value to search for.
197
+
198
+ Returns:
199
+ int: The index of the target value if found, otherwise -1.
200
+ \"\"\"
201
+ for i, item in enumerate(items):
202
+ if item == target:
203
+ return i
204
+ return -1
205
+ ```
206
+ """
207
+ cache = pg.Dict()
208
+ cache[cache_key] = cache_value
209
+ cache.save(cache_file)
210
+
211
+ @function_generation.function_gen(
212
+ lm=lm,
213
+ unittest=_unittest_fn,
214
+ cache_filename=cache_file,
215
+ )
216
+ def linear_search(items, target): # pylint: disable=unused-argument
217
+ """Performs a linear search on a list to find a target value.
218
+
219
+ Args:
220
+ items (list): The list to search within.
221
+ target: The value to search for.
222
+
223
+ Returns:
224
+ int: The index of the target value if found, otherwise -1.
225
+ """
226
+
227
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
228
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'd'), -1)
229
+
230
+ def test_empty_cache_file(self):
231
+ function_gen_lm_response = inspect.cleandoc("""
232
+ ```python
233
+ def linear_search(items, target):
234
+ \"\"\"
235
+ Performs a linear search on a list to find a target value.
236
+
237
+ Args:
238
+ items (list): The list to search within.
239
+ target: The value to search for.
240
+
241
+ Returns:
242
+ int: The index of the target value if found, otherwise -1.
243
+ \"\"\"
244
+ for i, item in enumerate(items):
245
+ if item == target:
246
+ return i
247
+ return -1
248
+ ```
249
+ """)
250
+
251
+ lm = fake.StaticSequence([function_gen_lm_response])
252
+
253
+ def _unittest_fn(func):
254
+ assert func([1, 2, 3, 4, 5], 3) == 2
255
+ assert func([1, 2, 3, 4, 5], 6) == -1
256
+
257
+ cache_file_dir = tempfile.gettempdir()
258
+ cache_file = os.path.join(cache_file_dir, 'cache_file.json')
259
+
260
+ @function_generation.function_gen(
261
+ lm=lm, unittest=_unittest_fn, cache_filename=cache_file
262
+ )
263
+ def linear_search(items, target): # pylint: disable=unused-argument
264
+ """Performs a linear search on a list to find a target value.
265
+
266
+ Args:
267
+ items (list): The list to search within.
268
+ target: The value to search for.
269
+
270
+ Returns:
271
+ int: The index of the target value if found, otherwise -1.
272
+ """
273
+
274
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
275
+
276
+ def test_siganture_check(self):
277
+ incorrect_signature_lm_response = inspect.cleandoc("""
278
+ ```python
279
+ def dummy():
280
+ pass
281
+ ```
282
+ """)
283
+ function_gen_lm_response = inspect.cleandoc("""
284
+ ```python
285
+ def linear_search(items, target):
286
+ \"\"\"
287
+ Performs a linear search on a list to find a target value.
288
+
289
+ Args:
290
+ items (list): The list to search within.
291
+ target: The value to search for.
292
+
293
+ Returns:
294
+ int: The index of the target value if found, otherwise -1.
295
+ \"\"\"
296
+ for i, item in enumerate(items):
297
+ if item == target:
298
+ return i
299
+ return -1
300
+ ```
301
+ """)
302
+
303
+ lm = fake.StaticSequence(
304
+ [incorrect_signature_lm_response, function_gen_lm_response]
305
+ )
306
+
307
+ def _unittest_fn(func):
308
+ assert func([1, 2, 3, 4, 5], 3) == 2
309
+ assert func([1, 2, 3, 4, 5], 6) == -1
310
+
311
+ custom_unittest = _unittest_fn
312
+
313
+ @function_generation.function_gen(lm=lm, unittest=custom_unittest)
314
+ def linear_search(items, target): # pylint: disable=unused-argument
315
+ """Performs a linear search on a list to find a target value.
316
+
317
+ Args:
318
+ items (list): The list to search within.
319
+ target: The value to search for.
320
+
321
+ Returns:
322
+ int: The index of the target value if found, otherwise -1.
323
+ """
324
+
325
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
326
+
327
+
328
+ if __name__ == '__main__':
329
+ unittest.main()
@@ -14,12 +14,49 @@
14
14
  """The base of symbolic mapping methods."""
15
15
 
16
16
  import io
17
- from typing import Annotated, Any
17
+ from typing import Annotated, Any, Callable
18
18
  import langfun.core as lf
19
19
  from langfun.core.structured import schema as schema_lib
20
20
  import pyglove as pg
21
21
 
22
22
 
23
+ class MappingError(Exception): # pylint: disable=g-bad-exception-name
24
+ """Mapping error."""
25
+
26
+ def __init__(self, lm_response: lf.Message, cause: Exception):
27
+ self._lm_response = lm_response
28
+ self._cause = cause
29
+
30
+ @property
31
+ def lm_response(self) -> lf.Message:
32
+ """Returns the LM response that failed to be mapped."""
33
+ return self._lm_response
34
+
35
+ @property
36
+ def cause(self) -> Exception:
37
+ """Returns the cause of the error."""
38
+ return self._cause
39
+
40
+ def __str__(self) -> str:
41
+ return self.format(include_lm_response=True)
42
+
43
+ def format(self, include_lm_response: bool = True) -> str:
44
+ """Formats the mapping error."""
45
+ r = io.StringIO()
46
+ error_message = str(self.cause).rstrip()
47
+ r.write(
48
+ lf.colored(
49
+ f'{self.cause.__class__.__name__}: {error_message}', 'magenta'
50
+ )
51
+ )
52
+ if include_lm_response:
53
+ r.write('\n\n')
54
+ r.write(lf.colored('[LM Response]', 'blue', styles=['bold']))
55
+ r.write('\n')
56
+ r.write(lf.colored(self.lm_response.text, 'blue'))
57
+ return r.getvalue()
58
+
59
+
23
60
  @pg.use_init_args(['input', 'output', 'schema', 'context'])
24
61
  class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
25
62
  """Mapping example between text, schema and structured value."""
@@ -278,6 +315,14 @@ class Mapping(lf.LangFunc):
278
315
  ),
279
316
  ] = lf.RAISE_IF_HAS_ERROR
280
317
 
318
+ response_postprocess: Annotated[
319
+ Callable[[str], str] | None,
320
+ (
321
+ 'A callable object that post process the raw LLM response before '
322
+ 'parsing it into the output Python object.'
323
+ )
324
+ ] = None
325
+
281
326
  #
282
327
  # Key methods for implementing specific mappings.
283
328
  #
@@ -296,10 +341,11 @@ class Mapping(lf.LangFunc):
296
341
  def transform_output(self, lm_output: lf.Message) -> lf.Message:
297
342
  """Transforms LM response into structure if schema is present."""
298
343
  try:
344
+ lm_output = self.postprocess_response(lm_output)
299
345
  lm_output.result = self.postprocess_result(self.parse_result(lm_output))
300
346
  except Exception as e: # pylint: disable=broad-exception-caught
301
347
  if self.default == lf.RAISE_IF_HAS_ERROR:
302
- raise e
348
+ raise MappingError(lm_output, e) from e
303
349
  lm_output.result = self.default
304
350
  return lm_output
305
351
 
@@ -316,6 +362,14 @@ class Mapping(lf.LangFunc):
316
362
  autofix_lm=self.autofix_lm or self.lm,
317
363
  )
318
364
 
365
+ def postprocess_response(self, response: lf.Message) -> lf.Message:
366
+ """Post process LLM response."""
367
+ if self.response_postprocess is not None:
368
+ postprocessed_text = self.response_postprocess(response.text)
369
+ if postprocessed_text != response.text:
370
+ return lf.AIMessage(postprocessed_text, source=response)
371
+ return response
372
+
319
373
  def postprocess_result(self, result: Any) -> Any:
320
374
  """Post process structured output."""
321
375
  return result
@@ -16,10 +16,27 @@
16
16
  import inspect
17
17
  import unittest
18
18
 
19
+ import langfun.core as lf
19
20
  from langfun.core.structured import mapping
20
21
  import pyglove as pg
21
22
 
22
23
 
24
+ class MappingErrorTest(unittest.TestCase):
25
+
26
+ def test_format(self):
27
+ error = mapping.MappingError(
28
+ lf.AIMessage('hi'), ValueError('Cannot parse message.')
29
+ )
30
+ self.assertEqual(
31
+ lf.text_formatting.decolored(str(error)),
32
+ 'ValueError: Cannot parse message.\n\n[LM Response]\nhi',
33
+ )
34
+ self.assertEqual(
35
+ lf.text_formatting.decolored(error.format(include_lm_response=False)),
36
+ 'ValueError: Cannot parse message.',
37
+ )
38
+
39
+
23
40
  class MappingExampleTest(unittest.TestCase):
24
41
 
25
42
  def test_basics(self):