langfun 0.1.2.dev202409070803__py3-none-any.whl → 0.1.2.dev202409080803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/coding/python/correction.py +28 -12
- langfun/core/coding/python/correction_test.py +27 -0
- {langfun-0.1.2.dev202409070803.dist-info → langfun-0.1.2.dev202409080803.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202409070803.dist-info → langfun-0.1.2.dev202409080803.dist-info}/RECORD +7 -7
- {langfun-0.1.2.dev202409070803.dist-info → langfun-0.1.2.dev202409080803.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202409070803.dist-info → langfun-0.1.2.dev202409080803.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202409070803.dist-info → langfun-0.1.2.dev202409080803.dist-info}/top_level.txt +0 -0
@@ -81,23 +81,27 @@ def run_with_correction(
|
|
81
81
|
# pylint: enable=g-import-not-at-top
|
82
82
|
|
83
83
|
if max_attempts == 0:
|
84
|
-
result =
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
84
|
+
result = _maybe_custom_validate(
|
85
|
+
execution.run(
|
86
|
+
code,
|
87
|
+
global_vars=global_vars,
|
88
|
+
sandbox=sandbox,
|
89
|
+
timeout=timeout,
|
90
|
+
outputs_intermediate=outputs_intermediate,
|
91
|
+
)
|
90
92
|
)
|
91
93
|
return (result, code) if returns_code else result
|
92
94
|
|
93
95
|
def result_and_error(code: str) -> tuple[Any, str | None]:
|
94
96
|
try:
|
95
|
-
result =
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
97
|
+
result = _maybe_custom_validate(
|
98
|
+
execution.run(
|
99
|
+
code,
|
100
|
+
global_vars=global_vars,
|
101
|
+
sandbox=sandbox,
|
102
|
+
timeout=timeout,
|
103
|
+
outputs_intermediate=outputs_intermediate,
|
104
|
+
)
|
101
105
|
)
|
102
106
|
return (result, None)
|
103
107
|
except Exception as e: # pylint: disable=broad-exception-caught
|
@@ -190,3 +194,15 @@ def _error_feedback_str(error: Exception) -> str:
|
|
190
194
|
)
|
191
195
|
else:
|
192
196
|
return f"Encountered {error.__class__.__name__}: {error}"
|
197
|
+
|
198
|
+
|
199
|
+
def _maybe_custom_validate(result: Any) -> Any:
|
200
|
+
"""Apply custom validation through __validate_generation__ method."""
|
201
|
+
if isinstance(result, dict) and "__result__" in result:
|
202
|
+
r = result["__result__"]
|
203
|
+
else:
|
204
|
+
r = result
|
205
|
+
|
206
|
+
if hasattr(r, "__validate__"):
|
207
|
+
r.__validate__()
|
208
|
+
return result
|
@@ -19,6 +19,7 @@ import unittest
|
|
19
19
|
from langfun.core.coding.python import correction
|
20
20
|
from langfun.core.coding.python import errors
|
21
21
|
from langfun.core.llms import fake
|
22
|
+
import pyglove as pg
|
22
23
|
|
23
24
|
|
24
25
|
class RunWithCorrectionTest(unittest.TestCase):
|
@@ -45,6 +46,32 @@ class RunWithCorrectionTest(unittest.TestCase):
|
|
45
46
|
)
|
46
47
|
self.assertEqual(result, 4)
|
47
48
|
|
49
|
+
def test_run_with_correction_upon_custom_validation(self):
|
50
|
+
|
51
|
+
class Foo(pg.Object):
|
52
|
+
x: int
|
53
|
+
|
54
|
+
def __validate__(self):
|
55
|
+
if self.x > 1:
|
56
|
+
raise ValueError('value should be less or equal than 1.')
|
57
|
+
if self.x < 0:
|
58
|
+
self.rebind(x=0, skip_notification=True)
|
59
|
+
|
60
|
+
result = correction.run_with_correction(
|
61
|
+
inspect.cleandoc("""
|
62
|
+
Foo(x=2)
|
63
|
+
"""),
|
64
|
+
global_vars=dict(Foo=Foo),
|
65
|
+
lm=fake.StaticSequence([
|
66
|
+
inspect.cleandoc("""
|
67
|
+
CorrectedCode(
|
68
|
+
corrected_code='Foo(x=-1)',
|
69
|
+
)
|
70
|
+
"""),
|
71
|
+
]),
|
72
|
+
)
|
73
|
+
self.assertEqual(result, Foo(0))
|
74
|
+
|
48
75
|
def test_run_without_correction(self):
|
49
76
|
result = correction.run_with_correction(
|
50
77
|
inspect.cleandoc("""
|
@@ -31,8 +31,8 @@ langfun/core/text_formatting.py,sha256=d7t9vaY6aCn1dkfkikpNYnBy5E_i93vHbfyDWFclG
|
|
31
31
|
langfun/core/text_formatting_test.py,sha256=ck0Xzdd4YF4CtCUj7VE0GybfbAyKQ8p3xkM1FBGrqIk,2096
|
32
32
|
langfun/core/coding/__init__.py,sha256=5utju_fwEsImaiftx4oXKl9FAM8p281k8-Esdh_-m1w,835
|
33
33
|
langfun/core/coding/python/__init__.py,sha256=MJ-vubliz-ebrZH3OBRKBwMi0S9-FrhGCp8YQLR6_I4,1776
|
34
|
-
langfun/core/coding/python/correction.py,sha256=
|
35
|
-
langfun/core/coding/python/correction_test.py,sha256=
|
34
|
+
langfun/core/coding/python/correction.py,sha256=WiBdoScL-6C___iA3Tg3vizuYtJWI-_4wy9zcMfVpj8,7020
|
35
|
+
langfun/core/coding/python/correction_test.py,sha256=qGxXuHaO32onF6cAoTfO1_sH_lM7-3dE9UqaaU8Myxs,4215
|
36
36
|
langfun/core/coding/python/errors.py,sha256=fX3Du63uGm25YFXW9D-bV2gntTdTAX3hBFtAnRlmg14,3166
|
37
37
|
langfun/core/coding/python/errors_test.py,sha256=_ZbWJCFIb-FkCK7K1zCuH8W3x_NFt-jNe3dfP8yqaD4,2323
|
38
38
|
langfun/core/coding/python/execution.py,sha256=raZix62g2fwt6Lgykll2DFzkLlEjVqN9E73q0iaVdak,10185
|
@@ -119,8 +119,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
119
119
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
120
120
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
121
121
|
langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
|
122
|
-
langfun-0.1.2.
|
123
|
-
langfun-0.1.2.
|
124
|
-
langfun-0.1.2.
|
125
|
-
langfun-0.1.2.
|
126
|
-
langfun-0.1.2.
|
122
|
+
langfun-0.1.2.dev202409080803.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
123
|
+
langfun-0.1.2.dev202409080803.dist-info/METADATA,sha256=owbYfUp61s8ymVIPTduW2Y0BOBn9uw3JqGB_J1_5rmc,8825
|
124
|
+
langfun-0.1.2.dev202409080803.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
125
|
+
langfun-0.1.2.dev202409080803.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
126
|
+
langfun-0.1.2.dev202409080803.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{langfun-0.1.2.dev202409070803.dist-info → langfun-0.1.2.dev202409080803.dist-info}/top_level.txt
RENAMED
File without changes
|