langfun 0.1.2.dev202409020804__py3-none-any.whl → 0.1.2.dev202409080803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -81,23 +81,27 @@ def run_with_correction(
81
81
  # pylint: enable=g-import-not-at-top
82
82
 
83
83
  if max_attempts == 0:
84
- result = execution.run(
85
- code,
86
- global_vars=global_vars,
87
- sandbox=sandbox,
88
- timeout=timeout,
89
- outputs_intermediate=outputs_intermediate,
84
+ result = _maybe_custom_validate(
85
+ execution.run(
86
+ code,
87
+ global_vars=global_vars,
88
+ sandbox=sandbox,
89
+ timeout=timeout,
90
+ outputs_intermediate=outputs_intermediate,
91
+ )
90
92
  )
91
93
  return (result, code) if returns_code else result
92
94
 
93
95
  def result_and_error(code: str) -> tuple[Any, str | None]:
94
96
  try:
95
- result = execution.run(
96
- code,
97
- global_vars=global_vars,
98
- sandbox=sandbox,
99
- timeout=timeout,
100
- outputs_intermediate=outputs_intermediate,
97
+ result = _maybe_custom_validate(
98
+ execution.run(
99
+ code,
100
+ global_vars=global_vars,
101
+ sandbox=sandbox,
102
+ timeout=timeout,
103
+ outputs_intermediate=outputs_intermediate,
104
+ )
101
105
  )
102
106
  return (result, None)
103
107
  except Exception as e: # pylint: disable=broad-exception-caught
@@ -190,3 +194,15 @@ def _error_feedback_str(error: Exception) -> str:
190
194
  )
191
195
  else:
192
196
  return f"Encountered {error.__class__.__name__}: {error}"
197
+
198
+
199
+ def _maybe_custom_validate(result: Any) -> Any:
200
+ """Apply custom validation through __validate_generation__ method."""
201
+ if isinstance(result, dict) and "__result__" in result:
202
+ r = result["__result__"]
203
+ else:
204
+ r = result
205
+
206
+ if hasattr(r, "__validate__"):
207
+ r.__validate__()
208
+ return result
@@ -19,6 +19,7 @@ import unittest
19
19
  from langfun.core.coding.python import correction
20
20
  from langfun.core.coding.python import errors
21
21
  from langfun.core.llms import fake
22
+ import pyglove as pg
22
23
 
23
24
 
24
25
  class RunWithCorrectionTest(unittest.TestCase):
@@ -45,6 +46,32 @@ class RunWithCorrectionTest(unittest.TestCase):
45
46
  )
46
47
  self.assertEqual(result, 4)
47
48
 
49
+ def test_run_with_correction_upon_custom_validation(self):
50
+
51
+ class Foo(pg.Object):
52
+ x: int
53
+
54
+ def __validate__(self):
55
+ if self.x > 1:
56
+ raise ValueError('value should be less or equal than 1.')
57
+ if self.x < 0:
58
+ self.rebind(x=0, skip_notification=True)
59
+
60
+ result = correction.run_with_correction(
61
+ inspect.cleandoc("""
62
+ Foo(x=2)
63
+ """),
64
+ global_vars=dict(Foo=Foo),
65
+ lm=fake.StaticSequence([
66
+ inspect.cleandoc("""
67
+ CorrectedCode(
68
+ corrected_code='Foo(x=-1)',
69
+ )
70
+ """),
71
+ ]),
72
+ )
73
+ self.assertEqual(result, Foo(0))
74
+
48
75
  def test_run_without_correction(self):
49
76
  result = correction.run_with_correction(
50
77
  inspect.cleandoc("""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.2.dev202409020804
3
+ Version: 0.1.2.dev202409080803
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -31,8 +31,8 @@ langfun/core/text_formatting.py,sha256=d7t9vaY6aCn1dkfkikpNYnBy5E_i93vHbfyDWFclG
31
31
  langfun/core/text_formatting_test.py,sha256=ck0Xzdd4YF4CtCUj7VE0GybfbAyKQ8p3xkM1FBGrqIk,2096
32
32
  langfun/core/coding/__init__.py,sha256=5utju_fwEsImaiftx4oXKl9FAM8p281k8-Esdh_-m1w,835
33
33
  langfun/core/coding/python/__init__.py,sha256=MJ-vubliz-ebrZH3OBRKBwMi0S9-FrhGCp8YQLR6_I4,1776
34
- langfun/core/coding/python/correction.py,sha256=a2aFUt9ocbXTCR6Z6OGNjQZDI1LfU0PBkSe7hJB8dEM,6589
35
- langfun/core/coding/python/correction_test.py,sha256=yLqmQ9BPORsnREkrS10PnljEaLR3BoydTVeT3OGoqfU,3507
34
+ langfun/core/coding/python/correction.py,sha256=WiBdoScL-6C___iA3Tg3vizuYtJWI-_4wy9zcMfVpj8,7020
35
+ langfun/core/coding/python/correction_test.py,sha256=qGxXuHaO32onF6cAoTfO1_sH_lM7-3dE9UqaaU8Myxs,4215
36
36
  langfun/core/coding/python/errors.py,sha256=fX3Du63uGm25YFXW9D-bV2gntTdTAX3hBFtAnRlmg14,3166
37
37
  langfun/core/coding/python/errors_test.py,sha256=_ZbWJCFIb-FkCK7K1zCuH8W3x_NFt-jNe3dfP8yqaD4,2323
38
38
  langfun/core/coding/python/execution.py,sha256=raZix62g2fwt6Lgykll2DFzkLlEjVqN9E73q0iaVdak,10185
@@ -119,8 +119,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
119
119
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
120
120
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
121
121
  langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
122
- langfun-0.1.2.dev202409020804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
123
- langfun-0.1.2.dev202409020804.dist-info/METADATA,sha256=KpnJzDncl3sbag1lmUIqEc_WJvsT58Mvug2_BgJfNbQ,8825
124
- langfun-0.1.2.dev202409020804.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
125
- langfun-0.1.2.dev202409020804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
126
- langfun-0.1.2.dev202409020804.dist-info/RECORD,,
122
+ langfun-0.1.2.dev202409080803.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
123
+ langfun-0.1.2.dev202409080803.dist-info/METADATA,sha256=owbYfUp61s8ymVIPTduW2Y0BOBn9uw3JqGB_J1_5rmc,8825
124
+ langfun-0.1.2.dev202409080803.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
125
+ langfun-0.1.2.dev202409080803.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
126
+ langfun-0.1.2.dev202409080803.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.0.0)
2
+ Generator: setuptools (74.1.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5