langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,150 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Evaluation (v1) for Langfun agentic actions."""
15
+
16
+ import io
17
+ import os
18
+ from typing import Annotated, Any
19
+
20
+ import langfun.core as lf
21
+ from langfun.core import eval as lf_eval
22
+ from langfun.core.agentic import action as action_lib
23
+ import pyglove as pg
24
+
25
+
26
+ class ActionEval(lf.eval.v2.Evaluation):
27
+ """Agent evaluation."""
28
+
29
+ action_args: Annotated[
30
+ dict[str, Any],
31
+ 'Arguments to call the action.'
32
+ ] = {}
33
+
34
+ def process(self, example: pg.Dict) -> tuple[str, dict[str, Any]]:
35
+ action = example.action
36
+ session = action_lib.Session()
37
+ with lf.logging.use_log_level('fatal'):
38
+ action(session=session, **self.action_args)
39
+ return session.final_result, dict(session=session)
40
+
41
+
42
+ #
43
+ # TODO(daiyip): Remove V1 once V2 is fully launched.
44
+ #
45
+
46
+
47
+ @pg.functor()
48
+ def _dummy_schema():
49
+ return int
50
+
51
+
52
+ class ExampleView(pg.Object):
53
+ id: int
54
+ input: Any
55
+ output: Any
56
+ error: str | None = None
57
+
58
+
59
+ class ActionEvalV1(lf_eval.Matching):
60
+ """Base class for action evaluations.
61
+
62
+ The input function should returns a list of pg.Dict, with `action` and
63
+ `groundtruth` fields.
64
+ """
65
+ # We override the schema and prompt to dummy values since they are not used.
66
+ schema_fn = _dummy_schema()
67
+ prompt = '<unused>'
68
+
69
+ def process(self, example: pg.Dict, **kwargs):
70
+ action = example.action
71
+ session = action_lib.Session()
72
+ action(session=session, lm=self.lm, **kwargs)
73
+ return session.as_message()
74
+
75
+ def answer(self, output: Any, example: pg.Dict) -> Any:
76
+ return output
77
+
78
+ def groundtruth(self, example: Any) -> Any:
79
+ return example.groundtruth
80
+
81
+ def audit(
82
+ self,
83
+ example_idx: int,
84
+ example: Any,
85
+ message: lf.Message | None,
86
+ error: Exception | None = None,
87
+ dryrun: bool = False,
88
+ ):
89
+ super().audit(example_idx, example, message, error, dryrun)
90
+ # Write each example to HTML.
91
+ if not dryrun and self.dir:
92
+ def _save_html():
93
+ ExampleView(
94
+ example_idx,
95
+ example,
96
+ None if message is None else message.result,
97
+ error
98
+ ).to_html(
99
+ collapse_level=None,
100
+ enable_summary_tooltip=False,
101
+ ).save(
102
+ os.path.join(self.dir, f'example_{example_idx}.html')
103
+ )
104
+ # Write HTML in a separate thread to avoid blocking the main thread.
105
+ lf.concurrent.get_executor(
106
+ 'background_eval_io', max_workers=16
107
+ ).submit(_save_html)
108
+
109
+ def _render_mismatches(self, s: io.StringIO) -> None:
110
+ s.write('<h2> Mismatches (Incorrect) </h2>')
111
+ first_url = None
112
+ mismatched_ids = sorted([
113
+ example_idx for example_idx, *_ in self.mismatches
114
+ ])
115
+ for example_idx in mismatched_ids:
116
+ url = os.path.join(self.dir, f'example_{example_idx}.html')
117
+ if first_url is None:
118
+ first_url = url
119
+ s.write(
120
+ f'<a href="{url}" style="margin-right: 10px" target="example_view">'
121
+ f'{example_idx}</a> '
122
+ )
123
+ if first_url:
124
+ s.write(
125
+ '<iframe style="border:0;width:100%;height:100%" name="example_view"'
126
+ f'src="{first_url}" title="Example View"></iframe>'
127
+ )
128
+ else:
129
+ s.write('No mismatches found.')
130
+
131
+ def _render_matches(self, s: io.StringIO) -> None:
132
+ s.write('<h2> Matches (correct) </h2>')
133
+ first_url = None
134
+ matched_ids = sorted([
135
+ example_idx for example_idx, *_ in self.matches
136
+ ])
137
+ for example_idx in matched_ids:
138
+ url = os.path.join(self.dir, f'example_{example_idx}.html')
139
+ if first_url is None:
140
+ first_url = url
141
+ s.write(
142
+ f'<a href="{url}" style="margin-right: 10px">{example_idx}</a> '
143
+ )
144
+ if first_url:
145
+ s.write(
146
+ '<iframe style="border:0;width:100%;height:100%" name="example_view"'
147
+ f'src="{first_url}" title="Example View"></iframe>'
148
+ )
149
+ else:
150
+ s.write('No matches found.')
@@ -0,0 +1,109 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for action evaluation."""
15
+
16
+ import os
17
+ import tempfile
18
+ import unittest
19
+
20
+ from langfun.core import eval as lf_eval
21
+ from langfun.core import llms as lf_llms
22
+ from langfun.core.agentic import action as action_lib
23
+ from langfun.core.agentic import action_eval
24
+ import pyglove as pg
25
+
26
+
27
+ class Foo(action_lib.Action):
28
+ x: int
29
+
30
+ def call(self, session, **kwargs):
31
+ del session, kwargs
32
+ return self.x
33
+
34
+
35
+ @pg.functor()
36
+ def foo_inputs():
37
+ return [
38
+ pg.Dict(action=Foo(1), groundtruth=1),
39
+ pg.Dict(action=Foo(2), groundtruth=1),
40
+ ]
41
+
42
+
43
+ class ActionEvalTest(unittest.TestCase):
44
+
45
+ def test_basics(self):
46
+
47
+ class FooEval(action_eval.ActionEval):
48
+ inputs = foo_inputs()
49
+ metrics = [lf_eval.v2.metrics.Match()]
50
+ action_args = dict(
51
+ lm=lf_llms.Echo()
52
+ )
53
+
54
+ s = FooEval()
55
+ root_dir = os.path.join(tempfile.gettempdir(), 'foo_eval')
56
+ s.run(root_dir, plugins=[])
57
+ self.assertEqual(s.metrics[0].matches, 0.5)
58
+ self.assertEqual(s.metrics[0].mismatches, 0.5)
59
+
60
+
61
+ class ActionEvalV1Test(unittest.TestCase):
62
+
63
+ def test_basics(self):
64
+
65
+ class FooEval(action_eval.ActionEvalV1):
66
+ lm = lf_llms.Echo()
67
+ inputs = foo_inputs()
68
+
69
+ s = FooEval()
70
+ result = s.run(summary=False)
71
+ pg.print(result)
72
+ self.assertEqual(
73
+ result,
74
+ dict(
75
+ experiment_setup=dict(
76
+ id=s.id,
77
+ dir=None,
78
+ model='Echo',
79
+ prompt_template='<unused>',
80
+ method='query',
81
+ schema_fn='_dummy_schema()'
82
+ ),
83
+ cache_stats=dict(
84
+ use_cache=True,
85
+ num_queries=0,
86
+ num_hits=0,
87
+ num_updates=0,
88
+ ),
89
+ metrics=dict(
90
+ total=2,
91
+ failures=0,
92
+ failure_rate=0.0,
93
+ oop_failures=0,
94
+ oop_failure_rate=0.0,
95
+ non_oop_failures=0,
96
+ non_oop_failure_rate=0.0,
97
+ failure_breakdown={},
98
+ num_matches=0,
99
+ match_rate=0.0,
100
+ num_mismatches=2,
101
+ mismatch_rate=1.0
102
+ ),
103
+ usage=None
104
+ )
105
+ )
106
+
107
+
108
+ if __name__ == '__main__':
109
+ unittest.main()
@@ -0,0 +1,136 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for base action."""
15
+
16
+ import unittest
17
+
18
+ import langfun.core as lf
19
+ from langfun.core.agentic import action as action_lib
20
+ from langfun.core.llms import fake
21
+ import langfun.core.structured as lf_structured
22
+ import pyglove as pg
23
+
24
+
25
+ class SessionTest(unittest.TestCase):
26
+
27
+ def test_basics(self):
28
+ test = self
29
+
30
+ class Bar(action_lib.Action):
31
+
32
+ def call(self, session, *, lm, **kwargs):
33
+ test.assertIs(session.current_action.action, self)
34
+ session.info('Begin Bar')
35
+ session.query('bar', lm=lm)
36
+ session.add_metadata(note='bar')
37
+ return 2
38
+
39
+ class Foo(action_lib.Action):
40
+ x: int
41
+
42
+ def call(self, session, *, lm, **kwargs):
43
+ test.assertIs(session.current_action.action, self)
44
+ with session.phase('prepare'):
45
+ session.info('Begin Foo', x=1)
46
+ session.query('foo', lm=lm)
47
+ with session.track_queries():
48
+ self.make_additional_query(lm)
49
+ session.add_metadata(note='foo')
50
+ return self.x + Bar()(session, lm=lm)
51
+
52
+ def make_additional_query(self, lm):
53
+ lf_structured.query('additional query', lm=lm)
54
+
55
+ lm = fake.StaticResponse('lm response')
56
+ foo = Foo(1)
57
+ self.assertEqual(foo(lm=lm), 3)
58
+
59
+ session = foo.session
60
+ self.assertIsNotNone(session)
61
+ self.assertIsInstance(session.root.action, action_lib.RootAction)
62
+ self.assertIs(session.current_action, session.root)
63
+
64
+ #
65
+ # Inspecting the root invocation.
66
+ #
67
+
68
+ root = session.root
69
+ self.assertEqual(len(root.execution.items), 1)
70
+ self.assertIs(root.execution.items[0].action, foo)
71
+
72
+ self.assertTrue(root.execution.has_started)
73
+ self.assertTrue(root.execution.has_stopped)
74
+ self.assertGreater(root.execution.elapse, 0)
75
+ self.assertEqual(root.result, 3)
76
+ self.assertEqual(root.metadata, dict(note='foo'))
77
+
78
+ # The root space should have one action (foo), no queries, and no logs.
79
+ self.assertEqual(len(list(root.actions)), 1)
80
+ self.assertEqual(len(list(root.queries)), 0)
81
+ self.assertEqual(len(list(root.logs)), 0)
82
+ # 1 query from Bar and 2 from Foo.
83
+ self.assertEqual(len(list(root.all_queries)), 3)
84
+ # 1 log from Bar and 1 from Foo.
85
+ self.assertEqual(len(list(root.all_logs)), 2)
86
+ self.assertEqual(root.usage_summary.total.num_requests, 3)
87
+
88
+ # Inspecting the top-level action (Foo)
89
+ foo_invocation = root.execution.items[0]
90
+ self.assertEqual(len(foo_invocation.execution.items), 3)
91
+
92
+ # Prepare phase.
93
+ prepare_phase = foo_invocation.execution.items[0]
94
+ self.assertIsInstance(
95
+ prepare_phase, action_lib.ExecutionTrace
96
+ )
97
+ self.assertEqual(len(prepare_phase.items), 2)
98
+ self.assertTrue(prepare_phase.has_started)
99
+ self.assertTrue(prepare_phase.has_stopped)
100
+ self.assertEqual(prepare_phase.usage_summary.total.num_requests, 1)
101
+
102
+ # Tracked queries.
103
+ query_invocation = foo_invocation.execution.items[1]
104
+ self.assertIsInstance(query_invocation, lf_structured.QueryInvocation)
105
+ self.assertIs(query_invocation.lm, lm)
106
+
107
+ # Invocation to Bar.
108
+ bar_invocation = foo_invocation.execution.items[2]
109
+ self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
110
+ self.assertIsInstance(bar_invocation.action, Bar)
111
+ self.assertEqual(bar_invocation.result, 2)
112
+ self.assertEqual(bar_invocation.metadata, dict(note='bar'))
113
+ self.assertEqual(len(bar_invocation.execution.items), 2)
114
+
115
+ # Save to HTML
116
+ self.assertIn('result', session.to_html().content)
117
+
118
+ # Save session to JSON
119
+ json_str = session.to_json_str(save_ref_value=True)
120
+ self.assertIsInstance(pg.from_json_str(json_str), action_lib.Session)
121
+
122
+ def test_log(self):
123
+ session = action_lib.Session()
124
+ session.debug('hi', x=1, y=2)
125
+ session.info('hi', x=1, y=2)
126
+ session.warning('hi', x=1, y=2)
127
+ session.error('hi', x=1, y=2)
128
+ session.fatal('hi', x=1, y=2)
129
+
130
+ def test_as_message(self):
131
+ session = action_lib.Session()
132
+ self.assertIsInstance(session.as_message(), lf.AIMessage)
133
+
134
+
135
+ if __name__ == '__main__':
136
+ unittest.main()
@@ -16,19 +16,13 @@
16
16
  # pylint: disable=g-bad-import-order
17
17
  # pylint: disable=g-importing-member
18
18
 
19
- from langfun.core.coding.python.errors import CodeError
20
-
21
- from langfun.core.coding.python.permissions import CodePermission
22
- from langfun.core.coding.python.permissions import permission
23
- from langfun.core.coding.python.permissions import get_permission
24
-
25
- from langfun.core.coding.python.parsing import PythonCodeParser
26
-
19
+ # Expose from `lf.coding` as aliases for `pg.coding` for backward compatibility.
20
+ from langfun.core.coding.python.execution import CodeError
21
+ from langfun.core.coding.python.execution import CodePermission
27
22
  from langfun.core.coding.python.execution import context
28
- from langfun.core.coding.python.execution import get_context
23
+
24
+ from langfun.core.coding.python.parsing import clean
29
25
  from langfun.core.coding.python.execution import evaluate
30
- from langfun.core.coding.python.execution import sandbox_call
31
- from langfun.core.coding.python.execution import call
32
26
  from langfun.core.coding.python.execution import run
33
27
 
34
28
  from langfun.core.coding.python.generation import PythonCode
@@ -12,10 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  """Python code error correction."""
15
- import re
16
15
  from typing import Any
17
16
  import langfun.core as lf
18
- from langfun.core.coding.python import errors
19
17
  from langfun.core.coding.python import execution
20
18
  import pyglove as pg
21
19
 
@@ -31,11 +29,6 @@ class CorrectedCode(pg.Object):
31
29
  corrected_code: str
32
30
 
33
31
 
34
- def remove_docstrings(code):
35
- pattern = re.compile(r"(def .+?:\s*?)('''|\"\"\")((.|\s)*?)(\2)", re.DOTALL)
36
- return pattern.sub(r"\1", code)
37
-
38
-
39
32
  def run_with_correction(
40
33
  code: str,
41
34
  error: str | None = None,
@@ -46,6 +39,7 @@ def run_with_correction(
46
39
  sandbox: bool | None = None,
47
40
  timeout: int | None = 5,
48
41
  returns_code: bool = False,
42
+ returns_stdout: bool = False,
49
43
  outputs_intermediate: bool = False,
50
44
  ) -> Any | tuple[Any, str]:
51
45
  """Correct code with a language model via self-play.
@@ -68,6 +62,7 @@ def run_with_correction(
68
62
  timeout. Applicable only when sandbox is set to True.
69
63
  returns_code: If True, the return value is a tuple of (result, final code).
70
64
  Otherwise the return value is the result only.
65
+ returns_stdout: If True, the stdout (a str) will be returned.
71
66
  outputs_intermediate: If True, intermediate output will be outputted as a
72
67
  dict, with the last line's value accessible by key '__result__'. Otherwise
73
68
  the value of the last line will be returned.
@@ -82,29 +77,33 @@ def run_with_correction(
82
77
  # Delay import at runtime to avoid circular depenency.
83
78
  # pylint: disable=g-import-not-at-top
84
79
  # pytype: disable=import-error
85
- from langfun.core.structured import prompting
80
+ from langfun.core.structured import querying
86
81
  # pytype: enable=import-error
87
82
  # pylint: enable=g-import-not-at-top
88
83
 
89
- code = remove_docstrings(code)
90
84
  if max_attempts == 0:
91
- result = execution.run(
92
- code,
93
- global_vars=global_vars,
94
- sandbox=sandbox,
95
- timeout=timeout,
96
- outputs_intermediate=outputs_intermediate,
85
+ result = _maybe_custom_validate(
86
+ execution.run(
87
+ code,
88
+ global_vars=global_vars,
89
+ sandbox=sandbox,
90
+ timeout=timeout,
91
+ returns_stdout=returns_stdout,
92
+ outputs_intermediate=outputs_intermediate,
93
+ )
97
94
  )
98
95
  return (result, code) if returns_code else result
99
96
 
100
97
  def result_and_error(code: str) -> tuple[Any, str | None]:
101
98
  try:
102
- result = execution.run(
103
- code,
104
- global_vars=global_vars,
105
- sandbox=sandbox,
106
- timeout=timeout,
107
- outputs_intermediate=outputs_intermediate,
99
+ result = _maybe_custom_validate(
100
+ execution.run(
101
+ code,
102
+ global_vars=global_vars,
103
+ sandbox=sandbox,
104
+ timeout=timeout,
105
+ outputs_intermediate=outputs_intermediate,
106
+ )
108
107
  )
109
108
  return (result, None)
110
109
  except Exception as e: # pylint: disable=broad-exception-caught
@@ -122,10 +121,10 @@ def run_with_correction(
122
121
  # structure.
123
122
  try:
124
123
  # Disable autofix for code correction to avoid recursion.
125
- correction = prompting.query(
124
+ correction = querying.query(
126
125
  CodeWithError(code=code, error=error), CorrectedCode, lm=lm, autofix=0
127
126
  )
128
- except errors.CodeError:
127
+ except pg.coding.CodeError:
129
128
  break
130
129
 
131
130
  code = correction.corrected_code
@@ -133,7 +132,7 @@ def run_with_correction(
133
132
  if error is None:
134
133
  return (result, code) if returns_code else result
135
134
 
136
- raise errors.CodeError(
135
+ raise pg.coding.CodeError(
137
136
  code,
138
137
  RuntimeError(
139
138
  f"Cannot correct code after {num_attempts} attempts. "
@@ -191,9 +190,19 @@ def correct(
191
190
 
192
191
  def _error_feedback_str(error: Exception) -> str:
193
192
  """Returns the error str for feedback."""
194
- if isinstance(error, errors.CodeError):
195
- return lf.text_formatting.decolored(
196
- error.format(include_complete_code=False)
197
- )
193
+ if isinstance(error, pg.coding.CodeError):
194
+ return pg.decolor(error.format(include_complete_code=False))
198
195
  else:
199
196
  return f"Encountered {error.__class__.__name__}: {error}"
197
+
198
+
199
+ def _maybe_custom_validate(result: Any) -> Any:
200
+ """Apply custom validation through __validate_generation__ method."""
201
+ if isinstance(result, dict) and "__result__" in result:
202
+ r = result["__result__"]
203
+ else:
204
+ r = result
205
+
206
+ if hasattr(r, "__validate__"):
207
+ r.__validate__()
208
+ return result
@@ -17,8 +17,8 @@ import inspect
17
17
  import unittest
18
18
 
19
19
  from langfun.core.coding.python import correction
20
- from langfun.core.coding.python import errors
21
20
  from langfun.core.llms import fake
21
+ import pyglove as pg
22
22
 
23
23
 
24
24
  class RunWithCorrectionTest(unittest.TestCase):
@@ -45,6 +45,32 @@ class RunWithCorrectionTest(unittest.TestCase):
45
45
  )
46
46
  self.assertEqual(result, 4)
47
47
 
48
+ def test_run_with_correction_upon_custom_validation(self):
49
+
50
+ class Foo(pg.Object):
51
+ x: int
52
+
53
+ def __validate__(self):
54
+ if self.x > 1:
55
+ raise ValueError('value should be less or equal than 1.')
56
+ if self.x < 0:
57
+ self.rebind(x=0, skip_notification=True)
58
+
59
+ result = correction.run_with_correction(
60
+ inspect.cleandoc("""
61
+ Foo(x=2)
62
+ """),
63
+ global_vars=dict(Foo=Foo),
64
+ lm=fake.StaticSequence([
65
+ inspect.cleandoc("""
66
+ CorrectedCode(
67
+ corrected_code='Foo(x=-1)',
68
+ )
69
+ """),
70
+ ]),
71
+ )
72
+ self.assertEqual(result, Foo(0))
73
+
48
74
  def test_run_without_correction(self):
49
75
  result = correction.run_with_correction(
50
76
  inspect.cleandoc("""
@@ -55,7 +81,7 @@ class RunWithCorrectionTest(unittest.TestCase):
55
81
  max_attempts=0,
56
82
  )
57
83
  self.assertEqual(result, 4)
58
- with self.assertRaises(errors.CodeError):
84
+ with self.assertRaises(pg.coding.CodeError):
59
85
  correction.run_with_correction(
60
86
  inspect.cleandoc("""
61
87
  x = 1,
@@ -98,7 +124,7 @@ class CorrectTest(unittest.TestCase):
98
124
 
99
125
  def test_correct_reaching_limit(self):
100
126
  with self.assertRaisesRegex(
101
- errors.CodeError, 'Cannot correct code after 1 attempts'
127
+ pg.coding.CodeError, 'Cannot correct code after 1 attempts'
102
128
  ):
103
129
  correction.correct(
104
130
  inspect.cleandoc("""