langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,150 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Evaluation (v1) for Langfun agentic actions."""
|
15
|
+
|
16
|
+
import io
|
17
|
+
import os
|
18
|
+
from typing import Annotated, Any
|
19
|
+
|
20
|
+
import langfun.core as lf
|
21
|
+
from langfun.core import eval as lf_eval
|
22
|
+
from langfun.core.agentic import action as action_lib
|
23
|
+
import pyglove as pg
|
24
|
+
|
25
|
+
|
26
|
+
class ActionEval(lf.eval.v2.Evaluation):
|
27
|
+
"""Agent evaluation."""
|
28
|
+
|
29
|
+
action_args: Annotated[
|
30
|
+
dict[str, Any],
|
31
|
+
'Arguments to call the action.'
|
32
|
+
] = {}
|
33
|
+
|
34
|
+
def process(self, example: pg.Dict) -> tuple[str, dict[str, Any]]:
|
35
|
+
action = example.action
|
36
|
+
session = action_lib.Session()
|
37
|
+
with lf.logging.use_log_level('fatal'):
|
38
|
+
action(session=session, **self.action_args)
|
39
|
+
return session.final_result, dict(session=session)
|
40
|
+
|
41
|
+
|
42
|
+
#
|
43
|
+
# TODO(daiyip): Remove V1 once V2 is fully launched.
|
44
|
+
#
|
45
|
+
|
46
|
+
|
47
|
+
@pg.functor()
|
48
|
+
def _dummy_schema():
|
49
|
+
return int
|
50
|
+
|
51
|
+
|
52
|
+
class ExampleView(pg.Object):
|
53
|
+
id: int
|
54
|
+
input: Any
|
55
|
+
output: Any
|
56
|
+
error: str | None = None
|
57
|
+
|
58
|
+
|
59
|
+
class ActionEvalV1(lf_eval.Matching):
|
60
|
+
"""Base class for action evaluations.
|
61
|
+
|
62
|
+
The input function should returns a list of pg.Dict, with `action` and
|
63
|
+
`groundtruth` fields.
|
64
|
+
"""
|
65
|
+
# We override the schema and prompt to dummy values since they are not used.
|
66
|
+
schema_fn = _dummy_schema()
|
67
|
+
prompt = '<unused>'
|
68
|
+
|
69
|
+
def process(self, example: pg.Dict, **kwargs):
|
70
|
+
action = example.action
|
71
|
+
session = action_lib.Session()
|
72
|
+
action(session=session, lm=self.lm, **kwargs)
|
73
|
+
return session.as_message()
|
74
|
+
|
75
|
+
def answer(self, output: Any, example: pg.Dict) -> Any:
|
76
|
+
return output
|
77
|
+
|
78
|
+
def groundtruth(self, example: Any) -> Any:
|
79
|
+
return example.groundtruth
|
80
|
+
|
81
|
+
def audit(
|
82
|
+
self,
|
83
|
+
example_idx: int,
|
84
|
+
example: Any,
|
85
|
+
message: lf.Message | None,
|
86
|
+
error: Exception | None = None,
|
87
|
+
dryrun: bool = False,
|
88
|
+
):
|
89
|
+
super().audit(example_idx, example, message, error, dryrun)
|
90
|
+
# Write each example to HTML.
|
91
|
+
if not dryrun and self.dir:
|
92
|
+
def _save_html():
|
93
|
+
ExampleView(
|
94
|
+
example_idx,
|
95
|
+
example,
|
96
|
+
None if message is None else message.result,
|
97
|
+
error
|
98
|
+
).to_html(
|
99
|
+
collapse_level=None,
|
100
|
+
enable_summary_tooltip=False,
|
101
|
+
).save(
|
102
|
+
os.path.join(self.dir, f'example_{example_idx}.html')
|
103
|
+
)
|
104
|
+
# Write HTML in a separate thread to avoid blocking the main thread.
|
105
|
+
lf.concurrent.get_executor(
|
106
|
+
'background_eval_io', max_workers=16
|
107
|
+
).submit(_save_html)
|
108
|
+
|
109
|
+
def _render_mismatches(self, s: io.StringIO) -> None:
|
110
|
+
s.write('<h2> Mismatches (Incorrect) </h2>')
|
111
|
+
first_url = None
|
112
|
+
mismatched_ids = sorted([
|
113
|
+
example_idx for example_idx, *_ in self.mismatches
|
114
|
+
])
|
115
|
+
for example_idx in mismatched_ids:
|
116
|
+
url = os.path.join(self.dir, f'example_{example_idx}.html')
|
117
|
+
if first_url is None:
|
118
|
+
first_url = url
|
119
|
+
s.write(
|
120
|
+
f'<a href="{url}" style="margin-right: 10px" target="example_view">'
|
121
|
+
f'{example_idx}</a> '
|
122
|
+
)
|
123
|
+
if first_url:
|
124
|
+
s.write(
|
125
|
+
'<iframe style="border:0;width:100%;height:100%" name="example_view"'
|
126
|
+
f'src="{first_url}" title="Example View"></iframe>'
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
s.write('No mismatches found.')
|
130
|
+
|
131
|
+
def _render_matches(self, s: io.StringIO) -> None:
|
132
|
+
s.write('<h2> Matches (correct) </h2>')
|
133
|
+
first_url = None
|
134
|
+
matched_ids = sorted([
|
135
|
+
example_idx for example_idx, *_ in self.matches
|
136
|
+
])
|
137
|
+
for example_idx in matched_ids:
|
138
|
+
url = os.path.join(self.dir, f'example_{example_idx}.html')
|
139
|
+
if first_url is None:
|
140
|
+
first_url = url
|
141
|
+
s.write(
|
142
|
+
f'<a href="{url}" style="margin-right: 10px">{example_idx}</a> '
|
143
|
+
)
|
144
|
+
if first_url:
|
145
|
+
s.write(
|
146
|
+
'<iframe style="border:0;width:100%;height:100%" name="example_view"'
|
147
|
+
f'src="{first_url}" title="Example View"></iframe>'
|
148
|
+
)
|
149
|
+
else:
|
150
|
+
s.write('No matches found.')
|
@@ -0,0 +1,109 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for action evaluation."""
|
15
|
+
|
16
|
+
import os
|
17
|
+
import tempfile
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
from langfun.core import eval as lf_eval
|
21
|
+
from langfun.core import llms as lf_llms
|
22
|
+
from langfun.core.agentic import action as action_lib
|
23
|
+
from langfun.core.agentic import action_eval
|
24
|
+
import pyglove as pg
|
25
|
+
|
26
|
+
|
27
|
+
class Foo(action_lib.Action):
|
28
|
+
x: int
|
29
|
+
|
30
|
+
def call(self, session, **kwargs):
|
31
|
+
del session, kwargs
|
32
|
+
return self.x
|
33
|
+
|
34
|
+
|
35
|
+
@pg.functor()
|
36
|
+
def foo_inputs():
|
37
|
+
return [
|
38
|
+
pg.Dict(action=Foo(1), groundtruth=1),
|
39
|
+
pg.Dict(action=Foo(2), groundtruth=1),
|
40
|
+
]
|
41
|
+
|
42
|
+
|
43
|
+
class ActionEvalTest(unittest.TestCase):
|
44
|
+
|
45
|
+
def test_basics(self):
|
46
|
+
|
47
|
+
class FooEval(action_eval.ActionEval):
|
48
|
+
inputs = foo_inputs()
|
49
|
+
metrics = [lf_eval.v2.metrics.Match()]
|
50
|
+
action_args = dict(
|
51
|
+
lm=lf_llms.Echo()
|
52
|
+
)
|
53
|
+
|
54
|
+
s = FooEval()
|
55
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'foo_eval')
|
56
|
+
s.run(root_dir, plugins=[])
|
57
|
+
self.assertEqual(s.metrics[0].matches, 0.5)
|
58
|
+
self.assertEqual(s.metrics[0].mismatches, 0.5)
|
59
|
+
|
60
|
+
|
61
|
+
class ActionEvalV1Test(unittest.TestCase):
|
62
|
+
|
63
|
+
def test_basics(self):
|
64
|
+
|
65
|
+
class FooEval(action_eval.ActionEvalV1):
|
66
|
+
lm = lf_llms.Echo()
|
67
|
+
inputs = foo_inputs()
|
68
|
+
|
69
|
+
s = FooEval()
|
70
|
+
result = s.run(summary=False)
|
71
|
+
pg.print(result)
|
72
|
+
self.assertEqual(
|
73
|
+
result,
|
74
|
+
dict(
|
75
|
+
experiment_setup=dict(
|
76
|
+
id=s.id,
|
77
|
+
dir=None,
|
78
|
+
model='Echo',
|
79
|
+
prompt_template='<unused>',
|
80
|
+
method='query',
|
81
|
+
schema_fn='_dummy_schema()'
|
82
|
+
),
|
83
|
+
cache_stats=dict(
|
84
|
+
use_cache=True,
|
85
|
+
num_queries=0,
|
86
|
+
num_hits=0,
|
87
|
+
num_updates=0,
|
88
|
+
),
|
89
|
+
metrics=dict(
|
90
|
+
total=2,
|
91
|
+
failures=0,
|
92
|
+
failure_rate=0.0,
|
93
|
+
oop_failures=0,
|
94
|
+
oop_failure_rate=0.0,
|
95
|
+
non_oop_failures=0,
|
96
|
+
non_oop_failure_rate=0.0,
|
97
|
+
failure_breakdown={},
|
98
|
+
num_matches=0,
|
99
|
+
match_rate=0.0,
|
100
|
+
num_mismatches=2,
|
101
|
+
mismatch_rate=1.0
|
102
|
+
),
|
103
|
+
usage=None
|
104
|
+
)
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
if __name__ == '__main__':
|
109
|
+
unittest.main()
|
@@ -0,0 +1,136 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for base action."""
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import langfun.core as lf
|
19
|
+
from langfun.core.agentic import action as action_lib
|
20
|
+
from langfun.core.llms import fake
|
21
|
+
import langfun.core.structured as lf_structured
|
22
|
+
import pyglove as pg
|
23
|
+
|
24
|
+
|
25
|
+
class SessionTest(unittest.TestCase):
|
26
|
+
|
27
|
+
def test_basics(self):
|
28
|
+
test = self
|
29
|
+
|
30
|
+
class Bar(action_lib.Action):
|
31
|
+
|
32
|
+
def call(self, session, *, lm, **kwargs):
|
33
|
+
test.assertIs(session.current_action.action, self)
|
34
|
+
session.info('Begin Bar')
|
35
|
+
session.query('bar', lm=lm)
|
36
|
+
session.add_metadata(note='bar')
|
37
|
+
return 2
|
38
|
+
|
39
|
+
class Foo(action_lib.Action):
|
40
|
+
x: int
|
41
|
+
|
42
|
+
def call(self, session, *, lm, **kwargs):
|
43
|
+
test.assertIs(session.current_action.action, self)
|
44
|
+
with session.phase('prepare'):
|
45
|
+
session.info('Begin Foo', x=1)
|
46
|
+
session.query('foo', lm=lm)
|
47
|
+
with session.track_queries():
|
48
|
+
self.make_additional_query(lm)
|
49
|
+
session.add_metadata(note='foo')
|
50
|
+
return self.x + Bar()(session, lm=lm)
|
51
|
+
|
52
|
+
def make_additional_query(self, lm):
|
53
|
+
lf_structured.query('additional query', lm=lm)
|
54
|
+
|
55
|
+
lm = fake.StaticResponse('lm response')
|
56
|
+
foo = Foo(1)
|
57
|
+
self.assertEqual(foo(lm=lm), 3)
|
58
|
+
|
59
|
+
session = foo.session
|
60
|
+
self.assertIsNotNone(session)
|
61
|
+
self.assertIsInstance(session.root.action, action_lib.RootAction)
|
62
|
+
self.assertIs(session.current_action, session.root)
|
63
|
+
|
64
|
+
#
|
65
|
+
# Inspecting the root invocation.
|
66
|
+
#
|
67
|
+
|
68
|
+
root = session.root
|
69
|
+
self.assertEqual(len(root.execution.items), 1)
|
70
|
+
self.assertIs(root.execution.items[0].action, foo)
|
71
|
+
|
72
|
+
self.assertTrue(root.execution.has_started)
|
73
|
+
self.assertTrue(root.execution.has_stopped)
|
74
|
+
self.assertGreater(root.execution.elapse, 0)
|
75
|
+
self.assertEqual(root.result, 3)
|
76
|
+
self.assertEqual(root.metadata, dict(note='foo'))
|
77
|
+
|
78
|
+
# The root space should have one action (foo), no queries, and no logs.
|
79
|
+
self.assertEqual(len(list(root.actions)), 1)
|
80
|
+
self.assertEqual(len(list(root.queries)), 0)
|
81
|
+
self.assertEqual(len(list(root.logs)), 0)
|
82
|
+
# 1 query from Bar and 2 from Foo.
|
83
|
+
self.assertEqual(len(list(root.all_queries)), 3)
|
84
|
+
# 1 log from Bar and 1 from Foo.
|
85
|
+
self.assertEqual(len(list(root.all_logs)), 2)
|
86
|
+
self.assertEqual(root.usage_summary.total.num_requests, 3)
|
87
|
+
|
88
|
+
# Inspecting the top-level action (Foo)
|
89
|
+
foo_invocation = root.execution.items[0]
|
90
|
+
self.assertEqual(len(foo_invocation.execution.items), 3)
|
91
|
+
|
92
|
+
# Prepare phase.
|
93
|
+
prepare_phase = foo_invocation.execution.items[0]
|
94
|
+
self.assertIsInstance(
|
95
|
+
prepare_phase, action_lib.ExecutionTrace
|
96
|
+
)
|
97
|
+
self.assertEqual(len(prepare_phase.items), 2)
|
98
|
+
self.assertTrue(prepare_phase.has_started)
|
99
|
+
self.assertTrue(prepare_phase.has_stopped)
|
100
|
+
self.assertEqual(prepare_phase.usage_summary.total.num_requests, 1)
|
101
|
+
|
102
|
+
# Tracked queries.
|
103
|
+
query_invocation = foo_invocation.execution.items[1]
|
104
|
+
self.assertIsInstance(query_invocation, lf_structured.QueryInvocation)
|
105
|
+
self.assertIs(query_invocation.lm, lm)
|
106
|
+
|
107
|
+
# Invocation to Bar.
|
108
|
+
bar_invocation = foo_invocation.execution.items[2]
|
109
|
+
self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
|
110
|
+
self.assertIsInstance(bar_invocation.action, Bar)
|
111
|
+
self.assertEqual(bar_invocation.result, 2)
|
112
|
+
self.assertEqual(bar_invocation.metadata, dict(note='bar'))
|
113
|
+
self.assertEqual(len(bar_invocation.execution.items), 2)
|
114
|
+
|
115
|
+
# Save to HTML
|
116
|
+
self.assertIn('result', session.to_html().content)
|
117
|
+
|
118
|
+
# Save session to JSON
|
119
|
+
json_str = session.to_json_str(save_ref_value=True)
|
120
|
+
self.assertIsInstance(pg.from_json_str(json_str), action_lib.Session)
|
121
|
+
|
122
|
+
def test_log(self):
|
123
|
+
session = action_lib.Session()
|
124
|
+
session.debug('hi', x=1, y=2)
|
125
|
+
session.info('hi', x=1, y=2)
|
126
|
+
session.warning('hi', x=1, y=2)
|
127
|
+
session.error('hi', x=1, y=2)
|
128
|
+
session.fatal('hi', x=1, y=2)
|
129
|
+
|
130
|
+
def test_as_message(self):
|
131
|
+
session = action_lib.Session()
|
132
|
+
self.assertIsInstance(session.as_message(), lf.AIMessage)
|
133
|
+
|
134
|
+
|
135
|
+
if __name__ == '__main__':
|
136
|
+
unittest.main()
|
@@ -16,19 +16,13 @@
|
|
16
16
|
# pylint: disable=g-bad-import-order
|
17
17
|
# pylint: disable=g-importing-member
|
18
18
|
|
19
|
-
from
|
20
|
-
|
21
|
-
from langfun.core.coding.python.
|
22
|
-
from langfun.core.coding.python.permissions import permission
|
23
|
-
from langfun.core.coding.python.permissions import get_permission
|
24
|
-
|
25
|
-
from langfun.core.coding.python.parsing import PythonCodeParser
|
26
|
-
|
19
|
+
# Expose from `lf.coding` as aliases for `pg.coding` for backward compatibility.
|
20
|
+
from langfun.core.coding.python.execution import CodeError
|
21
|
+
from langfun.core.coding.python.execution import CodePermission
|
27
22
|
from langfun.core.coding.python.execution import context
|
28
|
-
|
23
|
+
|
24
|
+
from langfun.core.coding.python.parsing import clean
|
29
25
|
from langfun.core.coding.python.execution import evaluate
|
30
|
-
from langfun.core.coding.python.execution import sandbox_call
|
31
|
-
from langfun.core.coding.python.execution import call
|
32
26
|
from langfun.core.coding.python.execution import run
|
33
27
|
|
34
28
|
from langfun.core.coding.python.generation import PythonCode
|
@@ -12,10 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
"""Python code error correction."""
|
15
|
-
import re
|
16
15
|
from typing import Any
|
17
16
|
import langfun.core as lf
|
18
|
-
from langfun.core.coding.python import errors
|
19
17
|
from langfun.core.coding.python import execution
|
20
18
|
import pyglove as pg
|
21
19
|
|
@@ -31,11 +29,6 @@ class CorrectedCode(pg.Object):
|
|
31
29
|
corrected_code: str
|
32
30
|
|
33
31
|
|
34
|
-
def remove_docstrings(code):
|
35
|
-
pattern = re.compile(r"(def .+?:\s*?)('''|\"\"\")((.|\s)*?)(\2)", re.DOTALL)
|
36
|
-
return pattern.sub(r"\1", code)
|
37
|
-
|
38
|
-
|
39
32
|
def run_with_correction(
|
40
33
|
code: str,
|
41
34
|
error: str | None = None,
|
@@ -46,6 +39,7 @@ def run_with_correction(
|
|
46
39
|
sandbox: bool | None = None,
|
47
40
|
timeout: int | None = 5,
|
48
41
|
returns_code: bool = False,
|
42
|
+
returns_stdout: bool = False,
|
49
43
|
outputs_intermediate: bool = False,
|
50
44
|
) -> Any | tuple[Any, str]:
|
51
45
|
"""Correct code with a language model via self-play.
|
@@ -68,6 +62,7 @@ def run_with_correction(
|
|
68
62
|
timeout. Applicable only when sandbox is set to True.
|
69
63
|
returns_code: If True, the return value is a tuple of (result, final code).
|
70
64
|
Otherwise the return value is the result only.
|
65
|
+
returns_stdout: If True, the stdout (a str) will be returned.
|
71
66
|
outputs_intermediate: If True, intermediate output will be outputted as a
|
72
67
|
dict, with the last line's value accessible by key '__result__'. Otherwise
|
73
68
|
the value of the last line will be returned.
|
@@ -82,29 +77,33 @@ def run_with_correction(
|
|
82
77
|
# Delay import at runtime to avoid circular depenency.
|
83
78
|
# pylint: disable=g-import-not-at-top
|
84
79
|
# pytype: disable=import-error
|
85
|
-
from langfun.core.structured import
|
80
|
+
from langfun.core.structured import querying
|
86
81
|
# pytype: enable=import-error
|
87
82
|
# pylint: enable=g-import-not-at-top
|
88
83
|
|
89
|
-
code = remove_docstrings(code)
|
90
84
|
if max_attempts == 0:
|
91
|
-
result =
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
85
|
+
result = _maybe_custom_validate(
|
86
|
+
execution.run(
|
87
|
+
code,
|
88
|
+
global_vars=global_vars,
|
89
|
+
sandbox=sandbox,
|
90
|
+
timeout=timeout,
|
91
|
+
returns_stdout=returns_stdout,
|
92
|
+
outputs_intermediate=outputs_intermediate,
|
93
|
+
)
|
97
94
|
)
|
98
95
|
return (result, code) if returns_code else result
|
99
96
|
|
100
97
|
def result_and_error(code: str) -> tuple[Any, str | None]:
|
101
98
|
try:
|
102
|
-
result =
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
99
|
+
result = _maybe_custom_validate(
|
100
|
+
execution.run(
|
101
|
+
code,
|
102
|
+
global_vars=global_vars,
|
103
|
+
sandbox=sandbox,
|
104
|
+
timeout=timeout,
|
105
|
+
outputs_intermediate=outputs_intermediate,
|
106
|
+
)
|
108
107
|
)
|
109
108
|
return (result, None)
|
110
109
|
except Exception as e: # pylint: disable=broad-exception-caught
|
@@ -122,10 +121,10 @@ def run_with_correction(
|
|
122
121
|
# structure.
|
123
122
|
try:
|
124
123
|
# Disable autofix for code correction to avoid recursion.
|
125
|
-
correction =
|
124
|
+
correction = querying.query(
|
126
125
|
CodeWithError(code=code, error=error), CorrectedCode, lm=lm, autofix=0
|
127
126
|
)
|
128
|
-
except
|
127
|
+
except pg.coding.CodeError:
|
129
128
|
break
|
130
129
|
|
131
130
|
code = correction.corrected_code
|
@@ -133,7 +132,7 @@ def run_with_correction(
|
|
133
132
|
if error is None:
|
134
133
|
return (result, code) if returns_code else result
|
135
134
|
|
136
|
-
raise
|
135
|
+
raise pg.coding.CodeError(
|
137
136
|
code,
|
138
137
|
RuntimeError(
|
139
138
|
f"Cannot correct code after {num_attempts} attempts. "
|
@@ -191,9 +190,19 @@ def correct(
|
|
191
190
|
|
192
191
|
def _error_feedback_str(error: Exception) -> str:
|
193
192
|
"""Returns the error str for feedback."""
|
194
|
-
if isinstance(error,
|
195
|
-
return
|
196
|
-
error.format(include_complete_code=False)
|
197
|
-
)
|
193
|
+
if isinstance(error, pg.coding.CodeError):
|
194
|
+
return pg.decolor(error.format(include_complete_code=False))
|
198
195
|
else:
|
199
196
|
return f"Encountered {error.__class__.__name__}: {error}"
|
197
|
+
|
198
|
+
|
199
|
+
def _maybe_custom_validate(result: Any) -> Any:
|
200
|
+
"""Apply custom validation through __validate_generation__ method."""
|
201
|
+
if isinstance(result, dict) and "__result__" in result:
|
202
|
+
r = result["__result__"]
|
203
|
+
else:
|
204
|
+
r = result
|
205
|
+
|
206
|
+
if hasattr(r, "__validate__"):
|
207
|
+
r.__validate__()
|
208
|
+
return result
|
@@ -17,8 +17,8 @@ import inspect
|
|
17
17
|
import unittest
|
18
18
|
|
19
19
|
from langfun.core.coding.python import correction
|
20
|
-
from langfun.core.coding.python import errors
|
21
20
|
from langfun.core.llms import fake
|
21
|
+
import pyglove as pg
|
22
22
|
|
23
23
|
|
24
24
|
class RunWithCorrectionTest(unittest.TestCase):
|
@@ -45,6 +45,32 @@ class RunWithCorrectionTest(unittest.TestCase):
|
|
45
45
|
)
|
46
46
|
self.assertEqual(result, 4)
|
47
47
|
|
48
|
+
def test_run_with_correction_upon_custom_validation(self):
|
49
|
+
|
50
|
+
class Foo(pg.Object):
|
51
|
+
x: int
|
52
|
+
|
53
|
+
def __validate__(self):
|
54
|
+
if self.x > 1:
|
55
|
+
raise ValueError('value should be less or equal than 1.')
|
56
|
+
if self.x < 0:
|
57
|
+
self.rebind(x=0, skip_notification=True)
|
58
|
+
|
59
|
+
result = correction.run_with_correction(
|
60
|
+
inspect.cleandoc("""
|
61
|
+
Foo(x=2)
|
62
|
+
"""),
|
63
|
+
global_vars=dict(Foo=Foo),
|
64
|
+
lm=fake.StaticSequence([
|
65
|
+
inspect.cleandoc("""
|
66
|
+
CorrectedCode(
|
67
|
+
corrected_code='Foo(x=-1)',
|
68
|
+
)
|
69
|
+
"""),
|
70
|
+
]),
|
71
|
+
)
|
72
|
+
self.assertEqual(result, Foo(0))
|
73
|
+
|
48
74
|
def test_run_without_correction(self):
|
49
75
|
result = correction.run_with_correction(
|
50
76
|
inspect.cleandoc("""
|
@@ -55,7 +81,7 @@ class RunWithCorrectionTest(unittest.TestCase):
|
|
55
81
|
max_attempts=0,
|
56
82
|
)
|
57
83
|
self.assertEqual(result, 4)
|
58
|
-
with self.assertRaises(
|
84
|
+
with self.assertRaises(pg.coding.CodeError):
|
59
85
|
correction.run_with_correction(
|
60
86
|
inspect.cleandoc("""
|
61
87
|
x = 1,
|
@@ -98,7 +124,7 @@ class CorrectTest(unittest.TestCase):
|
|
98
124
|
|
99
125
|
def test_correct_reaching_limit(self):
|
100
126
|
with self.assertRaisesRegex(
|
101
|
-
|
127
|
+
pg.coding.CodeError, 'Cannot correct code after 1 attempts'
|
102
128
|
):
|
103
129
|
correction.correct(
|
104
130
|
inspect.cleandoc("""
|