langfun 0.1.2.dev202412070804__py3-none-any.whl → 0.1.2.dev202412080804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/structured/function_generation.py +26 -14
- langfun/core/structured/function_generation_test.py +30 -0
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412080804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412080804.dist-info}/RECORD +7 -7
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412080804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412080804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412080804.dist-info}/top_level.txt +0 -0
@@ -76,6 +76,7 @@ def unittest_with_test_cases(f, unittests):
|
|
76
76
|
|
77
77
|
def _function_gen(
|
78
78
|
func: Callable[..., Any],
|
79
|
+
context: dict[str, Any],
|
79
80
|
signature: str,
|
80
81
|
lm: language_model.LanguageModel,
|
81
82
|
num_retries: int = 1,
|
@@ -141,21 +142,23 @@ def _function_gen(
|
|
141
142
|
elif isinstance(unittest, list):
|
142
143
|
unittest_examples = unittest
|
143
144
|
|
145
|
+
last_error = None
|
144
146
|
for _ in range(num_retries):
|
145
147
|
try:
|
146
148
|
source_code = prompting.query(
|
147
149
|
PythonFunctionPrompt(signature=signature), lm=lm
|
148
150
|
)
|
149
|
-
f = python.evaluate(source_code)
|
151
|
+
f = python.evaluate(source_code, global_vars=context)
|
150
152
|
|
151
153
|
# Check whether the sigantures are the same.
|
152
154
|
if inspect.signature(f) != inspect.signature(func):
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
155
|
+
raise python.CodeError(
|
156
|
+
code=source_code,
|
157
|
+
cause=TypeError(
|
158
|
+
f"Signature mismatch: Expected: {inspect.signature(func)}, "
|
159
|
+
f"Actual: {inspect.signature(f)}.",
|
160
|
+
),
|
157
161
|
)
|
158
|
-
continue
|
159
162
|
|
160
163
|
if callable(unittest):
|
161
164
|
unittest(f)
|
@@ -163,10 +166,12 @@ def _function_gen(
|
|
163
166
|
unittest_with_test_cases(f, unittest_examples)
|
164
167
|
|
165
168
|
return f, source_code
|
166
|
-
except
|
167
|
-
|
168
|
-
|
169
|
-
|
169
|
+
except python.CodeError as e:
|
170
|
+
last_error = e
|
171
|
+
pg.logging.warning(
|
172
|
+
f"Bad code generated: {e}",
|
173
|
+
)
|
174
|
+
raise last_error
|
170
175
|
|
171
176
|
|
172
177
|
def _process_signature(signature):
|
@@ -220,6 +225,13 @@ def function_gen(
|
|
220
225
|
setattr(func, "__function__", None)
|
221
226
|
setattr(func, "__source_code__", None)
|
222
227
|
|
228
|
+
# Prepare the globals/locals for the generated code to be evaluated against.
|
229
|
+
callstack = inspect.stack()
|
230
|
+
assert len(callstack) > 1
|
231
|
+
context = dict(callstack[1][0].f_globals)
|
232
|
+
context.update(callstack[1][0].f_locals)
|
233
|
+
context.pop(func.__name__, None)
|
234
|
+
|
223
235
|
@functools.wraps(func)
|
224
236
|
def lm_generated_func(*args, **kwargs):
|
225
237
|
if func.__function__ is not None:
|
@@ -238,20 +250,20 @@ def function_gen(
|
|
238
250
|
|
239
251
|
if signature in cache:
|
240
252
|
func.__source_code__ = cache[signature]
|
241
|
-
func.__function__ = python.evaluate(
|
253
|
+
func.__function__ = python.evaluate(
|
254
|
+
func.__source_code__, global_vars=context
|
255
|
+
)
|
242
256
|
return func.__function__(*args, **kwargs)
|
243
257
|
|
244
258
|
func.__function__, func.__source_code__ = _function_gen(
|
245
259
|
func,
|
260
|
+
context,
|
246
261
|
signature,
|
247
262
|
lm,
|
248
263
|
num_retries=num_retries,
|
249
264
|
unittest=unittest,
|
250
265
|
unittest_num_retries=unittest_num_retries,
|
251
266
|
)
|
252
|
-
if func.__function__ is None:
|
253
|
-
raise ValueError(f"Function generation failed. Signature:\n{signature}")
|
254
|
-
|
255
267
|
if cache_filename is not None:
|
256
268
|
cache[signature] = func.__source_code__
|
257
269
|
cache.save(cache_filename)
|
@@ -311,6 +311,36 @@ class FunctionGenerationTest(unittest.TestCase):
|
|
311
311
|
|
312
312
|
self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
|
313
313
|
|
314
|
+
def test_context_passthrough(self):
|
315
|
+
|
316
|
+
class Number(pg.Object):
|
317
|
+
value: int
|
318
|
+
|
319
|
+
function_gen_lm_response = inspect.cleandoc("""
|
320
|
+
```python
|
321
|
+
def add(a: Number, b: Number) -> Number:
|
322
|
+
\"\"\"Adds two numbers together.\"\"\"
|
323
|
+
return Number(a.value + b.value)
|
324
|
+
```
|
325
|
+
""")
|
326
|
+
|
327
|
+
lm = fake.StaticSequence(
|
328
|
+
[function_gen_lm_response]
|
329
|
+
)
|
330
|
+
|
331
|
+
def _unittest_fn(func):
|
332
|
+
assert func(Number(1), Number(2)) == Number(3)
|
333
|
+
|
334
|
+
custom_unittest = _unittest_fn
|
335
|
+
|
336
|
+
@function_generation.function_gen(
|
337
|
+
lm=lm, unittest=custom_unittest, num_retries=1
|
338
|
+
)
|
339
|
+
def add(a: Number, b: Number) -> Number: # pylint: disable=unused-argument
|
340
|
+
"""Adds two numbers together."""
|
341
|
+
|
342
|
+
self.assertEqual(add(Number(2), Number(3)), Number(5))
|
343
|
+
|
314
344
|
def test_siganture_check(self):
|
315
345
|
incorrect_signature_lm_response = inspect.cleandoc("""
|
316
346
|
```python
|
@@ -123,8 +123,8 @@ langfun/core/structured/completion.py,sha256=cS2PjG7sqzDu5x0xoTk8RmNcoeX55iVwH38
|
|
123
123
|
langfun/core/structured/completion_test.py,sha256=lendf6nPsNfAmd5A7k3v_HS2At9F_jjbKBcV7OEt94o,19310
|
124
124
|
langfun/core/structured/description.py,sha256=SXW4MJvshFjbR-0gw6rE21o6WXq12UlRXawvDBXMZFA,5211
|
125
125
|
langfun/core/structured/description_test.py,sha256=UtZGjSFUaQ6130t1E5tcL7ODu0xIefkapb53TbnqsK8,7362
|
126
|
-
langfun/core/structured/function_generation.py,sha256=
|
127
|
-
langfun/core/structured/function_generation_test.py,sha256=
|
126
|
+
langfun/core/structured/function_generation.py,sha256=c2KB3M86GE1H-8vSZlilko0OJRem8uqvX7CgqTzLdVw,8558
|
127
|
+
langfun/core/structured/function_generation_test.py,sha256=LaXYDXf9GlqUrR6v_gtmK_H4kxzonmU7SYbn7XXMgjU,12128
|
128
128
|
langfun/core/structured/mapping.py,sha256=vLKH79UT-j0qkQdvqlQBO7SkXXuM-yr2Idm8_HH8qwM,13649
|
129
129
|
langfun/core/structured/mapping_test.py,sha256=bHm2ZCXBITq_G8Lvw_olFHeUUc4s_lGXZm9v9JhoPB4,9630
|
130
130
|
langfun/core/structured/parsing.py,sha256=D58wBWOC6r6DCJNychCDkiHPrsy1XJfBDCDDZtug00k,11765
|
@@ -148,8 +148,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
148
148
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
149
149
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
150
150
|
langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
|
151
|
-
langfun-0.1.2.
|
152
|
-
langfun-0.1.2.
|
153
|
-
langfun-0.1.2.
|
154
|
-
langfun-0.1.2.
|
155
|
-
langfun-0.1.2.
|
151
|
+
langfun-0.1.2.dev202412080804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
152
|
+
langfun-0.1.2.dev202412080804.dist-info/METADATA,sha256=71KlK3WFqJkT60wgcB53JeY2zHOqHqX1ucAyX6r05CI,8281
|
153
|
+
langfun-0.1.2.dev202412080804.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
154
|
+
langfun-0.1.2.dev202412080804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
155
|
+
langfun-0.1.2.dev202412080804.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412080804.dist-info}/top_level.txt
RENAMED
File without changes
|