langfun 0.1.2.dev202412060804__py3-none-any.whl → 0.1.2.dev202412080804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/structured/function_generation.py +57 -24
- langfun/core/structured/function_generation_test.py +72 -2
- {langfun-0.1.2.dev202412060804.dist-info → langfun-0.1.2.dev202412080804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202412060804.dist-info → langfun-0.1.2.dev202412080804.dist-info}/RECORD +7 -7
- {langfun-0.1.2.dev202412060804.dist-info → langfun-0.1.2.dev202412080804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412060804.dist-info → langfun-0.1.2.dev202412080804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412060804.dist-info → langfun-0.1.2.dev202412080804.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@
|
|
16
16
|
import functools
|
17
17
|
import inspect
|
18
18
|
import re
|
19
|
-
from typing import Any, Callable, Optional, Tuple
|
19
|
+
from typing import Any, Callable, Literal, Optional, Tuple
|
20
20
|
|
21
21
|
from langfun.core import language_model
|
22
22
|
from langfun.core import template
|
@@ -25,7 +25,7 @@ from langfun.core.structured import prompting
|
|
25
25
|
import pyglove as pg
|
26
26
|
|
27
27
|
|
28
|
-
def unittest_gen(signature, lm, num_retries=
|
28
|
+
def unittest_gen(signature, lm, num_retries=1):
|
29
29
|
"""Generates unit tests for a python function signature."""
|
30
30
|
|
31
31
|
class UnitTest(pg.Object):
|
@@ -76,12 +76,16 @@ def unittest_with_test_cases(f, unittests):
|
|
76
76
|
|
77
77
|
def _function_gen(
|
78
78
|
func: Callable[..., Any],
|
79
|
+
context: dict[str, Any],
|
79
80
|
signature: str,
|
80
81
|
lm: language_model.LanguageModel,
|
81
|
-
num_retries: int =
|
82
|
+
num_retries: int = 1,
|
82
83
|
unittest: Optional[
|
83
|
-
Callable[[Callable[..., Any]], None]
|
84
|
+
Callable[[Callable[..., Any]], None]
|
85
|
+
| list[Tuple[Any, Any]]
|
86
|
+
| Literal["auto"]
|
84
87
|
] = None,
|
88
|
+
unittest_num_retries: int = 1,
|
85
89
|
):
|
86
90
|
"""Generates a python function with LLM and verify its quality with unit testing."""
|
87
91
|
|
@@ -131,32 +135,43 @@ def _function_gen(
|
|
131
135
|
"""
|
132
136
|
|
133
137
|
unittest_examples = None
|
134
|
-
if unittest
|
135
|
-
unittest_examples = unittest_gen(
|
136
|
-
|
138
|
+
if unittest == "auto":
|
139
|
+
unittest_examples = unittest_gen(
|
140
|
+
signature, lm=lm, num_retries=unittest_num_retries
|
141
|
+
)
|
142
|
+
elif isinstance(unittest, list):
|
137
143
|
unittest_examples = unittest
|
138
144
|
|
145
|
+
last_error = None
|
139
146
|
for _ in range(num_retries):
|
140
147
|
try:
|
141
148
|
source_code = prompting.query(
|
142
149
|
PythonFunctionPrompt(signature=signature), lm=lm
|
143
150
|
)
|
144
|
-
f = python.evaluate(source_code)
|
151
|
+
f = python.evaluate(source_code, global_vars=context)
|
145
152
|
|
146
153
|
# Check whether the sigantures are the same.
|
147
154
|
if inspect.signature(f) != inspect.signature(func):
|
148
|
-
|
155
|
+
raise python.CodeError(
|
156
|
+
code=source_code,
|
157
|
+
cause=TypeError(
|
158
|
+
f"Signature mismatch: Expected: {inspect.signature(func)}, "
|
159
|
+
f"Actual: {inspect.signature(f)}.",
|
160
|
+
),
|
161
|
+
)
|
149
162
|
|
150
163
|
if callable(unittest):
|
151
164
|
unittest(f)
|
152
|
-
|
165
|
+
elif unittest_examples:
|
153
166
|
unittest_with_test_cases(f, unittest_examples)
|
154
167
|
|
155
168
|
return f, source_code
|
156
|
-
except
|
157
|
-
|
158
|
-
|
159
|
-
|
169
|
+
except python.CodeError as e:
|
170
|
+
last_error = e
|
171
|
+
pg.logging.warning(
|
172
|
+
f"Bad code generated: {e}",
|
173
|
+
)
|
174
|
+
raise last_error
|
160
175
|
|
161
176
|
|
162
177
|
def _process_signature(signature):
|
@@ -172,10 +187,13 @@ def _process_signature(signature):
|
|
172
187
|
def function_gen(
|
173
188
|
lm: language_model.LanguageModel,
|
174
189
|
cache_filename: str | None = None,
|
175
|
-
num_retries: int =
|
190
|
+
num_retries: int = 1,
|
176
191
|
unittest: Optional[
|
177
|
-
Callable[[Callable[..., Any]], None]
|
192
|
+
Callable[[Callable[..., Any]], None]
|
193
|
+
| list[Tuple[Any, Any]]
|
194
|
+
| Literal["auto"]
|
178
195
|
] = None,
|
196
|
+
unittest_num_retries: int = 1,
|
179
197
|
):
|
180
198
|
"""A decorator for automating function generation using a language model.
|
181
199
|
|
@@ -192,9 +210,12 @@ def function_gen(
|
|
192
210
|
make to generate a suitable function implementation.
|
193
211
|
unittest: This optional parameter enables the definition of custom unit
|
194
212
|
tests. You can either provide a list of test cases as tuples of inputs
|
195
|
-
and outputs, or a function that throws an error if a test fails
|
196
|
-
|
197
|
-
|
213
|
+
and outputs, or a function that throws an error if a test fails, or let
|
214
|
+
LLM automatically create the unit test cases. If a generated function is
|
215
|
+
and returned, it should pass all the unittests.
|
216
|
+
unittest_num_retries: If unittest is set to "auto", this parameter
|
217
|
+
specifies the number of times the LLM's attempts to generate unit test
|
218
|
+
cases.
|
198
219
|
|
199
220
|
Returns:
|
200
221
|
The implemented function object.
|
@@ -204,6 +225,13 @@ def function_gen(
|
|
204
225
|
setattr(func, "__function__", None)
|
205
226
|
setattr(func, "__source_code__", None)
|
206
227
|
|
228
|
+
# Prepare the globals/locals for the generated code to be evaluated against.
|
229
|
+
callstack = inspect.stack()
|
230
|
+
assert len(callstack) > 1
|
231
|
+
context = dict(callstack[1][0].f_globals)
|
232
|
+
context.update(callstack[1][0].f_locals)
|
233
|
+
context.pop(func.__name__, None)
|
234
|
+
|
207
235
|
@functools.wraps(func)
|
208
236
|
def lm_generated_func(*args, **kwargs):
|
209
237
|
if func.__function__ is not None:
|
@@ -222,15 +250,20 @@ def function_gen(
|
|
222
250
|
|
223
251
|
if signature in cache:
|
224
252
|
func.__source_code__ = cache[signature]
|
225
|
-
func.__function__ = python.evaluate(
|
253
|
+
func.__function__ = python.evaluate(
|
254
|
+
func.__source_code__, global_vars=context
|
255
|
+
)
|
226
256
|
return func.__function__(*args, **kwargs)
|
227
257
|
|
228
258
|
func.__function__, func.__source_code__ = _function_gen(
|
229
|
-
func,
|
259
|
+
func,
|
260
|
+
context,
|
261
|
+
signature,
|
262
|
+
lm,
|
263
|
+
num_retries=num_retries,
|
264
|
+
unittest=unittest,
|
265
|
+
unittest_num_retries=unittest_num_retries,
|
230
266
|
)
|
231
|
-
if func.__function__ is None:
|
232
|
-
raise ValueError(f"Function generation failed. Signature:\n{signature}")
|
233
|
-
|
234
267
|
if cache_filename is not None:
|
235
268
|
cache[signature] = func.__source_code__
|
236
269
|
cache.save(cache_filename)
|
@@ -63,6 +63,42 @@ class FunctionGenerationTest(unittest.TestCase):
|
|
63
63
|
|
64
64
|
lm = fake.StaticSequence([unittest_lm_response, function_gen_lm_response])
|
65
65
|
|
66
|
+
@function_generation.function_gen(lm=lm, unittest='auto')
|
67
|
+
def linear_search(items, target): # pylint: disable=unused-argument
|
68
|
+
"""Performs a linear search on a list to find a target value.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
items (list): The list to search within.
|
72
|
+
target: The value to search for.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
int: The index of the target value if found, otherwise -1.
|
76
|
+
"""
|
77
|
+
|
78
|
+
self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
|
79
|
+
self.assertEqual(linear_search.source(), function_gen_lm_response)
|
80
|
+
|
81
|
+
def test_generate_function_without_unittest(self):
|
82
|
+
function_gen_lm_response = inspect.cleandoc("""
|
83
|
+
def linear_search(items, target):
|
84
|
+
\"\"\"
|
85
|
+
Performs a linear search on a list to find a target value.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
items (list): The list to search within.
|
89
|
+
target: The value to search for.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
int: The index of the target value if found, otherwise -1.
|
93
|
+
\"\"\"
|
94
|
+
for i, item in enumerate(items):
|
95
|
+
if item == target:
|
96
|
+
return i
|
97
|
+
return -1
|
98
|
+
""")
|
99
|
+
|
100
|
+
lm = fake.StaticSequence([function_gen_lm_response])
|
101
|
+
|
66
102
|
@function_generation.function_gen(lm=lm)
|
67
103
|
def linear_search(items, target): # pylint: disable=unused-argument
|
68
104
|
"""Performs a linear search on a list to find a target value.
|
@@ -258,7 +294,9 @@ class FunctionGenerationTest(unittest.TestCase):
|
|
258
294
|
cache_file = os.path.join(cache_file_dir, 'cache_file.json')
|
259
295
|
|
260
296
|
@function_generation.function_gen(
|
261
|
-
lm=lm,
|
297
|
+
lm=lm,
|
298
|
+
unittest=_unittest_fn,
|
299
|
+
cache_filename=cache_file,
|
262
300
|
)
|
263
301
|
def linear_search(items, target): # pylint: disable=unused-argument
|
264
302
|
"""Performs a linear search on a list to find a target value.
|
@@ -273,6 +311,36 @@ class FunctionGenerationTest(unittest.TestCase):
|
|
273
311
|
|
274
312
|
self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
|
275
313
|
|
314
|
+
def test_context_passthrough(self):
|
315
|
+
|
316
|
+
class Number(pg.Object):
|
317
|
+
value: int
|
318
|
+
|
319
|
+
function_gen_lm_response = inspect.cleandoc("""
|
320
|
+
```python
|
321
|
+
def add(a: Number, b: Number) -> Number:
|
322
|
+
\"\"\"Adds two numbers together.\"\"\"
|
323
|
+
return Number(a.value + b.value)
|
324
|
+
```
|
325
|
+
""")
|
326
|
+
|
327
|
+
lm = fake.StaticSequence(
|
328
|
+
[function_gen_lm_response]
|
329
|
+
)
|
330
|
+
|
331
|
+
def _unittest_fn(func):
|
332
|
+
assert func(Number(1), Number(2)) == Number(3)
|
333
|
+
|
334
|
+
custom_unittest = _unittest_fn
|
335
|
+
|
336
|
+
@function_generation.function_gen(
|
337
|
+
lm=lm, unittest=custom_unittest, num_retries=1
|
338
|
+
)
|
339
|
+
def add(a: Number, b: Number) -> Number: # pylint: disable=unused-argument
|
340
|
+
"""Adds two numbers together."""
|
341
|
+
|
342
|
+
self.assertEqual(add(Number(2), Number(3)), Number(5))
|
343
|
+
|
276
344
|
def test_siganture_check(self):
|
277
345
|
incorrect_signature_lm_response = inspect.cleandoc("""
|
278
346
|
```python
|
@@ -310,7 +378,9 @@ class FunctionGenerationTest(unittest.TestCase):
|
|
310
378
|
|
311
379
|
custom_unittest = _unittest_fn
|
312
380
|
|
313
|
-
@function_generation.function_gen(
|
381
|
+
@function_generation.function_gen(
|
382
|
+
lm=lm, unittest=custom_unittest, num_retries=2
|
383
|
+
)
|
314
384
|
def linear_search(items, target): # pylint: disable=unused-argument
|
315
385
|
"""Performs a linear search on a list to find a target value.
|
316
386
|
|
@@ -123,8 +123,8 @@ langfun/core/structured/completion.py,sha256=cS2PjG7sqzDu5x0xoTk8RmNcoeX55iVwH38
|
|
123
123
|
langfun/core/structured/completion_test.py,sha256=lendf6nPsNfAmd5A7k3v_HS2At9F_jjbKBcV7OEt94o,19310
|
124
124
|
langfun/core/structured/description.py,sha256=SXW4MJvshFjbR-0gw6rE21o6WXq12UlRXawvDBXMZFA,5211
|
125
125
|
langfun/core/structured/description_test.py,sha256=UtZGjSFUaQ6130t1E5tcL7ODu0xIefkapb53TbnqsK8,7362
|
126
|
-
langfun/core/structured/function_generation.py,sha256=
|
127
|
-
langfun/core/structured/function_generation_test.py,sha256=
|
126
|
+
langfun/core/structured/function_generation.py,sha256=c2KB3M86GE1H-8vSZlilko0OJRem8uqvX7CgqTzLdVw,8558
|
127
|
+
langfun/core/structured/function_generation_test.py,sha256=LaXYDXf9GlqUrR6v_gtmK_H4kxzonmU7SYbn7XXMgjU,12128
|
128
128
|
langfun/core/structured/mapping.py,sha256=vLKH79UT-j0qkQdvqlQBO7SkXXuM-yr2Idm8_HH8qwM,13649
|
129
129
|
langfun/core/structured/mapping_test.py,sha256=bHm2ZCXBITq_G8Lvw_olFHeUUc4s_lGXZm9v9JhoPB4,9630
|
130
130
|
langfun/core/structured/parsing.py,sha256=D58wBWOC6r6DCJNychCDkiHPrsy1XJfBDCDDZtug00k,11765
|
@@ -148,8 +148,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
148
148
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
149
149
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
150
150
|
langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
|
151
|
-
langfun-0.1.2.
|
152
|
-
langfun-0.1.2.
|
153
|
-
langfun-0.1.2.
|
154
|
-
langfun-0.1.2.
|
155
|
-
langfun-0.1.2.
|
151
|
+
langfun-0.1.2.dev202412080804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
152
|
+
langfun-0.1.2.dev202412080804.dist-info/METADATA,sha256=71KlK3WFqJkT60wgcB53JeY2zHOqHqX1ucAyX6r05CI,8281
|
153
|
+
langfun-0.1.2.dev202412080804.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
154
|
+
langfun-0.1.2.dev202412080804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
155
|
+
langfun-0.1.2.dev202412080804.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{langfun-0.1.2.dev202412060804.dist-info → langfun-0.1.2.dev202412080804.dist-info}/top_level.txt
RENAMED
File without changes
|