langfun 0.1.2.dev202412060804__py3-none-any.whl → 0.1.2.dev202412080804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,7 +16,7 @@
16
16
  import functools
17
17
  import inspect
18
18
  import re
19
- from typing import Any, Callable, Optional, Tuple
19
+ from typing import Any, Callable, Literal, Optional, Tuple
20
20
 
21
21
  from langfun.core import language_model
22
22
  from langfun.core import template
@@ -25,7 +25,7 @@ from langfun.core.structured import prompting
25
25
  import pyglove as pg
26
26
 
27
27
 
28
- def unittest_gen(signature, lm, num_retries=10):
28
+ def unittest_gen(signature, lm, num_retries=1):
29
29
  """Generates unit tests for a python function signature."""
30
30
 
31
31
  class UnitTest(pg.Object):
@@ -76,12 +76,16 @@ def unittest_with_test_cases(f, unittests):
76
76
 
77
77
  def _function_gen(
78
78
  func: Callable[..., Any],
79
+ context: dict[str, Any],
79
80
  signature: str,
80
81
  lm: language_model.LanguageModel,
81
- num_retries: int = 10,
82
+ num_retries: int = 1,
82
83
  unittest: Optional[
83
- Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
84
+ Callable[[Callable[..., Any]], None]
85
+ | list[Tuple[Any, Any]]
86
+ | Literal["auto"]
84
87
  ] = None,
88
+ unittest_num_retries: int = 1,
85
89
  ):
86
90
  """Generates a python function with LLM and verify its quality with unit testing."""
87
91
 
@@ -131,32 +135,43 @@ def _function_gen(
131
135
  """
132
136
 
133
137
  unittest_examples = None
134
- if unittest is None:
135
- unittest_examples = unittest_gen(signature, lm=lm)
136
- elif not callable(unittest):
138
+ if unittest == "auto":
139
+ unittest_examples = unittest_gen(
140
+ signature, lm=lm, num_retries=unittest_num_retries
141
+ )
142
+ elif isinstance(unittest, list):
137
143
  unittest_examples = unittest
138
144
 
145
+ last_error = None
139
146
  for _ in range(num_retries):
140
147
  try:
141
148
  source_code = prompting.query(
142
149
  PythonFunctionPrompt(signature=signature), lm=lm
143
150
  )
144
- f = python.evaluate(source_code)
151
+ f = python.evaluate(source_code, global_vars=context)
145
152
 
146
153
  # Check whether the sigantures are the same.
147
154
  if inspect.signature(f) != inspect.signature(func):
148
- continue
155
+ raise python.CodeError(
156
+ code=source_code,
157
+ cause=TypeError(
158
+ f"Signature mismatch: Expected: {inspect.signature(func)}, "
159
+ f"Actual: {inspect.signature(f)}.",
160
+ ),
161
+ )
149
162
 
150
163
  if callable(unittest):
151
164
  unittest(f)
152
- else:
165
+ elif unittest_examples:
153
166
  unittest_with_test_cases(f, unittest_examples)
154
167
 
155
168
  return f, source_code
156
- except Exception: # pylint: disable=broad-exception-caught
157
- pass
158
-
159
- return None, None
169
+ except python.CodeError as e:
170
+ last_error = e
171
+ pg.logging.warning(
172
+ f"Bad code generated: {e}",
173
+ )
174
+ raise last_error
160
175
 
161
176
 
162
177
  def _process_signature(signature):
@@ -172,10 +187,13 @@ def _process_signature(signature):
172
187
  def function_gen(
173
188
  lm: language_model.LanguageModel,
174
189
  cache_filename: str | None = None,
175
- num_retries: int = 10,
190
+ num_retries: int = 1,
176
191
  unittest: Optional[
177
- Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
192
+ Callable[[Callable[..., Any]], None]
193
+ | list[Tuple[Any, Any]]
194
+ | Literal["auto"]
178
195
  ] = None,
196
+ unittest_num_retries: int = 1,
179
197
  ):
180
198
  """A decorator for automating function generation using a language model.
181
199
 
@@ -192,9 +210,12 @@ def function_gen(
192
210
  make to generate a suitable function implementation.
193
211
  unittest: This optional parameter enables the definition of custom unit
194
212
  tests. You can either provide a list of test cases as tuples of inputs
195
- and outputs, or a function that throws an error if a test fails. If left
196
- as None (the default setting), the LLM will automatically create the
197
- unit test cases.
213
+ and outputs, or a function that throws an error if a test fails, or let
214
+ LLM automatically create the unit test cases. If a generated function is
215
+ and returned, it should pass all the unittests.
216
+ unittest_num_retries: If unittest is set to "auto", this parameter
217
+ specifies the number of times the LLM's attempts to generate unit test
218
+ cases.
198
219
 
199
220
  Returns:
200
221
  The implemented function object.
@@ -204,6 +225,13 @@ def function_gen(
204
225
  setattr(func, "__function__", None)
205
226
  setattr(func, "__source_code__", None)
206
227
 
228
+ # Prepare the globals/locals for the generated code to be evaluated against.
229
+ callstack = inspect.stack()
230
+ assert len(callstack) > 1
231
+ context = dict(callstack[1][0].f_globals)
232
+ context.update(callstack[1][0].f_locals)
233
+ context.pop(func.__name__, None)
234
+
207
235
  @functools.wraps(func)
208
236
  def lm_generated_func(*args, **kwargs):
209
237
  if func.__function__ is not None:
@@ -222,15 +250,20 @@ def function_gen(
222
250
 
223
251
  if signature in cache:
224
252
  func.__source_code__ = cache[signature]
225
- func.__function__ = python.evaluate(func.__source_code__)
253
+ func.__function__ = python.evaluate(
254
+ func.__source_code__, global_vars=context
255
+ )
226
256
  return func.__function__(*args, **kwargs)
227
257
 
228
258
  func.__function__, func.__source_code__ = _function_gen(
229
- func, signature, lm, num_retries=num_retries, unittest=unittest
259
+ func,
260
+ context,
261
+ signature,
262
+ lm,
263
+ num_retries=num_retries,
264
+ unittest=unittest,
265
+ unittest_num_retries=unittest_num_retries,
230
266
  )
231
- if func.__function__ is None:
232
- raise ValueError(f"Function generation failed. Signature:\n{signature}")
233
-
234
267
  if cache_filename is not None:
235
268
  cache[signature] = func.__source_code__
236
269
  cache.save(cache_filename)
@@ -63,6 +63,42 @@ class FunctionGenerationTest(unittest.TestCase):
63
63
 
64
64
  lm = fake.StaticSequence([unittest_lm_response, function_gen_lm_response])
65
65
 
66
+ @function_generation.function_gen(lm=lm, unittest='auto')
67
+ def linear_search(items, target): # pylint: disable=unused-argument
68
+ """Performs a linear search on a list to find a target value.
69
+
70
+ Args:
71
+ items (list): The list to search within.
72
+ target: The value to search for.
73
+
74
+ Returns:
75
+ int: The index of the target value if found, otherwise -1.
76
+ """
77
+
78
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
79
+ self.assertEqual(linear_search.source(), function_gen_lm_response)
80
+
81
+ def test_generate_function_without_unittest(self):
82
+ function_gen_lm_response = inspect.cleandoc("""
83
+ def linear_search(items, target):
84
+ \"\"\"
85
+ Performs a linear search on a list to find a target value.
86
+
87
+ Args:
88
+ items (list): The list to search within.
89
+ target: The value to search for.
90
+
91
+ Returns:
92
+ int: The index of the target value if found, otherwise -1.
93
+ \"\"\"
94
+ for i, item in enumerate(items):
95
+ if item == target:
96
+ return i
97
+ return -1
98
+ """)
99
+
100
+ lm = fake.StaticSequence([function_gen_lm_response])
101
+
66
102
  @function_generation.function_gen(lm=lm)
67
103
  def linear_search(items, target): # pylint: disable=unused-argument
68
104
  """Performs a linear search on a list to find a target value.
@@ -258,7 +294,9 @@ class FunctionGenerationTest(unittest.TestCase):
258
294
  cache_file = os.path.join(cache_file_dir, 'cache_file.json')
259
295
 
260
296
  @function_generation.function_gen(
261
- lm=lm, unittest=_unittest_fn, cache_filename=cache_file
297
+ lm=lm,
298
+ unittest=_unittest_fn,
299
+ cache_filename=cache_file,
262
300
  )
263
301
  def linear_search(items, target): # pylint: disable=unused-argument
264
302
  """Performs a linear search on a list to find a target value.
@@ -273,6 +311,36 @@ class FunctionGenerationTest(unittest.TestCase):
273
311
 
274
312
  self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
275
313
 
314
+ def test_context_passthrough(self):
315
+
316
+ class Number(pg.Object):
317
+ value: int
318
+
319
+ function_gen_lm_response = inspect.cleandoc("""
320
+ ```python
321
+ def add(a: Number, b: Number) -> Number:
322
+ \"\"\"Adds two numbers together.\"\"\"
323
+ return Number(a.value + b.value)
324
+ ```
325
+ """)
326
+
327
+ lm = fake.StaticSequence(
328
+ [function_gen_lm_response]
329
+ )
330
+
331
+ def _unittest_fn(func):
332
+ assert func(Number(1), Number(2)) == Number(3)
333
+
334
+ custom_unittest = _unittest_fn
335
+
336
+ @function_generation.function_gen(
337
+ lm=lm, unittest=custom_unittest, num_retries=1
338
+ )
339
+ def add(a: Number, b: Number) -> Number: # pylint: disable=unused-argument
340
+ """Adds two numbers together."""
341
+
342
+ self.assertEqual(add(Number(2), Number(3)), Number(5))
343
+
276
344
  def test_siganture_check(self):
277
345
  incorrect_signature_lm_response = inspect.cleandoc("""
278
346
  ```python
@@ -310,7 +378,9 @@ class FunctionGenerationTest(unittest.TestCase):
310
378
 
311
379
  custom_unittest = _unittest_fn
312
380
 
313
- @function_generation.function_gen(lm=lm, unittest=custom_unittest)
381
+ @function_generation.function_gen(
382
+ lm=lm, unittest=custom_unittest, num_retries=2
383
+ )
314
384
  def linear_search(items, target): # pylint: disable=unused-argument
315
385
  """Performs a linear search on a list to find a target value.
316
386
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.2.dev202412060804
3
+ Version: 0.1.2.dev202412080804
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -123,8 +123,8 @@ langfun/core/structured/completion.py,sha256=cS2PjG7sqzDu5x0xoTk8RmNcoeX55iVwH38
123
123
  langfun/core/structured/completion_test.py,sha256=lendf6nPsNfAmd5A7k3v_HS2At9F_jjbKBcV7OEt94o,19310
124
124
  langfun/core/structured/description.py,sha256=SXW4MJvshFjbR-0gw6rE21o6WXq12UlRXawvDBXMZFA,5211
125
125
  langfun/core/structured/description_test.py,sha256=UtZGjSFUaQ6130t1E5tcL7ODu0xIefkapb53TbnqsK8,7362
126
- langfun/core/structured/function_generation.py,sha256=pFgS3vcRAWiuFBol2x5Eeip3XqoudONsOpeJpWyjT3s,7479
127
- langfun/core/structured/function_generation_test.py,sha256=ZJI-aaGgWWszn92u7h5IZ9Pl70N2DgAGGJrIxPzsvwg,10065
126
+ langfun/core/structured/function_generation.py,sha256=c2KB3M86GE1H-8vSZlilko0OJRem8uqvX7CgqTzLdVw,8558
127
+ langfun/core/structured/function_generation_test.py,sha256=LaXYDXf9GlqUrR6v_gtmK_H4kxzonmU7SYbn7XXMgjU,12128
128
128
  langfun/core/structured/mapping.py,sha256=vLKH79UT-j0qkQdvqlQBO7SkXXuM-yr2Idm8_HH8qwM,13649
129
129
  langfun/core/structured/mapping_test.py,sha256=bHm2ZCXBITq_G8Lvw_olFHeUUc4s_lGXZm9v9JhoPB4,9630
130
130
  langfun/core/structured/parsing.py,sha256=D58wBWOC6r6DCJNychCDkiHPrsy1XJfBDCDDZtug00k,11765
@@ -148,8 +148,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
148
148
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
149
149
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
150
150
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
151
- langfun-0.1.2.dev202412060804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
152
- langfun-0.1.2.dev202412060804.dist-info/METADATA,sha256=7jXlkJbHKaJuKLg8E5_BZ0jS99pZozItRnVWolNqAtg,8281
153
- langfun-0.1.2.dev202412060804.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
154
- langfun-0.1.2.dev202412060804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
155
- langfun-0.1.2.dev202412060804.dist-info/RECORD,,
151
+ langfun-0.1.2.dev202412080804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
152
+ langfun-0.1.2.dev202412080804.dist-info/METADATA,sha256=71KlK3WFqJkT60wgcB53JeY2zHOqHqX1ucAyX6r05CI,8281
153
+ langfun-0.1.2.dev202412080804.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
154
+ langfun-0.1.2.dev202412080804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
155
+ langfun-0.1.2.dev202412080804.dist-info/RECORD,,