langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,150 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Evaluation (v1) for Langfun agentic actions."""
15
+
16
+ import io
17
+ import os
18
+ from typing import Annotated, Any
19
+
20
+ import langfun.core as lf
21
+ from langfun.core import eval as lf_eval
22
+ from langfun.core.agentic import action as action_lib
23
+ import pyglove as pg
24
+
25
+
26
+ class ActionEval(lf.eval.v2.Evaluation):
27
+ """Agent evaluation."""
28
+
29
+ action_args: Annotated[
30
+ dict[str, Any],
31
+ 'Arguments to call the action.'
32
+ ] = {}
33
+
34
+ def process(self, example: pg.Dict) -> tuple[str, dict[str, Any]]:
35
+ action = example.action
36
+ session = action_lib.Session()
37
+ with lf.logging.use_log_level('fatal'):
38
+ action(session=session, **self.action_args)
39
+ return session.final_result, dict(session=session)
40
+
41
+
42
+ #
43
+ # TODO(daiyip): Remove V1 once V2 is fully launched.
44
+ #
45
+
46
+
47
+ @pg.functor()
48
+ def _dummy_schema():
49
+ return int
50
+
51
+
52
+ class ExampleView(pg.Object):
53
+ id: int
54
+ input: Any
55
+ output: Any
56
+ error: str | None = None
57
+
58
+
59
+ class ActionEvalV1(lf_eval.Matching):
60
+ """Base class for action evaluations.
61
+
62
+ The input function should returns a list of pg.Dict, with `action` and
63
+ `groundtruth` fields.
64
+ """
65
+ # We override the schema and prompt to dummy values since they are not used.
66
+ schema_fn = _dummy_schema()
67
+ prompt = '<unused>'
68
+
69
+ def process(self, example: pg.Dict, **kwargs):
70
+ action = example.action
71
+ session = action_lib.Session()
72
+ action(session=session, lm=self.lm, **kwargs)
73
+ return session.as_message()
74
+
75
+ def answer(self, output: Any, example: pg.Dict) -> Any:
76
+ return output
77
+
78
+ def groundtruth(self, example: Any) -> Any:
79
+ return example.groundtruth
80
+
81
+ def audit(
82
+ self,
83
+ example_idx: int,
84
+ example: Any,
85
+ message: lf.Message | None,
86
+ error: Exception | None = None,
87
+ dryrun: bool = False,
88
+ ):
89
+ super().audit(example_idx, example, message, error, dryrun)
90
+ # Write each example to HTML.
91
+ if not dryrun and self.dir:
92
+ def _save_html():
93
+ ExampleView(
94
+ example_idx,
95
+ example,
96
+ None if message is None else message.result,
97
+ error
98
+ ).to_html(
99
+ collapse_level=None,
100
+ enable_summary_tooltip=False,
101
+ ).save(
102
+ os.path.join(self.dir, f'example_{example_idx}.html')
103
+ )
104
+ # Write HTML in a separate thread to avoid blocking the main thread.
105
+ lf.concurrent.get_executor(
106
+ 'background_eval_io', max_workers=16
107
+ ).submit(_save_html)
108
+
109
+ def _render_mismatches(self, s: io.StringIO) -> None:
110
+ s.write('<h2> Mismatches (Incorrect) </h2>')
111
+ first_url = None
112
+ mismatched_ids = sorted([
113
+ example_idx for example_idx, *_ in self.mismatches
114
+ ])
115
+ for example_idx in mismatched_ids:
116
+ url = os.path.join(self.dir, f'example_{example_idx}.html')
117
+ if first_url is None:
118
+ first_url = url
119
+ s.write(
120
+ f'<a href="{url}" style="margin-right: 10px" target="example_view">'
121
+ f'{example_idx}</a> '
122
+ )
123
+ if first_url:
124
+ s.write(
125
+ '<iframe style="border:0;width:100%;height:100%" name="example_view"'
126
+ f'src="{first_url}" title="Example View"></iframe>'
127
+ )
128
+ else:
129
+ s.write('No mismatches found.')
130
+
131
+ def _render_matches(self, s: io.StringIO) -> None:
132
+ s.write('<h2> Matches (correct) </h2>')
133
+ first_url = None
134
+ matched_ids = sorted([
135
+ example_idx for example_idx, *_ in self.matches
136
+ ])
137
+ for example_idx in matched_ids:
138
+ url = os.path.join(self.dir, f'example_{example_idx}.html')
139
+ if first_url is None:
140
+ first_url = url
141
+ s.write(
142
+ f'<a href="{url}" style="margin-right: 10px">{example_idx}</a> '
143
+ )
144
+ if first_url:
145
+ s.write(
146
+ '<iframe style="border:0;width:100%;height:100%" name="example_view"'
147
+ f'src="{first_url}" title="Example View"></iframe>'
148
+ )
149
+ else:
150
+ s.write('No matches found.')
@@ -0,0 +1,109 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for action evaluation."""
15
+
16
+ import os
17
+ import tempfile
18
+ import unittest
19
+
20
+ from langfun.core import eval as lf_eval
21
+ from langfun.core import llms as lf_llms
22
+ from langfun.core.agentic import action as action_lib
23
+ from langfun.core.agentic import action_eval
24
+ import pyglove as pg
25
+
26
+
27
+ class Foo(action_lib.Action):
28
+ x: int
29
+
30
+ def call(self, session, **kwargs):
31
+ del session, kwargs
32
+ return self.x
33
+
34
+
35
+ @pg.functor()
36
+ def foo_inputs():
37
+ return [
38
+ pg.Dict(action=Foo(1), groundtruth=1),
39
+ pg.Dict(action=Foo(2), groundtruth=1),
40
+ ]
41
+
42
+
43
+ class ActionEvalTest(unittest.TestCase):
44
+
45
+ def test_basics(self):
46
+
47
+ class FooEval(action_eval.ActionEval):
48
+ inputs = foo_inputs()
49
+ metrics = [lf_eval.v2.metrics.Match()]
50
+ action_args = dict(
51
+ lm=lf_llms.Echo()
52
+ )
53
+
54
+ s = FooEval()
55
+ root_dir = os.path.join(tempfile.gettempdir(), 'foo_eval')
56
+ s.run(root_dir, plugins=[])
57
+ self.assertEqual(s.metrics[0].matches, 0.5)
58
+ self.assertEqual(s.metrics[0].mismatches, 0.5)
59
+
60
+
61
+ class ActionEvalV1Test(unittest.TestCase):
62
+
63
+ def test_basics(self):
64
+
65
+ class FooEval(action_eval.ActionEvalV1):
66
+ lm = lf_llms.Echo()
67
+ inputs = foo_inputs()
68
+
69
+ s = FooEval()
70
+ result = s.run(summary=False)
71
+ pg.print(result)
72
+ self.assertEqual(
73
+ result,
74
+ dict(
75
+ experiment_setup=dict(
76
+ id=s.id,
77
+ dir=None,
78
+ model='Echo',
79
+ prompt_template='<unused>',
80
+ method='query',
81
+ schema_fn='_dummy_schema()'
82
+ ),
83
+ cache_stats=dict(
84
+ use_cache=True,
85
+ num_queries=0,
86
+ num_hits=0,
87
+ num_updates=0,
88
+ ),
89
+ metrics=dict(
90
+ total=2,
91
+ failures=0,
92
+ failure_rate=0.0,
93
+ oop_failures=0,
94
+ oop_failure_rate=0.0,
95
+ non_oop_failures=0,
96
+ non_oop_failure_rate=0.0,
97
+ failure_breakdown={},
98
+ num_matches=0,
99
+ match_rate=0.0,
100
+ num_mismatches=2,
101
+ mismatch_rate=1.0
102
+ ),
103
+ usage=None
104
+ )
105
+ )
106
+
107
+
108
+ if __name__ == '__main__':
109
+ unittest.main()
@@ -0,0 +1,136 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for base action."""
15
+
16
+ import unittest
17
+
18
+ import langfun.core as lf
19
+ from langfun.core.agentic import action as action_lib
20
+ from langfun.core.llms import fake
21
+ import langfun.core.structured as lf_structured
22
+ import pyglove as pg
23
+
24
+
25
+ class SessionTest(unittest.TestCase):
26
+
27
+ def test_basics(self):
28
+ test = self
29
+
30
+ class Bar(action_lib.Action):
31
+
32
+ def call(self, session, *, lm, **kwargs):
33
+ test.assertIs(session.current_action.action, self)
34
+ session.info('Begin Bar')
35
+ session.query('bar', lm=lm)
36
+ session.add_metadata(note='bar')
37
+ return 2
38
+
39
+ class Foo(action_lib.Action):
40
+ x: int
41
+
42
+ def call(self, session, *, lm, **kwargs):
43
+ test.assertIs(session.current_action.action, self)
44
+ with session.phase('prepare'):
45
+ session.info('Begin Foo', x=1)
46
+ session.query('foo', lm=lm)
47
+ with session.track_queries():
48
+ self.make_additional_query(lm)
49
+ session.add_metadata(note='foo')
50
+ return self.x + Bar()(session, lm=lm)
51
+
52
+ def make_additional_query(self, lm):
53
+ lf_structured.query('additional query', lm=lm)
54
+
55
+ lm = fake.StaticResponse('lm response')
56
+ foo = Foo(1)
57
+ self.assertEqual(foo(lm=lm), 3)
58
+
59
+ session = foo.session
60
+ self.assertIsNotNone(session)
61
+ self.assertIsInstance(session.root.action, action_lib.RootAction)
62
+ self.assertIs(session.current_action, session.root)
63
+
64
+ #
65
+ # Inspecting the root invocation.
66
+ #
67
+
68
+ root = session.root
69
+ self.assertEqual(len(root.execution.items), 1)
70
+ self.assertIs(root.execution.items[0].action, foo)
71
+
72
+ self.assertTrue(root.execution.has_started)
73
+ self.assertTrue(root.execution.has_stopped)
74
+ self.assertGreater(root.execution.elapse, 0)
75
+ self.assertEqual(root.result, 3)
76
+ self.assertEqual(root.metadata, dict(note='foo'))
77
+
78
+ # The root space should have one action (foo), no queries, and no logs.
79
+ self.assertEqual(len(list(root.actions)), 1)
80
+ self.assertEqual(len(list(root.queries)), 0)
81
+ self.assertEqual(len(list(root.logs)), 0)
82
+ # 1 query from Bar and 2 from Foo.
83
+ self.assertEqual(len(list(root.all_queries)), 3)
84
+ # 1 log from Bar and 1 from Foo.
85
+ self.assertEqual(len(list(root.all_logs)), 2)
86
+ self.assertEqual(root.usage_summary.total.num_requests, 3)
87
+
88
+ # Inspecting the top-level action (Foo)
89
+ foo_invocation = root.execution.items[0]
90
+ self.assertEqual(len(foo_invocation.execution.items), 3)
91
+
92
+ # Prepare phase.
93
+ prepare_phase = foo_invocation.execution.items[0]
94
+ self.assertIsInstance(
95
+ prepare_phase, action_lib.ExecutionTrace
96
+ )
97
+ self.assertEqual(len(prepare_phase.items), 2)
98
+ self.assertTrue(prepare_phase.has_started)
99
+ self.assertTrue(prepare_phase.has_stopped)
100
+ self.assertEqual(prepare_phase.usage_summary.total.num_requests, 1)
101
+
102
+ # Tracked queries.
103
+ query_invocation = foo_invocation.execution.items[1]
104
+ self.assertIsInstance(query_invocation, lf_structured.QueryInvocation)
105
+ self.assertIs(query_invocation.lm, lm)
106
+
107
+ # Invocation to Bar.
108
+ bar_invocation = foo_invocation.execution.items[2]
109
+ self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
110
+ self.assertIsInstance(bar_invocation.action, Bar)
111
+ self.assertEqual(bar_invocation.result, 2)
112
+ self.assertEqual(bar_invocation.metadata, dict(note='bar'))
113
+ self.assertEqual(len(bar_invocation.execution.items), 2)
114
+
115
+ # Save to HTML
116
+ self.assertIn('result', session.to_html().content)
117
+
118
+ # Save session to JSON
119
+ json_str = session.to_json_str(save_ref_value=True)
120
+ self.assertIsInstance(pg.from_json_str(json_str), action_lib.Session)
121
+
122
+ def test_log(self):
123
+ session = action_lib.Session()
124
+ session.debug('hi', x=1, y=2)
125
+ session.info('hi', x=1, y=2)
126
+ session.warning('hi', x=1, y=2)
127
+ session.error('hi', x=1, y=2)
128
+ session.fatal('hi', x=1, y=2)
129
+
130
+ def test_as_message(self):
131
+ session = action_lib.Session()
132
+ self.assertIsInstance(session.as_message(), lf.AIMessage)
133
+
134
+
135
+ if __name__ == '__main__':
136
+ unittest.main()
@@ -16,19 +16,13 @@
16
16
  # pylint: disable=g-bad-import-order
17
17
  # pylint: disable=g-importing-member
18
18
 
19
- from langfun.core.coding.python.errors import CodeError
20
-
21
- from langfun.core.coding.python.permissions import CodePermission
22
- from langfun.core.coding.python.permissions import permission
23
- from langfun.core.coding.python.permissions import get_permission
24
-
25
- from langfun.core.coding.python.parsing import PythonCodeParser
26
-
19
+ # Expose from `lf.coding` as aliases for `pg.coding` for backward compatibility.
20
+ from langfun.core.coding.python.execution import CodeError
21
+ from langfun.core.coding.python.execution import CodePermission
27
22
  from langfun.core.coding.python.execution import context
28
- from langfun.core.coding.python.execution import get_context
23
+
24
+ from langfun.core.coding.python.parsing import clean
29
25
  from langfun.core.coding.python.execution import evaluate
30
- from langfun.core.coding.python.execution import sandbox_call
31
- from langfun.core.coding.python.execution import call
32
26
  from langfun.core.coding.python.execution import run
33
27
 
34
28
  from langfun.core.coding.python.generation import PythonCode
@@ -14,7 +14,6 @@
14
14
  """Python code error correction."""
15
15
  from typing import Any
16
16
  import langfun.core as lf
17
- from langfun.core.coding.python import errors
18
17
  from langfun.core.coding.python import execution
19
18
  import pyglove as pg
20
19
 
@@ -40,6 +39,7 @@ def run_with_correction(
40
39
  sandbox: bool | None = None,
41
40
  timeout: int | None = 5,
42
41
  returns_code: bool = False,
42
+ returns_stdout: bool = False,
43
43
  outputs_intermediate: bool = False,
44
44
  ) -> Any | tuple[Any, str]:
45
45
  """Correct code with a language model via self-play.
@@ -62,6 +62,7 @@ def run_with_correction(
62
62
  timeout. Applicable only when sandbox is set to True.
63
63
  returns_code: If True, the return value is a tuple of (result, final code).
64
64
  Otherwise the return value is the result only.
65
+ returns_stdout: If True, the stdout (a str) will be returned.
65
66
  outputs_intermediate: If True, intermediate output will be outputted as a
66
67
  dict, with the last line's value accessible by key '__result__'. Otherwise
67
68
  the value of the last line will be returned.
@@ -76,28 +77,33 @@ def run_with_correction(
76
77
  # Delay import at runtime to avoid circular depenency.
77
78
  # pylint: disable=g-import-not-at-top
78
79
  # pytype: disable=import-error
79
- from langfun.core.structured import prompting
80
+ from langfun.core.structured import querying
80
81
  # pytype: enable=import-error
81
82
  # pylint: enable=g-import-not-at-top
82
83
 
83
84
  if max_attempts == 0:
84
- result = execution.run(
85
- code,
86
- global_vars=global_vars,
87
- sandbox=sandbox,
88
- timeout=timeout,
89
- outputs_intermediate=outputs_intermediate,
85
+ result = _maybe_custom_validate(
86
+ execution.run(
87
+ code,
88
+ global_vars=global_vars,
89
+ sandbox=sandbox,
90
+ timeout=timeout,
91
+ returns_stdout=returns_stdout,
92
+ outputs_intermediate=outputs_intermediate,
93
+ )
90
94
  )
91
95
  return (result, code) if returns_code else result
92
96
 
93
97
  def result_and_error(code: str) -> tuple[Any, str | None]:
94
98
  try:
95
- result = execution.run(
96
- code,
97
- global_vars=global_vars,
98
- sandbox=sandbox,
99
- timeout=timeout,
100
- outputs_intermediate=outputs_intermediate,
99
+ result = _maybe_custom_validate(
100
+ execution.run(
101
+ code,
102
+ global_vars=global_vars,
103
+ sandbox=sandbox,
104
+ timeout=timeout,
105
+ outputs_intermediate=outputs_intermediate,
106
+ )
101
107
  )
102
108
  return (result, None)
103
109
  except Exception as e: # pylint: disable=broad-exception-caught
@@ -115,10 +121,10 @@ def run_with_correction(
115
121
  # structure.
116
122
  try:
117
123
  # Disable autofix for code correction to avoid recursion.
118
- correction = prompting.query(
124
+ correction = querying.query(
119
125
  CodeWithError(code=code, error=error), CorrectedCode, lm=lm, autofix=0
120
126
  )
121
- except errors.CodeError:
127
+ except pg.coding.CodeError:
122
128
  break
123
129
 
124
130
  code = correction.corrected_code
@@ -126,7 +132,7 @@ def run_with_correction(
126
132
  if error is None:
127
133
  return (result, code) if returns_code else result
128
134
 
129
- raise errors.CodeError(
135
+ raise pg.coding.CodeError(
130
136
  code,
131
137
  RuntimeError(
132
138
  f"Cannot correct code after {num_attempts} attempts. "
@@ -184,9 +190,19 @@ def correct(
184
190
 
185
191
  def _error_feedback_str(error: Exception) -> str:
186
192
  """Returns the error str for feedback."""
187
- if isinstance(error, errors.CodeError):
188
- return lf.text_formatting.decolored(
189
- error.format(include_complete_code=False)
190
- )
193
+ if isinstance(error, pg.coding.CodeError):
194
+ return pg.decolor(error.format(include_complete_code=False))
191
195
  else:
192
196
  return f"Encountered {error.__class__.__name__}: {error}"
197
+
198
+
199
+ def _maybe_custom_validate(result: Any) -> Any:
200
+ """Apply custom validation through __validate_generation__ method."""
201
+ if isinstance(result, dict) and "__result__" in result:
202
+ r = result["__result__"]
203
+ else:
204
+ r = result
205
+
206
+ if hasattr(r, "__validate__"):
207
+ r.__validate__()
208
+ return result
@@ -17,8 +17,8 @@ import inspect
17
17
  import unittest
18
18
 
19
19
  from langfun.core.coding.python import correction
20
- from langfun.core.coding.python import errors
21
20
  from langfun.core.llms import fake
21
+ import pyglove as pg
22
22
 
23
23
 
24
24
  class RunWithCorrectionTest(unittest.TestCase):
@@ -45,6 +45,32 @@ class RunWithCorrectionTest(unittest.TestCase):
45
45
  )
46
46
  self.assertEqual(result, 4)
47
47
 
48
+ def test_run_with_correction_upon_custom_validation(self):
49
+
50
+ class Foo(pg.Object):
51
+ x: int
52
+
53
+ def __validate__(self):
54
+ if self.x > 1:
55
+ raise ValueError('value should be less or equal than 1.')
56
+ if self.x < 0:
57
+ self.rebind(x=0, skip_notification=True)
58
+
59
+ result = correction.run_with_correction(
60
+ inspect.cleandoc("""
61
+ Foo(x=2)
62
+ """),
63
+ global_vars=dict(Foo=Foo),
64
+ lm=fake.StaticSequence([
65
+ inspect.cleandoc("""
66
+ CorrectedCode(
67
+ corrected_code='Foo(x=-1)',
68
+ )
69
+ """),
70
+ ]),
71
+ )
72
+ self.assertEqual(result, Foo(0))
73
+
48
74
  def test_run_without_correction(self):
49
75
  result = correction.run_with_correction(
50
76
  inspect.cleandoc("""
@@ -55,7 +81,7 @@ class RunWithCorrectionTest(unittest.TestCase):
55
81
  max_attempts=0,
56
82
  )
57
83
  self.assertEqual(result, 4)
58
- with self.assertRaises(errors.CodeError):
84
+ with self.assertRaises(pg.coding.CodeError):
59
85
  correction.run_with_correction(
60
86
  inspect.cleandoc("""
61
87
  x = 1,
@@ -98,7 +124,7 @@ class CorrectTest(unittest.TestCase):
98
124
 
99
125
  def test_correct_reaching_limit(self):
100
126
  with self.assertRaisesRegex(
101
- errors.CodeError, 'Cannot correct code after 1 attempts'
127
+ pg.coding.CodeError, 'Cannot correct code after 1 attempts'
102
128
  ):
103
129
  correction.correct(
104
130
  inspect.cleandoc("""