langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +7 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +15 -0
- langfun/core/eval/base.py +665 -95
- langfun/core/eval/base_test.py +224 -53
- langfun/core/eval/matching.py +48 -30
- langfun/core/eval/matching_test.py +25 -3
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +19 -10
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/langfunc.py +1 -22
- langfun/core/langfunc_test.py +10 -4
- langfun/core/language_model.py +130 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +27 -2
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +34 -25
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/google_genai.py +8 -0
- langfun/core/llms/google_genai_test.py +8 -3
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +100 -81
- langfun/core/llms/openai_test.py +287 -60
- langfun/core/llms/vertexai.py +291 -0
- langfun/core/llms/vertexai_test.py +233 -0
- langfun/core/modalities/image.py +1 -3
- langfun/core/modalities/mime.py +6 -0
- langfun/core/modalities/video.py +6 -5
- langfun/core/structured/__init__.py +5 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +61 -3
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +61 -12
- langfun/core/structured/prompting_test.py +122 -12
- langfun/core/structured/schema.py +38 -6
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +36 -7
- langfun/core/structured/scoring.py +4 -1
- langfun/core/structured/scoring_test.py +6 -0
- langfun/core/template.py +147 -11
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
- langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,245 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""LLM-based function generation."""
|
15
|
+
|
16
|
+
import functools
|
17
|
+
import inspect
|
18
|
+
import re
|
19
|
+
from typing import Any, Callable, Optional, Tuple
|
20
|
+
|
21
|
+
from langfun.core import language_model
|
22
|
+
from langfun.core import template
|
23
|
+
from langfun.core.coding import python
|
24
|
+
from langfun.core.structured import prompting
|
25
|
+
import pyglove as pg
|
26
|
+
|
27
|
+
|
28
|
+
def unittest_gen(signature, lm, num_retries=10):
|
29
|
+
"""Generates unit tests for a python function signature."""
|
30
|
+
|
31
|
+
class UnitTest(pg.Object):
|
32
|
+
"""A valid unit test for a python function."""
|
33
|
+
|
34
|
+
input: dict[str, Any]
|
35
|
+
expected_output: Any
|
36
|
+
|
37
|
+
class PythonFunctionSignature(pg.Object):
|
38
|
+
signature: str
|
39
|
+
|
40
|
+
unittest_examples = None
|
41
|
+
for _ in range(num_retries):
|
42
|
+
r = prompting.query(
|
43
|
+
PythonFunctionSignature(signature=signature),
|
44
|
+
list[UnitTest],
|
45
|
+
lm=lm,
|
46
|
+
default=None,
|
47
|
+
)
|
48
|
+
if isinstance(r, list) and r:
|
49
|
+
unittest_examples = []
|
50
|
+
for unit_test in r:
|
51
|
+
unittest_examples.append((unit_test.input, unit_test.expected_output))
|
52
|
+
break
|
53
|
+
|
54
|
+
return unittest_examples
|
55
|
+
|
56
|
+
|
57
|
+
def unittest_with_test_cases(f, unittests):
|
58
|
+
"""Applies unit tests to a python function to be tested."""
|
59
|
+
if not unittests:
|
60
|
+
raise ValueError(f"No unit tests provided: {unittests}")
|
61
|
+
|
62
|
+
for unit_test in unittests:
|
63
|
+
inputs = unit_test[0]
|
64
|
+
if isinstance(inputs, dict):
|
65
|
+
actual = f(**inputs)
|
66
|
+
elif isinstance(inputs, tuple):
|
67
|
+
actual = f(*inputs)
|
68
|
+
else:
|
69
|
+
actual = f(inputs)
|
70
|
+
|
71
|
+
expected = unit_test[1]
|
72
|
+
assert (
|
73
|
+
actual == expected
|
74
|
+
), f"Test FAILED: Inputs: {inputs}, Expected: {expected}, Actual: {actual}"
|
75
|
+
|
76
|
+
|
77
|
+
def _function_gen(
|
78
|
+
func: Callable[..., Any],
|
79
|
+
signature: str,
|
80
|
+
lm: language_model.LanguageModel,
|
81
|
+
num_retries: int = 10,
|
82
|
+
unittest: Optional[
|
83
|
+
Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
|
84
|
+
] = None,
|
85
|
+
):
|
86
|
+
"""Generates a python function with LLM and verify its quality with unit testing."""
|
87
|
+
|
88
|
+
class PythonFunctionPrompt(template.Template):
|
89
|
+
r"""A template for a python function generation.
|
90
|
+
|
91
|
+
Please reply to the last PYTHON_FUNCTION_SIGNATURE with a self-sufficient,
|
92
|
+
error-free, and efficiently coded PYTHON_FUNCTION, crafted to the standards
|
93
|
+
of a world-class programmer.
|
94
|
+
|
95
|
+
PYTHON_FUNCTION_SIGNATURE:
|
96
|
+
```python
|
97
|
+
def calculate_area_circle(radius: float) -> float:
|
98
|
+
\"\"\"Calculates the area of a circle given its radius.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
radius: The radius of the circle.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
The area of the circle.
|
105
|
+
\"\"\"
|
106
|
+
```
|
107
|
+
|
108
|
+
PYTHON_FUNCTION:
|
109
|
+
```python
|
110
|
+
def calculate_area_circle(radius: float) -> float:
|
111
|
+
\"\"\"Calculates the area of a circle given its radius.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
radius: The radius of the circle.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
The area of the circle.
|
118
|
+
\"\"\"
|
119
|
+
import math
|
120
|
+
|
121
|
+
area = math.pi * radius**2
|
122
|
+
return area
|
123
|
+
```
|
124
|
+
|
125
|
+
PYTHON_FUNCTION_SIGNATURE:
|
126
|
+
```python
|
127
|
+
{{signature}}
|
128
|
+
```
|
129
|
+
|
130
|
+
PYTHON_FUNCTION:
|
131
|
+
"""
|
132
|
+
|
133
|
+
unittest_examples = None
|
134
|
+
if unittest is None:
|
135
|
+
unittest_examples = unittest_gen(signature, lm=lm)
|
136
|
+
elif not callable(unittest):
|
137
|
+
unittest_examples = unittest
|
138
|
+
|
139
|
+
for _ in range(num_retries):
|
140
|
+
try:
|
141
|
+
source_code = prompting.query(
|
142
|
+
PythonFunctionPrompt(signature=signature), lm=lm
|
143
|
+
)
|
144
|
+
f = python.evaluate(source_code)
|
145
|
+
|
146
|
+
# Check whether the sigantures are the same.
|
147
|
+
if inspect.signature(f) != inspect.signature(func):
|
148
|
+
continue
|
149
|
+
|
150
|
+
if callable(unittest):
|
151
|
+
unittest(f)
|
152
|
+
else:
|
153
|
+
unittest_with_test_cases(f, unittest_examples)
|
154
|
+
|
155
|
+
return f, source_code
|
156
|
+
except Exception: # pylint: disable=broad-exception-caught
|
157
|
+
pass
|
158
|
+
|
159
|
+
return None, None
|
160
|
+
|
161
|
+
|
162
|
+
def _process_signature(signature):
|
163
|
+
# Remove the decorator.
|
164
|
+
pattern = r"^\@.*function_gen.*$"
|
165
|
+
signature = re.sub(pattern, "", signature, flags=re.MULTILINE)
|
166
|
+
# Remove the possible 'pass' in an empty function.
|
167
|
+
pattern = r"^\s*pass\s*$"
|
168
|
+
signature = re.sub(pattern, "", signature, flags=re.MULTILINE)
|
169
|
+
return signature.strip()
|
170
|
+
|
171
|
+
|
172
|
+
def function_gen(
|
173
|
+
lm: language_model.LanguageModel,
|
174
|
+
cache_filename: str | None = None,
|
175
|
+
num_retries: int = 10,
|
176
|
+
unittest: Optional[
|
177
|
+
Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
|
178
|
+
] = None,
|
179
|
+
):
|
180
|
+
"""A decorator for automating function generation using a language model.
|
181
|
+
|
182
|
+
This decorator should be applied to functions that are not yet implemented. It
|
183
|
+
facilitates the implementation via the specified LLM, ensuring
|
184
|
+
quality through unit tests.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
lm (lf.LanguageModel): The language model used for generating function
|
188
|
+
implementations.
|
189
|
+
cache_filename (str | None): Optional. The path of the file where
|
190
|
+
generated function implementations are loaded from or saved to.
|
191
|
+
num_retries (int): Maximum number of attempts the language model should
|
192
|
+
make to generate a suitable function implementation.
|
193
|
+
unittest: This optional parameter enables the definition of custom unit
|
194
|
+
tests. You can either provide a list of test cases as tuples of inputs
|
195
|
+
and outputs, or a function that throws an error if a test fails. If left
|
196
|
+
as None (the default setting), the LLM will automatically create the
|
197
|
+
unit test cases.
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
The implemented function object.
|
201
|
+
"""
|
202
|
+
|
203
|
+
def _decorate(func):
|
204
|
+
setattr(func, "__function__", None)
|
205
|
+
setattr(func, "__source_code__", None)
|
206
|
+
|
207
|
+
@functools.wraps(func)
|
208
|
+
def lm_generated_func(*args, **kwargs):
|
209
|
+
if func.__function__ is not None:
|
210
|
+
return func.__function__(*args, **kwargs)
|
211
|
+
|
212
|
+
signature = _process_signature(inspect.getsource(func))
|
213
|
+
cache = pg.Dict()
|
214
|
+
if cache_filename is not None:
|
215
|
+
try:
|
216
|
+
cache = pg.load(cache_filename)
|
217
|
+
except FileNotFoundError:
|
218
|
+
pg.logging.warning(
|
219
|
+
"Creating a new cache as cache file '%s' does not exist.",
|
220
|
+
cache_filename,
|
221
|
+
)
|
222
|
+
|
223
|
+
if signature in cache:
|
224
|
+
func.__source_code__ = cache[signature]
|
225
|
+
func.__function__ = python.evaluate(func.__source_code__)
|
226
|
+
return func.__function__(*args, **kwargs)
|
227
|
+
|
228
|
+
func.__function__, func.__source_code__ = _function_gen(
|
229
|
+
func, signature, lm, num_retries=num_retries, unittest=unittest
|
230
|
+
)
|
231
|
+
if func.__function__ is None:
|
232
|
+
raise ValueError(f"Function generation failed. Signature:\n{signature}")
|
233
|
+
|
234
|
+
if cache_filename is not None:
|
235
|
+
cache[signature] = func.__source_code__
|
236
|
+
cache.save(cache_filename)
|
237
|
+
return func.__function__(*args, **kwargs)
|
238
|
+
|
239
|
+
lm_generated_func.__name__ = func.__name__
|
240
|
+
lm_generated_func.__qualname__ = func.__qualname__
|
241
|
+
lm_generated_func.__module__ = func.__module__
|
242
|
+
lm_generated_func.source = lambda: func.__source_code__
|
243
|
+
return lm_generated_func
|
244
|
+
|
245
|
+
return _decorate
|
@@ -0,0 +1,329 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import inspect
|
15
|
+
import os
|
16
|
+
import tempfile
|
17
|
+
import unittest
|
18
|
+
from langfun.core.llms import fake
|
19
|
+
from langfun.core.structured import function_generation
|
20
|
+
import pyglove as pg
|
21
|
+
|
22
|
+
|
23
|
+
class FunctionGenerationTest(unittest.TestCase):
|
24
|
+
|
25
|
+
def test_generate_function(self):
|
26
|
+
function_gen_lm_response = inspect.cleandoc("""
|
27
|
+
def linear_search(items, target):
|
28
|
+
\"\"\"
|
29
|
+
Performs a linear search on a list to find a target value.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
items (list): The list to search within.
|
33
|
+
target: The value to search for.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
int: The index of the target value if found, otherwise -1.
|
37
|
+
\"\"\"
|
38
|
+
for i, item in enumerate(items):
|
39
|
+
if item == target:
|
40
|
+
return i
|
41
|
+
return -1
|
42
|
+
""")
|
43
|
+
unittest_lm_response = inspect.cleandoc("""
|
44
|
+
```python
|
45
|
+
[
|
46
|
+
UnitTest(
|
47
|
+
input={
|
48
|
+
'items': [1, 2, 3, 4, 5],
|
49
|
+
'target': 3
|
50
|
+
},
|
51
|
+
expected_output=2
|
52
|
+
),
|
53
|
+
UnitTest(
|
54
|
+
input={
|
55
|
+
'items': [1, 2, 3, 4, 5],
|
56
|
+
'target': 6
|
57
|
+
},
|
58
|
+
expected_output=-1
|
59
|
+
)
|
60
|
+
]
|
61
|
+
```
|
62
|
+
""")
|
63
|
+
|
64
|
+
lm = fake.StaticSequence([unittest_lm_response, function_gen_lm_response])
|
65
|
+
|
66
|
+
@function_generation.function_gen(lm=lm)
|
67
|
+
def linear_search(items, target): # pylint: disable=unused-argument
|
68
|
+
"""Performs a linear search on a list to find a target value.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
items (list): The list to search within.
|
72
|
+
target: The value to search for.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
int: The index of the target value if found, otherwise -1.
|
76
|
+
"""
|
77
|
+
|
78
|
+
self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
|
79
|
+
self.assertEqual(linear_search.source(), function_gen_lm_response)
|
80
|
+
|
81
|
+
def test_custom_unittest_examples(self):
|
82
|
+
function_gen_lm_response = inspect.cleandoc("""
|
83
|
+
```python
|
84
|
+
def linear_search(items, target):
|
85
|
+
\"\"\"
|
86
|
+
Performs a linear search on a list to find a target value.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
items (list): The list to search within.
|
90
|
+
target: The value to search for.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
int: The index of the target value if found, otherwise -1.
|
94
|
+
\"\"\"
|
95
|
+
for i, item in enumerate(items):
|
96
|
+
if item == target:
|
97
|
+
return i
|
98
|
+
return -1
|
99
|
+
```
|
100
|
+
""")
|
101
|
+
|
102
|
+
lm = fake.StaticSequence([function_gen_lm_response])
|
103
|
+
|
104
|
+
custom_unittest = [(([1, 2, 3, 4, 5], 3), 2)]
|
105
|
+
|
106
|
+
@function_generation.function_gen(lm=lm, unittest=custom_unittest)
|
107
|
+
def linear_search(items, target): # pylint: disable=unused-argument
|
108
|
+
"""Performs a linear search on a list to find a target value.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
items (list): The list to search within.
|
112
|
+
target: The value to search for.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
int: The index of the target value if found, otherwise -1.
|
116
|
+
"""
|
117
|
+
|
118
|
+
self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
|
119
|
+
|
120
|
+
def test_custom_unittest_fn(self):
|
121
|
+
function_gen_lm_response = inspect.cleandoc("""
|
122
|
+
```python
|
123
|
+
def linear_search(items, target):
|
124
|
+
\"\"\"
|
125
|
+
Performs a linear search on a list to find a target value.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
items (list): The list to search within.
|
129
|
+
target: The value to search for.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
int: The index of the target value if found, otherwise -1.
|
133
|
+
\"\"\"
|
134
|
+
for i, item in enumerate(items):
|
135
|
+
if item == target:
|
136
|
+
return i
|
137
|
+
return -1
|
138
|
+
```
|
139
|
+
""")
|
140
|
+
|
141
|
+
lm = fake.StaticSequence([function_gen_lm_response])
|
142
|
+
|
143
|
+
def _unittest_fn(func):
|
144
|
+
assert func([1, 2, 3, 4, 5], 3) == 2
|
145
|
+
assert func([1, 2, 3, 4, 5], 6) == -1
|
146
|
+
|
147
|
+
custom_unittest = _unittest_fn
|
148
|
+
|
149
|
+
@function_generation.function_gen(lm=lm, unittest=custom_unittest)
|
150
|
+
def linear_search(items, target): # pylint: disable=unused-argument
|
151
|
+
"""Performs a linear search on a list to find a target value.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
items (list): The list to search within.
|
155
|
+
target: The value to search for.
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
int: The index of the target value if found, otherwise -1.
|
159
|
+
"""
|
160
|
+
|
161
|
+
self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
|
162
|
+
|
163
|
+
def test_load_function_from_cache_file(self):
|
164
|
+
lm = fake.StaticSequence([])
|
165
|
+
|
166
|
+
def _unittest_fn(func):
|
167
|
+
assert func([1, 2, 3, 4, 5], 3) == 2
|
168
|
+
assert func([1, 2, 3, 4, 5], 6) == -1
|
169
|
+
|
170
|
+
cache_file_dir = tempfile.gettempdir()
|
171
|
+
cache_file = os.path.join(cache_file_dir, 'cache_file.json')
|
172
|
+
|
173
|
+
cache_key = """@function_generation.function_gen(
|
174
|
+
lm=lm,
|
175
|
+
unittest=_unittest_fn,
|
176
|
+
cache_filename=cache_file,
|
177
|
+
)
|
178
|
+
def linear_search(items, target): # pylint: disable=unused-argument
|
179
|
+
\"\"\"Performs a linear search on a list to find a target value.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
items (list): The list to search within.
|
183
|
+
target: The value to search for.
|
184
|
+
|
185
|
+
Returns:
|
186
|
+
int: The index of the target value if found, otherwise -1.
|
187
|
+
\"\"\""""
|
188
|
+
cache_value = """
|
189
|
+
```python
|
190
|
+
def linear_search(items, target):
|
191
|
+
\"\"\"
|
192
|
+
Performs a linear search on a list to find a target value.
|
193
|
+
|
194
|
+
Args:
|
195
|
+
items (list): The list to search within.
|
196
|
+
target: The value to search for.
|
197
|
+
|
198
|
+
Returns:
|
199
|
+
int: The index of the target value if found, otherwise -1.
|
200
|
+
\"\"\"
|
201
|
+
for i, item in enumerate(items):
|
202
|
+
if item == target:
|
203
|
+
return i
|
204
|
+
return -1
|
205
|
+
```
|
206
|
+
"""
|
207
|
+
cache = pg.Dict()
|
208
|
+
cache[cache_key] = cache_value
|
209
|
+
cache.save(cache_file)
|
210
|
+
|
211
|
+
@function_generation.function_gen(
|
212
|
+
lm=lm,
|
213
|
+
unittest=_unittest_fn,
|
214
|
+
cache_filename=cache_file,
|
215
|
+
)
|
216
|
+
def linear_search(items, target): # pylint: disable=unused-argument
|
217
|
+
"""Performs a linear search on a list to find a target value.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
items (list): The list to search within.
|
221
|
+
target: The value to search for.
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
int: The index of the target value if found, otherwise -1.
|
225
|
+
"""
|
226
|
+
|
227
|
+
self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
|
228
|
+
self.assertEqual(linear_search(['a', 'b', 'c'], 'd'), -1)
|
229
|
+
|
230
|
+
def test_empty_cache_file(self):
|
231
|
+
function_gen_lm_response = inspect.cleandoc("""
|
232
|
+
```python
|
233
|
+
def linear_search(items, target):
|
234
|
+
\"\"\"
|
235
|
+
Performs a linear search on a list to find a target value.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
items (list): The list to search within.
|
239
|
+
target: The value to search for.
|
240
|
+
|
241
|
+
Returns:
|
242
|
+
int: The index of the target value if found, otherwise -1.
|
243
|
+
\"\"\"
|
244
|
+
for i, item in enumerate(items):
|
245
|
+
if item == target:
|
246
|
+
return i
|
247
|
+
return -1
|
248
|
+
```
|
249
|
+
""")
|
250
|
+
|
251
|
+
lm = fake.StaticSequence([function_gen_lm_response])
|
252
|
+
|
253
|
+
def _unittest_fn(func):
|
254
|
+
assert func([1, 2, 3, 4, 5], 3) == 2
|
255
|
+
assert func([1, 2, 3, 4, 5], 6) == -1
|
256
|
+
|
257
|
+
cache_file_dir = tempfile.gettempdir()
|
258
|
+
cache_file = os.path.join(cache_file_dir, 'cache_file.json')
|
259
|
+
|
260
|
+
@function_generation.function_gen(
|
261
|
+
lm=lm, unittest=_unittest_fn, cache_filename=cache_file
|
262
|
+
)
|
263
|
+
def linear_search(items, target): # pylint: disable=unused-argument
|
264
|
+
"""Performs a linear search on a list to find a target value.
|
265
|
+
|
266
|
+
Args:
|
267
|
+
items (list): The list to search within.
|
268
|
+
target: The value to search for.
|
269
|
+
|
270
|
+
Returns:
|
271
|
+
int: The index of the target value if found, otherwise -1.
|
272
|
+
"""
|
273
|
+
|
274
|
+
self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
|
275
|
+
|
276
|
+
def test_siganture_check(self):
|
277
|
+
incorrect_signature_lm_response = inspect.cleandoc("""
|
278
|
+
```python
|
279
|
+
def dummy():
|
280
|
+
pass
|
281
|
+
```
|
282
|
+
""")
|
283
|
+
function_gen_lm_response = inspect.cleandoc("""
|
284
|
+
```python
|
285
|
+
def linear_search(items, target):
|
286
|
+
\"\"\"
|
287
|
+
Performs a linear search on a list to find a target value.
|
288
|
+
|
289
|
+
Args:
|
290
|
+
items (list): The list to search within.
|
291
|
+
target: The value to search for.
|
292
|
+
|
293
|
+
Returns:
|
294
|
+
int: The index of the target value if found, otherwise -1.
|
295
|
+
\"\"\"
|
296
|
+
for i, item in enumerate(items):
|
297
|
+
if item == target:
|
298
|
+
return i
|
299
|
+
return -1
|
300
|
+
```
|
301
|
+
""")
|
302
|
+
|
303
|
+
lm = fake.StaticSequence(
|
304
|
+
[incorrect_signature_lm_response, function_gen_lm_response]
|
305
|
+
)
|
306
|
+
|
307
|
+
def _unittest_fn(func):
|
308
|
+
assert func([1, 2, 3, 4, 5], 3) == 2
|
309
|
+
assert func([1, 2, 3, 4, 5], 6) == -1
|
310
|
+
|
311
|
+
custom_unittest = _unittest_fn
|
312
|
+
|
313
|
+
@function_generation.function_gen(lm=lm, unittest=custom_unittest)
|
314
|
+
def linear_search(items, target): # pylint: disable=unused-argument
|
315
|
+
"""Performs a linear search on a list to find a target value.
|
316
|
+
|
317
|
+
Args:
|
318
|
+
items (list): The list to search within.
|
319
|
+
target: The value to search for.
|
320
|
+
|
321
|
+
Returns:
|
322
|
+
int: The index of the target value if found, otherwise -1.
|
323
|
+
"""
|
324
|
+
|
325
|
+
self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
|
326
|
+
|
327
|
+
|
328
|
+
if __name__ == '__main__':
|
329
|
+
unittest.main()
|
@@ -14,12 +14,49 @@
|
|
14
14
|
"""The base of symbolic mapping methods."""
|
15
15
|
|
16
16
|
import io
|
17
|
-
from typing import Annotated, Any
|
17
|
+
from typing import Annotated, Any, Callable
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core.structured import schema as schema_lib
|
20
20
|
import pyglove as pg
|
21
21
|
|
22
22
|
|
23
|
+
class MappingError(Exception): # pylint: disable=g-bad-exception-name
|
24
|
+
"""Mapping error."""
|
25
|
+
|
26
|
+
def __init__(self, lm_response: lf.Message, cause: Exception):
|
27
|
+
self._lm_response = lm_response
|
28
|
+
self._cause = cause
|
29
|
+
|
30
|
+
@property
|
31
|
+
def lm_response(self) -> lf.Message:
|
32
|
+
"""Returns the LM response that failed to be mapped."""
|
33
|
+
return self._lm_response
|
34
|
+
|
35
|
+
@property
|
36
|
+
def cause(self) -> Exception:
|
37
|
+
"""Returns the cause of the error."""
|
38
|
+
return self._cause
|
39
|
+
|
40
|
+
def __str__(self) -> str:
|
41
|
+
return self.format(include_lm_response=True)
|
42
|
+
|
43
|
+
def format(self, include_lm_response: bool = True) -> str:
|
44
|
+
"""Formats the mapping error."""
|
45
|
+
r = io.StringIO()
|
46
|
+
error_message = str(self.cause).rstrip()
|
47
|
+
r.write(
|
48
|
+
lf.colored(
|
49
|
+
f'{self.cause.__class__.__name__}: {error_message}', 'magenta'
|
50
|
+
)
|
51
|
+
)
|
52
|
+
if include_lm_response:
|
53
|
+
r.write('\n\n')
|
54
|
+
r.write(lf.colored('[LM Response]', 'blue', styles=['bold']))
|
55
|
+
r.write('\n')
|
56
|
+
r.write(lf.colored(self.lm_response.text, 'blue'))
|
57
|
+
return r.getvalue()
|
58
|
+
|
59
|
+
|
23
60
|
@pg.use_init_args(['input', 'output', 'schema', 'context'])
|
24
61
|
class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
|
25
62
|
"""Mapping example between text, schema and structured value."""
|
@@ -214,7 +251,7 @@ class Mapping(lf.LangFunc):
|
|
214
251
|
|
215
252
|
{%- if example.schema -%}
|
216
253
|
{{ schema_title }}:
|
217
|
-
{{ example.schema_repr(protocol) | indent(2, True) }}
|
254
|
+
{{ example.schema_repr(protocol, include_methods=include_methods) | indent(2, True) }}
|
218
255
|
|
219
256
|
{% endif -%}
|
220
257
|
|
@@ -242,6 +279,10 @@ class Mapping(lf.LangFunc):
|
|
242
279
|
'The protocol for representing the schema and value.',
|
243
280
|
] = 'python'
|
244
281
|
|
282
|
+
include_methods: Annotated[
|
283
|
+
bool, 'If True, include method definitions in the schema.'
|
284
|
+
] = False
|
285
|
+
|
245
286
|
#
|
246
287
|
# Other user-provided flags.
|
247
288
|
#
|
@@ -278,6 +319,14 @@ class Mapping(lf.LangFunc):
|
|
278
319
|
),
|
279
320
|
] = lf.RAISE_IF_HAS_ERROR
|
280
321
|
|
322
|
+
response_postprocess: Annotated[
|
323
|
+
Callable[[str], str] | None,
|
324
|
+
(
|
325
|
+
'A callable object that post process the raw LLM response before '
|
326
|
+
'parsing it into the output Python object.'
|
327
|
+
)
|
328
|
+
] = None
|
329
|
+
|
281
330
|
#
|
282
331
|
# Key methods for implementing specific mappings.
|
283
332
|
#
|
@@ -296,10 +345,11 @@ class Mapping(lf.LangFunc):
|
|
296
345
|
def transform_output(self, lm_output: lf.Message) -> lf.Message:
|
297
346
|
"""Transforms LM response into structure if schema is present."""
|
298
347
|
try:
|
348
|
+
lm_output = self.postprocess_response(lm_output)
|
299
349
|
lm_output.result = self.postprocess_result(self.parse_result(lm_output))
|
300
350
|
except Exception as e: # pylint: disable=broad-exception-caught
|
301
351
|
if self.default == lf.RAISE_IF_HAS_ERROR:
|
302
|
-
raise e
|
352
|
+
raise MappingError(lm_output, e) from e
|
303
353
|
lm_output.result = self.default
|
304
354
|
return lm_output
|
305
355
|
|
@@ -316,6 +366,14 @@ class Mapping(lf.LangFunc):
|
|
316
366
|
autofix_lm=self.autofix_lm or self.lm,
|
317
367
|
)
|
318
368
|
|
369
|
+
def postprocess_response(self, response: lf.Message) -> lf.Message:
|
370
|
+
"""Post process LLM response."""
|
371
|
+
if self.response_postprocess is not None:
|
372
|
+
postprocessed_text = self.response_postprocess(response.text)
|
373
|
+
if postprocessed_text != response.text:
|
374
|
+
return lf.AIMessage(postprocessed_text, source=response)
|
375
|
+
return response
|
376
|
+
|
319
377
|
def postprocess_result(self, result: Any) -> Any:
|
320
378
|
"""Post process structured output."""
|
321
379
|
return result
|
@@ -16,10 +16,27 @@
|
|
16
16
|
import inspect
|
17
17
|
import unittest
|
18
18
|
|
19
|
+
import langfun.core as lf
|
19
20
|
from langfun.core.structured import mapping
|
20
21
|
import pyglove as pg
|
21
22
|
|
22
23
|
|
24
|
+
class MappingErrorTest(unittest.TestCase):
|
25
|
+
|
26
|
+
def test_format(self):
|
27
|
+
error = mapping.MappingError(
|
28
|
+
lf.AIMessage('hi'), ValueError('Cannot parse message.')
|
29
|
+
)
|
30
|
+
self.assertEqual(
|
31
|
+
lf.text_formatting.decolored(str(error)),
|
32
|
+
'ValueError: Cannot parse message.\n\n[LM Response]\nhi',
|
33
|
+
)
|
34
|
+
self.assertEqual(
|
35
|
+
lf.text_formatting.decolored(error.format(include_lm_response=False)),
|
36
|
+
'ValueError: Cannot parse message.',
|
37
|
+
)
|
38
|
+
|
39
|
+
|
23
40
|
class MappingExampleTest(unittest.TestCase):
|
24
41
|
|
25
42
|
def test_basics(self):
|