langfun 0.1.2.dev202412070804__py3-none-any.whl → 0.1.2.dev202412080804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -76,6 +76,7 @@ def unittest_with_test_cases(f, unittests):
76
76
 
77
77
  def _function_gen(
78
78
  func: Callable[..., Any],
79
+ context: dict[str, Any],
79
80
  signature: str,
80
81
  lm: language_model.LanguageModel,
81
82
  num_retries: int = 1,
@@ -141,21 +142,23 @@ def _function_gen(
141
142
  elif isinstance(unittest, list):
142
143
  unittest_examples = unittest
143
144
 
145
+ last_error = None
144
146
  for _ in range(num_retries):
145
147
  try:
146
148
  source_code = prompting.query(
147
149
  PythonFunctionPrompt(signature=signature), lm=lm
148
150
  )
149
- f = python.evaluate(source_code)
151
+ f = python.evaluate(source_code, global_vars=context)
150
152
 
151
153
  # Check whether the sigantures are the same.
152
154
  if inspect.signature(f) != inspect.signature(func):
153
- pg.logging.warning(
154
- "Signature mismatch. Expected: %s, Actual: %s",
155
- inspect.signature(func),
156
- inspect.signature(f),
155
+ raise python.CodeError(
156
+ code=source_code,
157
+ cause=TypeError(
158
+ f"Signature mismatch: Expected: {inspect.signature(func)}, "
159
+ f"Actual: {inspect.signature(f)}.",
160
+ ),
157
161
  )
158
- continue
159
162
 
160
163
  if callable(unittest):
161
164
  unittest(f)
@@ -163,10 +166,12 @@ def _function_gen(
163
166
  unittest_with_test_cases(f, unittest_examples)
164
167
 
165
168
  return f, source_code
166
- except Exception: # pylint: disable=broad-exception-caught
167
- pass
168
-
169
- return None, None
169
+ except python.CodeError as e:
170
+ last_error = e
171
+ pg.logging.warning(
172
+ f"Bad code generated: {e}",
173
+ )
174
+ raise last_error
170
175
 
171
176
 
172
177
  def _process_signature(signature):
@@ -220,6 +225,13 @@ def function_gen(
220
225
  setattr(func, "__function__", None)
221
226
  setattr(func, "__source_code__", None)
222
227
 
228
+ # Prepare the globals/locals for the generated code to be evaluated against.
229
+ callstack = inspect.stack()
230
+ assert len(callstack) > 1
231
+ context = dict(callstack[1][0].f_globals)
232
+ context.update(callstack[1][0].f_locals)
233
+ context.pop(func.__name__, None)
234
+
223
235
  @functools.wraps(func)
224
236
  def lm_generated_func(*args, **kwargs):
225
237
  if func.__function__ is not None:
@@ -238,20 +250,20 @@ def function_gen(
238
250
 
239
251
  if signature in cache:
240
252
  func.__source_code__ = cache[signature]
241
- func.__function__ = python.evaluate(func.__source_code__)
253
+ func.__function__ = python.evaluate(
254
+ func.__source_code__, global_vars=context
255
+ )
242
256
  return func.__function__(*args, **kwargs)
243
257
 
244
258
  func.__function__, func.__source_code__ = _function_gen(
245
259
  func,
260
+ context,
246
261
  signature,
247
262
  lm,
248
263
  num_retries=num_retries,
249
264
  unittest=unittest,
250
265
  unittest_num_retries=unittest_num_retries,
251
266
  )
252
- if func.__function__ is None:
253
- raise ValueError(f"Function generation failed. Signature:\n{signature}")
254
-
255
267
  if cache_filename is not None:
256
268
  cache[signature] = func.__source_code__
257
269
  cache.save(cache_filename)
@@ -311,6 +311,36 @@ class FunctionGenerationTest(unittest.TestCase):
311
311
 
312
312
  self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
313
313
 
314
+ def test_context_passthrough(self):
315
+
316
+ class Number(pg.Object):
317
+ value: int
318
+
319
+ function_gen_lm_response = inspect.cleandoc("""
320
+ ```python
321
+ def add(a: Number, b: Number) -> Number:
322
+ \"\"\"Adds two numbers together.\"\"\"
323
+ return Number(a.value + b.value)
324
+ ```
325
+ """)
326
+
327
+ lm = fake.StaticSequence(
328
+ [function_gen_lm_response]
329
+ )
330
+
331
+ def _unittest_fn(func):
332
+ assert func(Number(1), Number(2)) == Number(3)
333
+
334
+ custom_unittest = _unittest_fn
335
+
336
+ @function_generation.function_gen(
337
+ lm=lm, unittest=custom_unittest, num_retries=1
338
+ )
339
+ def add(a: Number, b: Number) -> Number: # pylint: disable=unused-argument
340
+ """Adds two numbers together."""
341
+
342
+ self.assertEqual(add(Number(2), Number(3)), Number(5))
343
+
314
344
  def test_siganture_check(self):
315
345
  incorrect_signature_lm_response = inspect.cleandoc("""
316
346
  ```python
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.2.dev202412070804
3
+ Version: 0.1.2.dev202412080804
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -123,8 +123,8 @@ langfun/core/structured/completion.py,sha256=cS2PjG7sqzDu5x0xoTk8RmNcoeX55iVwH38
123
123
  langfun/core/structured/completion_test.py,sha256=lendf6nPsNfAmd5A7k3v_HS2At9F_jjbKBcV7OEt94o,19310
124
124
  langfun/core/structured/description.py,sha256=SXW4MJvshFjbR-0gw6rE21o6WXq12UlRXawvDBXMZFA,5211
125
125
  langfun/core/structured/description_test.py,sha256=UtZGjSFUaQ6130t1E5tcL7ODu0xIefkapb53TbnqsK8,7362
126
- langfun/core/structured/function_generation.py,sha256=gOV5B4KXzN6ng1P1QtZ8aOAEQB8eAbgwWGj57tnzWJY,8159
127
- langfun/core/structured/function_generation_test.py,sha256=1OtstouOYyYOd_gmZtL8RRbh-FcYGEvBNju6lNrJrOA,11331
126
+ langfun/core/structured/function_generation.py,sha256=c2KB3M86GE1H-8vSZlilko0OJRem8uqvX7CgqTzLdVw,8558
127
+ langfun/core/structured/function_generation_test.py,sha256=LaXYDXf9GlqUrR6v_gtmK_H4kxzonmU7SYbn7XXMgjU,12128
128
128
  langfun/core/structured/mapping.py,sha256=vLKH79UT-j0qkQdvqlQBO7SkXXuM-yr2Idm8_HH8qwM,13649
129
129
  langfun/core/structured/mapping_test.py,sha256=bHm2ZCXBITq_G8Lvw_olFHeUUc4s_lGXZm9v9JhoPB4,9630
130
130
  langfun/core/structured/parsing.py,sha256=D58wBWOC6r6DCJNychCDkiHPrsy1XJfBDCDDZtug00k,11765
@@ -148,8 +148,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
148
148
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
149
149
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
150
150
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
151
- langfun-0.1.2.dev202412070804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
152
- langfun-0.1.2.dev202412070804.dist-info/METADATA,sha256=Ao9kzm8AFw7749nP7p_m3k41rVcBBzG8_QuMnLLN9_U,8281
153
- langfun-0.1.2.dev202412070804.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
154
- langfun-0.1.2.dev202412070804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
155
- langfun-0.1.2.dev202412070804.dist-info/RECORD,,
151
+ langfun-0.1.2.dev202412080804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
152
+ langfun-0.1.2.dev202412080804.dist-info/METADATA,sha256=71KlK3WFqJkT60wgcB53JeY2zHOqHqX1ucAyX6r05CI,8281
153
+ langfun-0.1.2.dev202412080804.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
154
+ langfun-0.1.2.dev202412080804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
155
+ langfun-0.1.2.dev202412080804.dist-info/RECORD,,