langfun 0.1.2.dev202411110804__py3-none-any.whl → 0.1.2.dev202411150804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +4 -0
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +250 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +84 -0
- langfun/core/console.py +10 -2
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/v2/__init__.py +38 -0
- langfun/core/eval/v2/checkpointing.py +135 -0
- langfun/core/eval/v2/checkpointing_test.py +89 -0
- langfun/core/eval/v2/evaluation.py +627 -0
- langfun/core/eval/v2/evaluation_test.py +156 -0
- langfun/core/eval/v2/example.py +295 -0
- langfun/core/eval/v2/example_test.py +114 -0
- langfun/core/eval/v2/experiment.py +949 -0
- langfun/core/eval/v2/experiment_test.py +304 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +209 -0
- langfun/core/eval/v2/progress_tracking_test.py +56 -0
- langfun/core/eval/v2/reporting.py +144 -0
- langfun/core/eval/v2/reporting_test.py +41 -0
- langfun/core/eval/v2/runners.py +417 -0
- langfun/core/eval/v2/runners_test.py +311 -0
- langfun/core/eval/v2/test_helper.py +80 -0
- langfun/core/language_model.py +122 -11
- langfun/core/language_model_test.py +97 -4
- langfun/core/llms/__init__.py +4 -0
- langfun/core/llms/anthropic.py +12 -0
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/vertexai.py +4 -4
- langfun/core/llms/vertexai_test.py +8 -2
- {langfun-0.1.2.dev202411110804.dist-info → langfun-0.1.2.dev202411150804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202411110804.dist-info → langfun-0.1.2.dev202411150804.dist-info}/RECORD +44 -15
- {langfun-0.1.2.dev202411110804.dist-info → langfun-0.1.2.dev202411150804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202411110804.dist-info → langfun-0.1.2.dev202411150804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202411110804.dist-info → langfun-0.1.2.dev202411150804.dist-info}/top_level.txt +0 -0
langfun/__init__.py
CHANGED
@@ -53,6 +53,10 @@ PythonFunction = coding.PythonFunction
|
|
53
53
|
from langfun.core import llms
|
54
54
|
lm_cache = llms.cache.lm_cache
|
55
55
|
|
56
|
+
from langfun.core import agentic
|
57
|
+
Action = agentic.Action
|
58
|
+
Session = agentic.Session
|
59
|
+
|
56
60
|
from langfun.core import memories
|
57
61
|
|
58
62
|
from langfun.core import modalities
|
@@ -0,0 +1,30 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Langfun agentic framework.."""
|
15
|
+
|
16
|
+
# pylint: disable=g-bad-import-order
|
17
|
+
# pylint: disable=g-importing-member
|
18
|
+
# pylint: disable=g-import-not-at-top
|
19
|
+
|
20
|
+
from langfun.core.agentic.action import Action
|
21
|
+
from langfun.core.agentic.action import ActionInvocation
|
22
|
+
from langfun.core.agentic.action import Session
|
23
|
+
|
24
|
+
from langfun.core.agentic.action_eval import ActionEval
|
25
|
+
from langfun.core.agentic.action_eval import ActionEvalV1
|
26
|
+
|
27
|
+
|
28
|
+
# pylint: enable=g-bad-import-order
|
29
|
+
# pylint: enable=g-importing-member
|
30
|
+
# pylint: enable=g-import-not-at-top
|
@@ -0,0 +1,250 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Base classes for agentic actions."""
|
15
|
+
|
16
|
+
import abc
|
17
|
+
from typing import Annotated, Any, Optional, Union
|
18
|
+
import langfun.core as lf
|
19
|
+
import pyglove as pg
|
20
|
+
|
21
|
+
|
22
|
+
class Action(pg.Object):
|
23
|
+
"""Base class for agent actions."""
|
24
|
+
|
25
|
+
def _on_bound(self):
|
26
|
+
super()._on_bound()
|
27
|
+
self._result = None
|
28
|
+
|
29
|
+
@property
|
30
|
+
def result(self) -> Any:
|
31
|
+
"""Returns the result of the action."""
|
32
|
+
return self._result
|
33
|
+
|
34
|
+
def __call__(
|
35
|
+
self, session: Optional['Session'] = None, **kwargs) -> Any:
|
36
|
+
"""Executes the action."""
|
37
|
+
session = session or Session()
|
38
|
+
try:
|
39
|
+
session.begin(self)
|
40
|
+
self._result = self.call(session=session, **kwargs)
|
41
|
+
return self._result
|
42
|
+
finally:
|
43
|
+
session.end(self)
|
44
|
+
|
45
|
+
@abc.abstractmethod
|
46
|
+
def call(self, session: 'Session', **kwargs) -> Any:
|
47
|
+
"""Subclasses to implement."""
|
48
|
+
|
49
|
+
|
50
|
+
class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
51
|
+
"""A class for capturing the invocation of an action."""
|
52
|
+
action: Action
|
53
|
+
result: Any = None
|
54
|
+
execution: Annotated[
|
55
|
+
list[Union['ActionInvocation', lf.logging.LogEntry]],
|
56
|
+
'Execution execution.'
|
57
|
+
] = []
|
58
|
+
|
59
|
+
# Allow symbolic assignment without `rebind`.
|
60
|
+
allow_symbolic_assignment = True
|
61
|
+
|
62
|
+
@property
|
63
|
+
def logs(self) -> list[lf.logging.LogEntry]:
|
64
|
+
"""Returns logs from execution sequence."""
|
65
|
+
return [v for v in self.execution if isinstance(v, lf.logging.LogEntry)]
|
66
|
+
|
67
|
+
@property
|
68
|
+
def child_invocations(self) -> list['ActionInvocation']:
|
69
|
+
"""Returns child action invocations."""
|
70
|
+
return [v for v in self.execution if isinstance(v, ActionInvocation)]
|
71
|
+
|
72
|
+
def _html_tree_view_summary(
|
73
|
+
self, *, view: pg.views.html.HtmlTreeView, **kwargs
|
74
|
+
):
|
75
|
+
if isinstance(self.action, RootAction):
|
76
|
+
return None
|
77
|
+
kwargs.pop('title')
|
78
|
+
return view.summary(
|
79
|
+
self,
|
80
|
+
title=view.render(
|
81
|
+
self.action, name='action', collapse_level=0,
|
82
|
+
css_classes='invocation-title',
|
83
|
+
),
|
84
|
+
**kwargs
|
85
|
+
)
|
86
|
+
|
87
|
+
def _html_tree_view_content(
|
88
|
+
self,
|
89
|
+
*,
|
90
|
+
root_path: pg.KeyPath | None = None,
|
91
|
+
collapse_level: int | None = None,
|
92
|
+
view: pg.views.html.HtmlTreeView,
|
93
|
+
**kwargs
|
94
|
+
):
|
95
|
+
prepare_phase = []
|
96
|
+
current_phase = prepare_phase
|
97
|
+
action_phases = []
|
98
|
+
for item in self.execution:
|
99
|
+
if isinstance(item, ActionInvocation):
|
100
|
+
current_phase = []
|
101
|
+
action_phases.append(current_phase)
|
102
|
+
current_phase.append(item)
|
103
|
+
|
104
|
+
def _render_phase(
|
105
|
+
phase: list[ActionInvocation | lf.logging.LogEntry]
|
106
|
+
) -> pg.Html.WritableTypes:
|
107
|
+
return pg.Html.element(
|
108
|
+
'div',
|
109
|
+
[
|
110
|
+
view.render(item) for item in phase
|
111
|
+
]
|
112
|
+
)
|
113
|
+
|
114
|
+
def _render_action_phases(
|
115
|
+
phases: list[list[ActionInvocation | lf.logging.LogEntry]]
|
116
|
+
) -> pg.Html.WritableTypes:
|
117
|
+
if len(phases) == 1:
|
118
|
+
return _render_phase(phases[0])
|
119
|
+
return pg.views.html.controls.TabControl(
|
120
|
+
[
|
121
|
+
pg.views.html.controls.Tab(
|
122
|
+
label=f'Step {i + 1}',
|
123
|
+
content=_render_phase(phase),
|
124
|
+
)
|
125
|
+
for i, phase in enumerate(phases)
|
126
|
+
],
|
127
|
+
)
|
128
|
+
|
129
|
+
result_name = 'final_result' if isinstance(
|
130
|
+
self.action, RootAction) else 'result'
|
131
|
+
return pg.Html.element(
|
132
|
+
'div',
|
133
|
+
[
|
134
|
+
view.render(
|
135
|
+
self.result,
|
136
|
+
name=result_name,
|
137
|
+
css_classes=[
|
138
|
+
f'invocation-{result_name}'.replace('_', '-')
|
139
|
+
]
|
140
|
+
),
|
141
|
+
_render_phase(prepare_phase) if prepare_phase else None,
|
142
|
+
_render_action_phases(action_phases)
|
143
|
+
]
|
144
|
+
)
|
145
|
+
|
146
|
+
@classmethod
|
147
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
148
|
+
return super()._html_tree_view_css_styles() + [
|
149
|
+
"""
|
150
|
+
details.invocation-title {
|
151
|
+
display: inline-block;
|
152
|
+
background-color: #b1f0ff;
|
153
|
+
border: 1px solid white;
|
154
|
+
}
|
155
|
+
details.invocation-result {
|
156
|
+
border: 1px solid #eee;
|
157
|
+
}
|
158
|
+
details.invocation-final-result {
|
159
|
+
border: 1px solid #eee;
|
160
|
+
background-color: #fef78f;
|
161
|
+
}
|
162
|
+
"""
|
163
|
+
]
|
164
|
+
|
165
|
+
|
166
|
+
class RootAction(Action):
|
167
|
+
"""A placeholder action for the root of the action tree."""
|
168
|
+
|
169
|
+
def call(self, session: 'Session', **kwargs) -> Any:
|
170
|
+
raise NotImplementedError('Shall not be called.')
|
171
|
+
|
172
|
+
|
173
|
+
class Session(pg.Object):
|
174
|
+
"""Session for performing an agentic task."""
|
175
|
+
|
176
|
+
root_invocation: ActionInvocation = ActionInvocation(RootAction())
|
177
|
+
|
178
|
+
def _on_bound(self):
|
179
|
+
super()._on_bound()
|
180
|
+
self._invocation_stack = [self.root_invocation]
|
181
|
+
|
182
|
+
@property
|
183
|
+
def final_result(self) -> Any:
|
184
|
+
"""Returns the final result of the session."""
|
185
|
+
return self.root_invocation.result
|
186
|
+
|
187
|
+
@property
|
188
|
+
def current_invocation(self) -> ActionInvocation:
|
189
|
+
"""Returns the current invocation."""
|
190
|
+
assert self._invocation_stack
|
191
|
+
return self._invocation_stack[-1]
|
192
|
+
|
193
|
+
def begin(self, action: Action):
|
194
|
+
"""Signal the beginning of the execution of an action."""
|
195
|
+
new_invocation = ActionInvocation(pg.maybe_ref(action))
|
196
|
+
with pg.notify_on_change(False):
|
197
|
+
self.current_invocation.execution.append(new_invocation)
|
198
|
+
self._invocation_stack.append(new_invocation)
|
199
|
+
|
200
|
+
def end(self, action: Action):
|
201
|
+
"""Signal the end of the execution of an action."""
|
202
|
+
assert self._invocation_stack
|
203
|
+
invocation = self._invocation_stack.pop(-1)
|
204
|
+
invocation.rebind(
|
205
|
+
result=action.result, skip_notification=True, raise_on_no_change=False
|
206
|
+
)
|
207
|
+
assert invocation.action is action, (invocation.action, action)
|
208
|
+
assert self._invocation_stack, self._invocation_stack
|
209
|
+
|
210
|
+
if len(self._invocation_stack) == 1:
|
211
|
+
self.root_invocation.rebind(
|
212
|
+
result=invocation.result,
|
213
|
+
skip_notification=True,
|
214
|
+
raise_on_no_change=False
|
215
|
+
)
|
216
|
+
|
217
|
+
def _log(self, level: lf.logging.LogLevel, message: str, **kwargs):
|
218
|
+
with pg.notify_on_change(False):
|
219
|
+
self.current_invocation.execution.append(
|
220
|
+
lf.logging.log(
|
221
|
+
level, message, indent=len(self._invocation_stack) - 1, **kwargs
|
222
|
+
)
|
223
|
+
)
|
224
|
+
|
225
|
+
def debug(self, message: str, **kwargs):
|
226
|
+
"""Logs a debug message to the session."""
|
227
|
+
self._log('debug', message, **kwargs)
|
228
|
+
|
229
|
+
def info(self, message: str, **kwargs):
|
230
|
+
"""Logs an info message to the session."""
|
231
|
+
self._log('info', message, **kwargs)
|
232
|
+
|
233
|
+
def warning(self, message: str, **kwargs):
|
234
|
+
"""Logs a warning message to the session."""
|
235
|
+
self._log('warning', message, **kwargs)
|
236
|
+
|
237
|
+
def error(self, message: str, **kwargs):
|
238
|
+
"""Logs an error message to the session."""
|
239
|
+
self._log('error', message, **kwargs)
|
240
|
+
|
241
|
+
def fatal(self, message: str, **kwargs):
|
242
|
+
"""Logs a fatal message to the session."""
|
243
|
+
self._log('fatal', message, **kwargs)
|
244
|
+
|
245
|
+
def as_message(self) -> lf.AIMessage:
|
246
|
+
"""Returns the session as a message."""
|
247
|
+
return lf.AIMessage(
|
248
|
+
'Agentic task session.',
|
249
|
+
result=self.root_invocation
|
250
|
+
)
|
@@ -0,0 +1,150 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Evaluation (v1) for Langfun agentic actions."""
|
15
|
+
|
16
|
+
import io
|
17
|
+
import os
|
18
|
+
from typing import Annotated, Any
|
19
|
+
|
20
|
+
import langfun.core as lf
|
21
|
+
from langfun.core import eval as lf_eval
|
22
|
+
from langfun.core.agentic import action as action_lib
|
23
|
+
import pyglove as pg
|
24
|
+
|
25
|
+
|
26
|
+
class ActionEval(lf.eval.v2.Evaluation):
|
27
|
+
"""Agent evaluation."""
|
28
|
+
|
29
|
+
action_args: Annotated[
|
30
|
+
dict[str, Any],
|
31
|
+
'Arguments to call the action.'
|
32
|
+
] = {}
|
33
|
+
|
34
|
+
def process(self, example: pg.Dict) -> tuple[str, dict[str, Any]]:
|
35
|
+
action = example.action
|
36
|
+
session = action_lib.Session()
|
37
|
+
with lf.logging.use_log_level('fatal'):
|
38
|
+
action(session=session, **self.action_args)
|
39
|
+
return session.final_result, dict(session=session)
|
40
|
+
|
41
|
+
|
42
|
+
#
|
43
|
+
# TODO(daiyip): Remove V1 once V2 is fully launched.
|
44
|
+
#
|
45
|
+
|
46
|
+
|
47
|
+
@pg.functor()
|
48
|
+
def _dummy_schema():
|
49
|
+
return int
|
50
|
+
|
51
|
+
|
52
|
+
class ExampleView(pg.Object):
|
53
|
+
id: int
|
54
|
+
input: Any
|
55
|
+
output: Any
|
56
|
+
error: str | None = None
|
57
|
+
|
58
|
+
|
59
|
+
class ActionEvalV1(lf_eval.Matching):
|
60
|
+
"""Base class for action evaluations.
|
61
|
+
|
62
|
+
The input function should returns a list of pg.Dict, with `action` and
|
63
|
+
`groundtruth` fields.
|
64
|
+
"""
|
65
|
+
# We override the schema and prompt to dummy values since they are not used.
|
66
|
+
schema_fn = _dummy_schema()
|
67
|
+
prompt = '<unused>'
|
68
|
+
|
69
|
+
def process(self, example: pg.Dict, **kwargs):
|
70
|
+
action = example.action
|
71
|
+
session = action_lib.Session()
|
72
|
+
action(session=session, lm=self.lm, **kwargs)
|
73
|
+
return session.as_message()
|
74
|
+
|
75
|
+
def answer(self, output: Any, example: pg.Dict) -> Any:
|
76
|
+
return output
|
77
|
+
|
78
|
+
def groundtruth(self, example: Any) -> Any:
|
79
|
+
return example.groundtruth
|
80
|
+
|
81
|
+
def audit(
|
82
|
+
self,
|
83
|
+
example_idx: int,
|
84
|
+
example: Any,
|
85
|
+
message: lf.Message | None,
|
86
|
+
error: Exception | None = None,
|
87
|
+
dryrun: bool = False,
|
88
|
+
):
|
89
|
+
super().audit(example_idx, example, message, error, dryrun)
|
90
|
+
# Write each example to HTML.
|
91
|
+
if not dryrun and self.dir:
|
92
|
+
def _save_html():
|
93
|
+
ExampleView(
|
94
|
+
example_idx,
|
95
|
+
example,
|
96
|
+
None if message is None else message.result,
|
97
|
+
error
|
98
|
+
).to_html(
|
99
|
+
collapse_level=None,
|
100
|
+
enable_summary_tooltip=False,
|
101
|
+
).save(
|
102
|
+
os.path.join(self.dir, f'example_{example_idx}.html')
|
103
|
+
)
|
104
|
+
# Write HTML in a separate thread to avoid blocking the main thread.
|
105
|
+
lf.concurrent.get_executor(
|
106
|
+
'background_eval_io', max_workers=16
|
107
|
+
).submit(_save_html)
|
108
|
+
|
109
|
+
def _render_mismatches(self, s: io.StringIO) -> None:
|
110
|
+
s.write('<h2> Mismatches (Incorrect) </h2>')
|
111
|
+
first_url = None
|
112
|
+
mismatched_ids = sorted([
|
113
|
+
example_idx for example_idx, *_ in self.mismatches
|
114
|
+
])
|
115
|
+
for example_idx in mismatched_ids:
|
116
|
+
url = os.path.join(self.dir, f'example_{example_idx}.html')
|
117
|
+
if first_url is None:
|
118
|
+
first_url = url
|
119
|
+
s.write(
|
120
|
+
f'<a href="{url}" style="margin-right: 10px" target="example_view">'
|
121
|
+
f'{example_idx}</a> '
|
122
|
+
)
|
123
|
+
if first_url:
|
124
|
+
s.write(
|
125
|
+
'<iframe style="border:0;width:100%;height:100%" name="example_view"'
|
126
|
+
f'src="{first_url}" title="Example View"></iframe>'
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
s.write('No mismatches found.')
|
130
|
+
|
131
|
+
def _render_matches(self, s: io.StringIO) -> None:
|
132
|
+
s.write('<h2> Matches (correct) </h2>')
|
133
|
+
first_url = None
|
134
|
+
matched_ids = sorted([
|
135
|
+
example_idx for example_idx, *_ in self.matches
|
136
|
+
])
|
137
|
+
for example_idx in matched_ids:
|
138
|
+
url = os.path.join(self.dir, f'example_{example_idx}.html')
|
139
|
+
if first_url is None:
|
140
|
+
first_url = url
|
141
|
+
s.write(
|
142
|
+
f'<a href="{url}" style="margin-right: 10px">{example_idx}</a> '
|
143
|
+
)
|
144
|
+
if first_url:
|
145
|
+
s.write(
|
146
|
+
'<iframe style="border:0;width:100%;height:100%" name="example_view"'
|
147
|
+
f'src="{first_url}" title="Example View"></iframe>'
|
148
|
+
)
|
149
|
+
else:
|
150
|
+
s.write('No matches found.')
|
@@ -0,0 +1,109 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for action evaluation."""
|
15
|
+
|
16
|
+
import os
|
17
|
+
import tempfile
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
from langfun.core import eval as lf_eval
|
21
|
+
from langfun.core import llms as lf_llms
|
22
|
+
from langfun.core.agentic import action as action_lib
|
23
|
+
from langfun.core.agentic import action_eval
|
24
|
+
import pyglove as pg
|
25
|
+
|
26
|
+
|
27
|
+
class Foo(action_lib.Action):
|
28
|
+
x: int
|
29
|
+
|
30
|
+
def call(self, session, **kwargs):
|
31
|
+
del session, kwargs
|
32
|
+
return self.x
|
33
|
+
|
34
|
+
|
35
|
+
@pg.functor()
|
36
|
+
def foo_inputs():
|
37
|
+
return [
|
38
|
+
pg.Dict(action=Foo(1), groundtruth=1),
|
39
|
+
pg.Dict(action=Foo(2), groundtruth=1),
|
40
|
+
]
|
41
|
+
|
42
|
+
|
43
|
+
class ActionEvalTest(unittest.TestCase):
|
44
|
+
|
45
|
+
def test_basics(self):
|
46
|
+
|
47
|
+
class FooEval(action_eval.ActionEval):
|
48
|
+
inputs = foo_inputs()
|
49
|
+
metrics = [lf_eval.v2.metrics.Match()]
|
50
|
+
action_args = dict(
|
51
|
+
lm=lf_llms.Echo()
|
52
|
+
)
|
53
|
+
|
54
|
+
s = FooEval()
|
55
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'foo_eval')
|
56
|
+
s.run(root_dir, plugins=[])
|
57
|
+
self.assertEqual(s.metrics[0].matches, 0.5)
|
58
|
+
self.assertEqual(s.metrics[0].mismatches, 0.5)
|
59
|
+
|
60
|
+
|
61
|
+
class ActionEvalV1Test(unittest.TestCase):
|
62
|
+
|
63
|
+
def test_basics(self):
|
64
|
+
|
65
|
+
class FooEval(action_eval.ActionEvalV1):
|
66
|
+
lm = lf_llms.Echo()
|
67
|
+
inputs = foo_inputs()
|
68
|
+
|
69
|
+
s = FooEval()
|
70
|
+
result = s.run(summary=False)
|
71
|
+
pg.print(result)
|
72
|
+
self.assertEqual(
|
73
|
+
result,
|
74
|
+
dict(
|
75
|
+
experiment_setup=dict(
|
76
|
+
id=s.id,
|
77
|
+
dir=None,
|
78
|
+
model='Echo',
|
79
|
+
prompt_template='<unused>',
|
80
|
+
method='query',
|
81
|
+
schema_fn='_dummy_schema()'
|
82
|
+
),
|
83
|
+
cache_stats=dict(
|
84
|
+
use_cache=True,
|
85
|
+
num_queries=0,
|
86
|
+
num_hits=0,
|
87
|
+
num_updates=0,
|
88
|
+
),
|
89
|
+
metrics=dict(
|
90
|
+
total=2,
|
91
|
+
failures=0,
|
92
|
+
failure_rate=0.0,
|
93
|
+
oop_failures=0,
|
94
|
+
oop_failure_rate=0.0,
|
95
|
+
non_oop_failures=0,
|
96
|
+
non_oop_failure_rate=0.0,
|
97
|
+
failure_breakdown={},
|
98
|
+
num_matches=0,
|
99
|
+
match_rate=0.0,
|
100
|
+
num_mismatches=2,
|
101
|
+
mismatch_rate=1.0
|
102
|
+
),
|
103
|
+
usage=None
|
104
|
+
)
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
if __name__ == '__main__':
|
109
|
+
unittest.main()
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for base action."""
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import langfun.core as lf
|
19
|
+
from langfun.core.agentic import action as action_lib
|
20
|
+
|
21
|
+
|
22
|
+
class SessionTest(unittest.TestCase):
|
23
|
+
|
24
|
+
def test_basics(self):
|
25
|
+
test = self
|
26
|
+
|
27
|
+
class Bar(action_lib.Action):
|
28
|
+
|
29
|
+
def call(self, session, **kwargs):
|
30
|
+
test.assertIs(session.current_invocation.action, self)
|
31
|
+
session.info('Begin Bar')
|
32
|
+
return 2
|
33
|
+
|
34
|
+
class Foo(action_lib.Action):
|
35
|
+
x: int
|
36
|
+
|
37
|
+
def call(self, session, **kwargs):
|
38
|
+
test.assertIs(session.current_invocation.action, self)
|
39
|
+
session.info('Begin Foo', x=1)
|
40
|
+
return self.x + Bar()(session)
|
41
|
+
|
42
|
+
session = action_lib.Session()
|
43
|
+
root = session.root_invocation
|
44
|
+
self.assertIsInstance(root.action, action_lib.RootAction)
|
45
|
+
self.assertIs(session.current_invocation, session.root_invocation)
|
46
|
+
self.assertEqual(Foo(1)(session), 3)
|
47
|
+
self.assertEqual(len(session.root_invocation.child_invocations), 1)
|
48
|
+
self.assertEqual(len(session.root_invocation.child_invocations[0].logs), 1)
|
49
|
+
self.assertEqual(
|
50
|
+
len(session.root_invocation.child_invocations[0].child_invocations),
|
51
|
+
1
|
52
|
+
)
|
53
|
+
self.assertEqual(
|
54
|
+
len(session.root_invocation
|
55
|
+
.child_invocations[0].child_invocations[0].logs),
|
56
|
+
1
|
57
|
+
)
|
58
|
+
self.assertEqual(
|
59
|
+
len(session.root_invocation
|
60
|
+
.child_invocations[0].child_invocations[0].child_invocations),
|
61
|
+
0
|
62
|
+
)
|
63
|
+
self.assertIs(session.current_invocation, session.root_invocation)
|
64
|
+
self.assertIs(session.final_result, 3)
|
65
|
+
self.assertIn(
|
66
|
+
'invocation-final-result',
|
67
|
+
session.to_html().content,
|
68
|
+
)
|
69
|
+
|
70
|
+
def test_log(self):
|
71
|
+
session = action_lib.Session()
|
72
|
+
session.debug('hi', x=1, y=2)
|
73
|
+
session.info('hi', x=1, y=2)
|
74
|
+
session.warning('hi', x=1, y=2)
|
75
|
+
session.error('hi', x=1, y=2)
|
76
|
+
session.fatal('hi', x=1, y=2)
|
77
|
+
|
78
|
+
def test_as_message(self):
|
79
|
+
session = action_lib.Session()
|
80
|
+
self.assertIsInstance(session.as_message(), lf.AIMessage)
|
81
|
+
|
82
|
+
|
83
|
+
if __name__ == '__main__':
|
84
|
+
unittest.main()
|
langfun/core/console.py
CHANGED
@@ -59,12 +59,20 @@ def under_notebook() -> bool:
|
|
59
59
|
return bool(_notebook)
|
60
60
|
|
61
61
|
|
62
|
-
def display(value: Any, clear: bool = False) ->
|
62
|
+
def display(value: Any, clear: bool = False) -> Any: # pylint: disable=redefined-outer-name
|
63
63
|
"""Displays object in current notebook cell."""
|
64
64
|
if _notebook is not None:
|
65
65
|
if clear:
|
66
66
|
_notebook.clear_output()
|
67
|
-
_notebook.display(value)
|
67
|
+
return _notebook.display(value)
|
68
|
+
return None
|
69
|
+
|
70
|
+
|
71
|
+
def run_script(javascript: str) -> Any:
|
72
|
+
"""Runs JavaScript in current notebook cell."""
|
73
|
+
if _notebook is not None:
|
74
|
+
return _notebook.display(_notebook.Javascript(javascript))
|
75
|
+
return
|
68
76
|
|
69
77
|
|
70
78
|
def clear() -> None:
|
langfun/core/console_test.py
CHANGED
@@ -18,6 +18,7 @@ import io
|
|
18
18
|
import unittest
|
19
19
|
|
20
20
|
from langfun.core import console
|
21
|
+
import pyglove as pg
|
21
22
|
|
22
23
|
|
23
24
|
class ConsoleTest(unittest.TestCase):
|
@@ -32,6 +33,22 @@ class ConsoleTest(unittest.TestCase):
|
|
32
33
|
|
33
34
|
def test_under_notebook(self):
|
34
35
|
self.assertFalse(console.under_notebook())
|
36
|
+
console._notebook = True
|
37
|
+
self.assertTrue(console.under_notebook())
|
38
|
+
console._notebook = None
|
39
|
+
|
40
|
+
def test_notebook_interaction(self):
|
41
|
+
console._notebook = pg.Dict(
|
42
|
+
display=lambda x: x, Javascript=lambda x: x, clear_output=lambda: None)
|
43
|
+
self.assertEqual(console.display('hi', clear=True), 'hi')
|
44
|
+
self.assertEqual(
|
45
|
+
console.run_script('console.log("hi")'),
|
46
|
+
'console.log("hi")'
|
47
|
+
)
|
48
|
+
console.clear()
|
49
|
+
console._notebook = None
|
50
|
+
self.assertIsNone(console.display('hi'))
|
51
|
+
self.assertIsNone(console.run_script('console.log("hi")'))
|
35
52
|
|
36
53
|
|
37
54
|
if __name__ == '__main__':
|
langfun/core/eval/__init__.py
CHANGED
@@ -16,6 +16,8 @@
|
|
16
16
|
# pylint: disable=g-importing-member
|
17
17
|
# pylint: disable=g-bad-import-order
|
18
18
|
|
19
|
+
from langfun.core.eval import v2
|
20
|
+
|
19
21
|
from langfun.core.eval.base import register
|
20
22
|
from langfun.core.eval.base import registered_names
|
21
23
|
from langfun.core.eval.base import get_evaluations
|