langfun 0.1.2.dev202411110804__py3-none-any.whl → 0.1.2.dev202411150804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. langfun/__init__.py +4 -0
  2. langfun/core/agentic/__init__.py +30 -0
  3. langfun/core/agentic/action.py +250 -0
  4. langfun/core/agentic/action_eval.py +150 -0
  5. langfun/core/agentic/action_eval_test.py +109 -0
  6. langfun/core/agentic/action_test.py +84 -0
  7. langfun/core/console.py +10 -2
  8. langfun/core/console_test.py +17 -0
  9. langfun/core/eval/__init__.py +2 -0
  10. langfun/core/eval/v2/__init__.py +38 -0
  11. langfun/core/eval/v2/checkpointing.py +135 -0
  12. langfun/core/eval/v2/checkpointing_test.py +89 -0
  13. langfun/core/eval/v2/evaluation.py +627 -0
  14. langfun/core/eval/v2/evaluation_test.py +156 -0
  15. langfun/core/eval/v2/example.py +295 -0
  16. langfun/core/eval/v2/example_test.py +114 -0
  17. langfun/core/eval/v2/experiment.py +949 -0
  18. langfun/core/eval/v2/experiment_test.py +304 -0
  19. langfun/core/eval/v2/metric_values.py +156 -0
  20. langfun/core/eval/v2/metric_values_test.py +80 -0
  21. langfun/core/eval/v2/metrics.py +357 -0
  22. langfun/core/eval/v2/metrics_test.py +203 -0
  23. langfun/core/eval/v2/progress.py +348 -0
  24. langfun/core/eval/v2/progress_test.py +82 -0
  25. langfun/core/eval/v2/progress_tracking.py +209 -0
  26. langfun/core/eval/v2/progress_tracking_test.py +56 -0
  27. langfun/core/eval/v2/reporting.py +144 -0
  28. langfun/core/eval/v2/reporting_test.py +41 -0
  29. langfun/core/eval/v2/runners.py +417 -0
  30. langfun/core/eval/v2/runners_test.py +311 -0
  31. langfun/core/eval/v2/test_helper.py +80 -0
  32. langfun/core/language_model.py +122 -11
  33. langfun/core/language_model_test.py +97 -4
  34. langfun/core/llms/__init__.py +4 -0
  35. langfun/core/llms/anthropic.py +12 -0
  36. langfun/core/llms/compositional.py +101 -0
  37. langfun/core/llms/compositional_test.py +73 -0
  38. langfun/core/llms/vertexai.py +4 -4
  39. langfun/core/llms/vertexai_test.py +8 -2
  40. {langfun-0.1.2.dev202411110804.dist-info → langfun-0.1.2.dev202411150804.dist-info}/METADATA +1 -1
  41. {langfun-0.1.2.dev202411110804.dist-info → langfun-0.1.2.dev202411150804.dist-info}/RECORD +44 -15
  42. {langfun-0.1.2.dev202411110804.dist-info → langfun-0.1.2.dev202411150804.dist-info}/WHEEL +1 -1
  43. {langfun-0.1.2.dev202411110804.dist-info → langfun-0.1.2.dev202411150804.dist-info}/LICENSE +0 -0
  44. {langfun-0.1.2.dev202411110804.dist-info → langfun-0.1.2.dev202411150804.dist-info}/top_level.txt +0 -0
langfun/__init__.py CHANGED
@@ -53,6 +53,10 @@ PythonFunction = coding.PythonFunction
53
53
  from langfun.core import llms
54
54
  lm_cache = llms.cache.lm_cache
55
55
 
56
+ from langfun.core import agentic
57
+ Action = agentic.Action
58
+ Session = agentic.Session
59
+
56
60
  from langfun.core import memories
57
61
 
58
62
  from langfun.core import modalities
@@ -0,0 +1,30 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Langfun agentic framework.."""
15
+
16
+ # pylint: disable=g-bad-import-order
17
+ # pylint: disable=g-importing-member
18
+ # pylint: disable=g-import-not-at-top
19
+
20
+ from langfun.core.agentic.action import Action
21
+ from langfun.core.agentic.action import ActionInvocation
22
+ from langfun.core.agentic.action import Session
23
+
24
+ from langfun.core.agentic.action_eval import ActionEval
25
+ from langfun.core.agentic.action_eval import ActionEvalV1
26
+
27
+
28
+ # pylint: enable=g-bad-import-order
29
+ # pylint: enable=g-importing-member
30
+ # pylint: enable=g-import-not-at-top
@@ -0,0 +1,250 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Base classes for agentic actions."""
15
+
16
+ import abc
17
+ from typing import Annotated, Any, Optional, Union
18
+ import langfun.core as lf
19
+ import pyglove as pg
20
+
21
+
22
+ class Action(pg.Object):
23
+ """Base class for agent actions."""
24
+
25
+ def _on_bound(self):
26
+ super()._on_bound()
27
+ self._result = None
28
+
29
+ @property
30
+ def result(self) -> Any:
31
+ """Returns the result of the action."""
32
+ return self._result
33
+
34
+ def __call__(
35
+ self, session: Optional['Session'] = None, **kwargs) -> Any:
36
+ """Executes the action."""
37
+ session = session or Session()
38
+ try:
39
+ session.begin(self)
40
+ self._result = self.call(session=session, **kwargs)
41
+ return self._result
42
+ finally:
43
+ session.end(self)
44
+
45
+ @abc.abstractmethod
46
+ def call(self, session: 'Session', **kwargs) -> Any:
47
+ """Subclasses to implement."""
48
+
49
+
50
+ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
51
+ """A class for capturing the invocation of an action."""
52
+ action: Action
53
+ result: Any = None
54
+ execution: Annotated[
55
+ list[Union['ActionInvocation', lf.logging.LogEntry]],
56
+ 'Execution execution.'
57
+ ] = []
58
+
59
+ # Allow symbolic assignment without `rebind`.
60
+ allow_symbolic_assignment = True
61
+
62
+ @property
63
+ def logs(self) -> list[lf.logging.LogEntry]:
64
+ """Returns logs from execution sequence."""
65
+ return [v for v in self.execution if isinstance(v, lf.logging.LogEntry)]
66
+
67
+ @property
68
+ def child_invocations(self) -> list['ActionInvocation']:
69
+ """Returns child action invocations."""
70
+ return [v for v in self.execution if isinstance(v, ActionInvocation)]
71
+
72
+ def _html_tree_view_summary(
73
+ self, *, view: pg.views.html.HtmlTreeView, **kwargs
74
+ ):
75
+ if isinstance(self.action, RootAction):
76
+ return None
77
+ kwargs.pop('title')
78
+ return view.summary(
79
+ self,
80
+ title=view.render(
81
+ self.action, name='action', collapse_level=0,
82
+ css_classes='invocation-title',
83
+ ),
84
+ **kwargs
85
+ )
86
+
87
+ def _html_tree_view_content(
88
+ self,
89
+ *,
90
+ root_path: pg.KeyPath | None = None,
91
+ collapse_level: int | None = None,
92
+ view: pg.views.html.HtmlTreeView,
93
+ **kwargs
94
+ ):
95
+ prepare_phase = []
96
+ current_phase = prepare_phase
97
+ action_phases = []
98
+ for item in self.execution:
99
+ if isinstance(item, ActionInvocation):
100
+ current_phase = []
101
+ action_phases.append(current_phase)
102
+ current_phase.append(item)
103
+
104
+ def _render_phase(
105
+ phase: list[ActionInvocation | lf.logging.LogEntry]
106
+ ) -> pg.Html.WritableTypes:
107
+ return pg.Html.element(
108
+ 'div',
109
+ [
110
+ view.render(item) for item in phase
111
+ ]
112
+ )
113
+
114
+ def _render_action_phases(
115
+ phases: list[list[ActionInvocation | lf.logging.LogEntry]]
116
+ ) -> pg.Html.WritableTypes:
117
+ if len(phases) == 1:
118
+ return _render_phase(phases[0])
119
+ return pg.views.html.controls.TabControl(
120
+ [
121
+ pg.views.html.controls.Tab(
122
+ label=f'Step {i + 1}',
123
+ content=_render_phase(phase),
124
+ )
125
+ for i, phase in enumerate(phases)
126
+ ],
127
+ )
128
+
129
+ result_name = 'final_result' if isinstance(
130
+ self.action, RootAction) else 'result'
131
+ return pg.Html.element(
132
+ 'div',
133
+ [
134
+ view.render(
135
+ self.result,
136
+ name=result_name,
137
+ css_classes=[
138
+ f'invocation-{result_name}'.replace('_', '-')
139
+ ]
140
+ ),
141
+ _render_phase(prepare_phase) if prepare_phase else None,
142
+ _render_action_phases(action_phases)
143
+ ]
144
+ )
145
+
146
+ @classmethod
147
+ def _html_tree_view_css_styles(cls) -> list[str]:
148
+ return super()._html_tree_view_css_styles() + [
149
+ """
150
+ details.invocation-title {
151
+ display: inline-block;
152
+ background-color: #b1f0ff;
153
+ border: 1px solid white;
154
+ }
155
+ details.invocation-result {
156
+ border: 1px solid #eee;
157
+ }
158
+ details.invocation-final-result {
159
+ border: 1px solid #eee;
160
+ background-color: #fef78f;
161
+ }
162
+ """
163
+ ]
164
+
165
+
166
+ class RootAction(Action):
167
+ """A placeholder action for the root of the action tree."""
168
+
169
+ def call(self, session: 'Session', **kwargs) -> Any:
170
+ raise NotImplementedError('Shall not be called.')
171
+
172
+
173
+ class Session(pg.Object):
174
+ """Session for performing an agentic task."""
175
+
176
+ root_invocation: ActionInvocation = ActionInvocation(RootAction())
177
+
178
+ def _on_bound(self):
179
+ super()._on_bound()
180
+ self._invocation_stack = [self.root_invocation]
181
+
182
+ @property
183
+ def final_result(self) -> Any:
184
+ """Returns the final result of the session."""
185
+ return self.root_invocation.result
186
+
187
+ @property
188
+ def current_invocation(self) -> ActionInvocation:
189
+ """Returns the current invocation."""
190
+ assert self._invocation_stack
191
+ return self._invocation_stack[-1]
192
+
193
+ def begin(self, action: Action):
194
+ """Signal the beginning of the execution of an action."""
195
+ new_invocation = ActionInvocation(pg.maybe_ref(action))
196
+ with pg.notify_on_change(False):
197
+ self.current_invocation.execution.append(new_invocation)
198
+ self._invocation_stack.append(new_invocation)
199
+
200
+ def end(self, action: Action):
201
+ """Signal the end of the execution of an action."""
202
+ assert self._invocation_stack
203
+ invocation = self._invocation_stack.pop(-1)
204
+ invocation.rebind(
205
+ result=action.result, skip_notification=True, raise_on_no_change=False
206
+ )
207
+ assert invocation.action is action, (invocation.action, action)
208
+ assert self._invocation_stack, self._invocation_stack
209
+
210
+ if len(self._invocation_stack) == 1:
211
+ self.root_invocation.rebind(
212
+ result=invocation.result,
213
+ skip_notification=True,
214
+ raise_on_no_change=False
215
+ )
216
+
217
+ def _log(self, level: lf.logging.LogLevel, message: str, **kwargs):
218
+ with pg.notify_on_change(False):
219
+ self.current_invocation.execution.append(
220
+ lf.logging.log(
221
+ level, message, indent=len(self._invocation_stack) - 1, **kwargs
222
+ )
223
+ )
224
+
225
+ def debug(self, message: str, **kwargs):
226
+ """Logs a debug message to the session."""
227
+ self._log('debug', message, **kwargs)
228
+
229
+ def info(self, message: str, **kwargs):
230
+ """Logs an info message to the session."""
231
+ self._log('info', message, **kwargs)
232
+
233
+ def warning(self, message: str, **kwargs):
234
+ """Logs a warning message to the session."""
235
+ self._log('warning', message, **kwargs)
236
+
237
+ def error(self, message: str, **kwargs):
238
+ """Logs an error message to the session."""
239
+ self._log('error', message, **kwargs)
240
+
241
+ def fatal(self, message: str, **kwargs):
242
+ """Logs a fatal message to the session."""
243
+ self._log('fatal', message, **kwargs)
244
+
245
+ def as_message(self) -> lf.AIMessage:
246
+ """Returns the session as a message."""
247
+ return lf.AIMessage(
248
+ 'Agentic task session.',
249
+ result=self.root_invocation
250
+ )
@@ -0,0 +1,150 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Evaluation (v1) for Langfun agentic actions."""
15
+
16
+ import io
17
+ import os
18
+ from typing import Annotated, Any
19
+
20
+ import langfun.core as lf
21
+ from langfun.core import eval as lf_eval
22
+ from langfun.core.agentic import action as action_lib
23
+ import pyglove as pg
24
+
25
+
26
+ class ActionEval(lf.eval.v2.Evaluation):
27
+ """Agent evaluation."""
28
+
29
+ action_args: Annotated[
30
+ dict[str, Any],
31
+ 'Arguments to call the action.'
32
+ ] = {}
33
+
34
+ def process(self, example: pg.Dict) -> tuple[str, dict[str, Any]]:
35
+ action = example.action
36
+ session = action_lib.Session()
37
+ with lf.logging.use_log_level('fatal'):
38
+ action(session=session, **self.action_args)
39
+ return session.final_result, dict(session=session)
40
+
41
+
42
+ #
43
+ # TODO(daiyip): Remove V1 once V2 is fully launched.
44
+ #
45
+
46
+
47
+ @pg.functor()
48
+ def _dummy_schema():
49
+ return int
50
+
51
+
52
+ class ExampleView(pg.Object):
53
+ id: int
54
+ input: Any
55
+ output: Any
56
+ error: str | None = None
57
+
58
+
59
+ class ActionEvalV1(lf_eval.Matching):
60
+ """Base class for action evaluations.
61
+
62
+ The input function should returns a list of pg.Dict, with `action` and
63
+ `groundtruth` fields.
64
+ """
65
+ # We override the schema and prompt to dummy values since they are not used.
66
+ schema_fn = _dummy_schema()
67
+ prompt = '<unused>'
68
+
69
+ def process(self, example: pg.Dict, **kwargs):
70
+ action = example.action
71
+ session = action_lib.Session()
72
+ action(session=session, lm=self.lm, **kwargs)
73
+ return session.as_message()
74
+
75
+ def answer(self, output: Any, example: pg.Dict) -> Any:
76
+ return output
77
+
78
+ def groundtruth(self, example: Any) -> Any:
79
+ return example.groundtruth
80
+
81
+ def audit(
82
+ self,
83
+ example_idx: int,
84
+ example: Any,
85
+ message: lf.Message | None,
86
+ error: Exception | None = None,
87
+ dryrun: bool = False,
88
+ ):
89
+ super().audit(example_idx, example, message, error, dryrun)
90
+ # Write each example to HTML.
91
+ if not dryrun and self.dir:
92
+ def _save_html():
93
+ ExampleView(
94
+ example_idx,
95
+ example,
96
+ None if message is None else message.result,
97
+ error
98
+ ).to_html(
99
+ collapse_level=None,
100
+ enable_summary_tooltip=False,
101
+ ).save(
102
+ os.path.join(self.dir, f'example_{example_idx}.html')
103
+ )
104
+ # Write HTML in a separate thread to avoid blocking the main thread.
105
+ lf.concurrent.get_executor(
106
+ 'background_eval_io', max_workers=16
107
+ ).submit(_save_html)
108
+
109
+ def _render_mismatches(self, s: io.StringIO) -> None:
110
+ s.write('<h2> Mismatches (Incorrect) </h2>')
111
+ first_url = None
112
+ mismatched_ids = sorted([
113
+ example_idx for example_idx, *_ in self.mismatches
114
+ ])
115
+ for example_idx in mismatched_ids:
116
+ url = os.path.join(self.dir, f'example_{example_idx}.html')
117
+ if first_url is None:
118
+ first_url = url
119
+ s.write(
120
+ f'<a href="{url}" style="margin-right: 10px" target="example_view">'
121
+ f'{example_idx}</a> '
122
+ )
123
+ if first_url:
124
+ s.write(
125
+ '<iframe style="border:0;width:100%;height:100%" name="example_view"'
126
+ f'src="{first_url}" title="Example View"></iframe>'
127
+ )
128
+ else:
129
+ s.write('No mismatches found.')
130
+
131
+ def _render_matches(self, s: io.StringIO) -> None:
132
+ s.write('<h2> Matches (correct) </h2>')
133
+ first_url = None
134
+ matched_ids = sorted([
135
+ example_idx for example_idx, *_ in self.matches
136
+ ])
137
+ for example_idx in matched_ids:
138
+ url = os.path.join(self.dir, f'example_{example_idx}.html')
139
+ if first_url is None:
140
+ first_url = url
141
+ s.write(
142
+ f'<a href="{url}" style="margin-right: 10px">{example_idx}</a> '
143
+ )
144
+ if first_url:
145
+ s.write(
146
+ '<iframe style="border:0;width:100%;height:100%" name="example_view"'
147
+ f'src="{first_url}" title="Example View"></iframe>'
148
+ )
149
+ else:
150
+ s.write('No matches found.')
@@ -0,0 +1,109 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for action evaluation."""
15
+
16
+ import os
17
+ import tempfile
18
+ import unittest
19
+
20
+ from langfun.core import eval as lf_eval
21
+ from langfun.core import llms as lf_llms
22
+ from langfun.core.agentic import action as action_lib
23
+ from langfun.core.agentic import action_eval
24
+ import pyglove as pg
25
+
26
+
27
+ class Foo(action_lib.Action):
28
+ x: int
29
+
30
+ def call(self, session, **kwargs):
31
+ del session, kwargs
32
+ return self.x
33
+
34
+
35
+ @pg.functor()
36
+ def foo_inputs():
37
+ return [
38
+ pg.Dict(action=Foo(1), groundtruth=1),
39
+ pg.Dict(action=Foo(2), groundtruth=1),
40
+ ]
41
+
42
+
43
+ class ActionEvalTest(unittest.TestCase):
44
+
45
+ def test_basics(self):
46
+
47
+ class FooEval(action_eval.ActionEval):
48
+ inputs = foo_inputs()
49
+ metrics = [lf_eval.v2.metrics.Match()]
50
+ action_args = dict(
51
+ lm=lf_llms.Echo()
52
+ )
53
+
54
+ s = FooEval()
55
+ root_dir = os.path.join(tempfile.gettempdir(), 'foo_eval')
56
+ s.run(root_dir, plugins=[])
57
+ self.assertEqual(s.metrics[0].matches, 0.5)
58
+ self.assertEqual(s.metrics[0].mismatches, 0.5)
59
+
60
+
61
+ class ActionEvalV1Test(unittest.TestCase):
62
+
63
+ def test_basics(self):
64
+
65
+ class FooEval(action_eval.ActionEvalV1):
66
+ lm = lf_llms.Echo()
67
+ inputs = foo_inputs()
68
+
69
+ s = FooEval()
70
+ result = s.run(summary=False)
71
+ pg.print(result)
72
+ self.assertEqual(
73
+ result,
74
+ dict(
75
+ experiment_setup=dict(
76
+ id=s.id,
77
+ dir=None,
78
+ model='Echo',
79
+ prompt_template='<unused>',
80
+ method='query',
81
+ schema_fn='_dummy_schema()'
82
+ ),
83
+ cache_stats=dict(
84
+ use_cache=True,
85
+ num_queries=0,
86
+ num_hits=0,
87
+ num_updates=0,
88
+ ),
89
+ metrics=dict(
90
+ total=2,
91
+ failures=0,
92
+ failure_rate=0.0,
93
+ oop_failures=0,
94
+ oop_failure_rate=0.0,
95
+ non_oop_failures=0,
96
+ non_oop_failure_rate=0.0,
97
+ failure_breakdown={},
98
+ num_matches=0,
99
+ match_rate=0.0,
100
+ num_mismatches=2,
101
+ mismatch_rate=1.0
102
+ ),
103
+ usage=None
104
+ )
105
+ )
106
+
107
+
108
+ if __name__ == '__main__':
109
+ unittest.main()
@@ -0,0 +1,84 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for base action."""
15
+
16
+ import unittest
17
+
18
+ import langfun.core as lf
19
+ from langfun.core.agentic import action as action_lib
20
+
21
+
22
+ class SessionTest(unittest.TestCase):
23
+
24
+ def test_basics(self):
25
+ test = self
26
+
27
+ class Bar(action_lib.Action):
28
+
29
+ def call(self, session, **kwargs):
30
+ test.assertIs(session.current_invocation.action, self)
31
+ session.info('Begin Bar')
32
+ return 2
33
+
34
+ class Foo(action_lib.Action):
35
+ x: int
36
+
37
+ def call(self, session, **kwargs):
38
+ test.assertIs(session.current_invocation.action, self)
39
+ session.info('Begin Foo', x=1)
40
+ return self.x + Bar()(session)
41
+
42
+ session = action_lib.Session()
43
+ root = session.root_invocation
44
+ self.assertIsInstance(root.action, action_lib.RootAction)
45
+ self.assertIs(session.current_invocation, session.root_invocation)
46
+ self.assertEqual(Foo(1)(session), 3)
47
+ self.assertEqual(len(session.root_invocation.child_invocations), 1)
48
+ self.assertEqual(len(session.root_invocation.child_invocations[0].logs), 1)
49
+ self.assertEqual(
50
+ len(session.root_invocation.child_invocations[0].child_invocations),
51
+ 1
52
+ )
53
+ self.assertEqual(
54
+ len(session.root_invocation
55
+ .child_invocations[0].child_invocations[0].logs),
56
+ 1
57
+ )
58
+ self.assertEqual(
59
+ len(session.root_invocation
60
+ .child_invocations[0].child_invocations[0].child_invocations),
61
+ 0
62
+ )
63
+ self.assertIs(session.current_invocation, session.root_invocation)
64
+ self.assertIs(session.final_result, 3)
65
+ self.assertIn(
66
+ 'invocation-final-result',
67
+ session.to_html().content,
68
+ )
69
+
70
+ def test_log(self):
71
+ session = action_lib.Session()
72
+ session.debug('hi', x=1, y=2)
73
+ session.info('hi', x=1, y=2)
74
+ session.warning('hi', x=1, y=2)
75
+ session.error('hi', x=1, y=2)
76
+ session.fatal('hi', x=1, y=2)
77
+
78
+ def test_as_message(self):
79
+ session = action_lib.Session()
80
+ self.assertIsInstance(session.as_message(), lf.AIMessage)
81
+
82
+
83
+ if __name__ == '__main__':
84
+ unittest.main()
langfun/core/console.py CHANGED
@@ -59,12 +59,20 @@ def under_notebook() -> bool:
59
59
  return bool(_notebook)
60
60
 
61
61
 
62
- def display(value: Any, clear: bool = False) -> None: # pylint: disable=redefined-outer-name
62
+ def display(value: Any, clear: bool = False) -> Any: # pylint: disable=redefined-outer-name
63
63
  """Displays object in current notebook cell."""
64
64
  if _notebook is not None:
65
65
  if clear:
66
66
  _notebook.clear_output()
67
- _notebook.display(value)
67
+ return _notebook.display(value)
68
+ return None
69
+
70
+
71
+ def run_script(javascript: str) -> Any:
72
+ """Runs JavaScript in current notebook cell."""
73
+ if _notebook is not None:
74
+ return _notebook.display(_notebook.Javascript(javascript))
75
+ return
68
76
 
69
77
 
70
78
  def clear() -> None:
@@ -18,6 +18,7 @@ import io
18
18
  import unittest
19
19
 
20
20
  from langfun.core import console
21
+ import pyglove as pg
21
22
 
22
23
 
23
24
  class ConsoleTest(unittest.TestCase):
@@ -32,6 +33,22 @@ class ConsoleTest(unittest.TestCase):
32
33
 
33
34
  def test_under_notebook(self):
34
35
  self.assertFalse(console.under_notebook())
36
+ console._notebook = True
37
+ self.assertTrue(console.under_notebook())
38
+ console._notebook = None
39
+
40
+ def test_notebook_interaction(self):
41
+ console._notebook = pg.Dict(
42
+ display=lambda x: x, Javascript=lambda x: x, clear_output=lambda: None)
43
+ self.assertEqual(console.display('hi', clear=True), 'hi')
44
+ self.assertEqual(
45
+ console.run_script('console.log("hi")'),
46
+ 'console.log("hi")'
47
+ )
48
+ console.clear()
49
+ console._notebook = None
50
+ self.assertIsNone(console.display('hi'))
51
+ self.assertIsNone(console.run_script('console.log("hi")'))
35
52
 
36
53
 
37
54
  if __name__ == '__main__':
@@ -16,6 +16,8 @@
16
16
  # pylint: disable=g-importing-member
17
17
  # pylint: disable=g-bad-import-order
18
18
 
19
+ from langfun.core.eval import v2
20
+
19
21
  from langfun.core.eval.base import register
20
22
  from langfun.core.eval.base import registered_names
21
23
  from langfun.core.eval.base import get_evaluations