langfun 0.1.2.dev202412020805__py3-none-any.whl → 0.1.2.dev202412050804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +2 -0
- langfun/core/agentic/action.py +74 -24
- langfun/core/agentic/action_test.py +20 -4
- langfun/core/eval/v2/runners.py +3 -0
- langfun/core/llms/__init__.py +1 -7
- langfun/core/llms/openai.py +142 -207
- langfun/core/llms/openai_test.py +160 -224
- langfun/core/llms/vertexai.py +23 -422
- langfun/core/llms/vertexai_test.py +21 -335
- langfun/core/structured/__init__.py +2 -0
- langfun/core/structured/prompting.py +148 -47
- langfun/core/structured/prompting_test.py +84 -1
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412050804.dist-info}/METADATA +1 -12
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412050804.dist-info}/RECORD +17 -17
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412050804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412050804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412050804.dist-info}/top_level.txt +0 -0
langfun/__init__.py
CHANGED
@@ -35,6 +35,8 @@ complete = structured.complete
|
|
35
35
|
score = structured.score
|
36
36
|
generate_class = structured.generate_class
|
37
37
|
|
38
|
+
track_queries = structured.track_queries
|
39
|
+
|
38
40
|
# Helper functions for input/output transformations based on
|
39
41
|
# `lf.query` (e.g. jax-on-beam could use these for batch processing)
|
40
42
|
query_prompt = structured.query_prompt
|
langfun/core/agentic/action.py
CHANGED
@@ -14,8 +14,10 @@
|
|
14
14
|
"""Base classes for agentic actions."""
|
15
15
|
|
16
16
|
import abc
|
17
|
-
|
17
|
+
import contextlib
|
18
|
+
from typing import Annotated, Any, Iterable, Iterator, Optional, Type, Union
|
18
19
|
import langfun.core as lf
|
20
|
+
from langfun.core import structured as lf_structured
|
19
21
|
import pyglove as pg
|
20
22
|
|
21
23
|
|
@@ -35,12 +37,9 @@ class Action(pg.Object):
|
|
35
37
|
self, session: Optional['Session'] = None, **kwargs) -> Any:
|
36
38
|
"""Executes the action."""
|
37
39
|
session = session or Session()
|
38
|
-
|
39
|
-
session.begin(self)
|
40
|
+
with session.track(self):
|
40
41
|
self._result = self.call(session=session, **kwargs)
|
41
42
|
return self._result
|
42
|
-
finally:
|
43
|
-
session.end(self)
|
44
43
|
|
45
44
|
@abc.abstractmethod
|
46
45
|
def call(self, session: 'Session', **kwargs) -> Any:
|
@@ -50,9 +49,20 @@ class Action(pg.Object):
|
|
50
49
|
class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
51
50
|
"""A class for capturing the invocation of an action."""
|
52
51
|
action: Action
|
53
|
-
|
52
|
+
|
53
|
+
result: Annotated[
|
54
|
+
Any,
|
55
|
+
'The result of the action.'
|
56
|
+
] = None
|
57
|
+
|
54
58
|
execution: Annotated[
|
55
|
-
list[
|
59
|
+
list[
|
60
|
+
Union[
|
61
|
+
lf_structured.QueryInvocation,
|
62
|
+
'ActionInvocation',
|
63
|
+
lf.logging.LogEntry
|
64
|
+
]
|
65
|
+
],
|
56
66
|
'Execution execution.'
|
57
67
|
] = []
|
58
68
|
|
@@ -69,6 +79,18 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
69
79
|
"""Returns child action invocations."""
|
70
80
|
return [v for v in self.execution if isinstance(v, ActionInvocation)]
|
71
81
|
|
82
|
+
def queries(
|
83
|
+
self,
|
84
|
+
include_children: bool = False
|
85
|
+
) -> Iterable[lf_structured.QueryInvocation]:
|
86
|
+
"""Iterates over queries from the current invocation."""
|
87
|
+
for v in self.execution:
|
88
|
+
if isinstance(v, lf_structured.QueryInvocation):
|
89
|
+
yield v
|
90
|
+
elif isinstance(v, ActionInvocation):
|
91
|
+
if include_children:
|
92
|
+
yield from v.queries(include_children=True)
|
93
|
+
|
72
94
|
def _html_tree_view_summary(
|
73
95
|
self, *, view: pg.views.html.HtmlTreeView, **kwargs
|
74
96
|
):
|
@@ -190,29 +212,57 @@ class Session(pg.Object):
|
|
190
212
|
assert self._invocation_stack
|
191
213
|
return self._invocation_stack[-1]
|
192
214
|
|
193
|
-
|
194
|
-
|
215
|
+
@contextlib.contextmanager
|
216
|
+
def track(self, action: Action) -> Iterator[ActionInvocation]:
|
217
|
+
"""Track the execution of an action."""
|
195
218
|
new_invocation = ActionInvocation(pg.maybe_ref(action))
|
196
219
|
with pg.notify_on_change(False):
|
197
220
|
self.current_invocation.execution.append(new_invocation)
|
198
221
|
self._invocation_stack.append(new_invocation)
|
199
222
|
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
assert invocation.action is action, (invocation.action, action)
|
208
|
-
assert self._invocation_stack, self._invocation_stack
|
209
|
-
|
210
|
-
if len(self._invocation_stack) == 1:
|
211
|
-
self.root_invocation.rebind(
|
212
|
-
result=invocation.result,
|
213
|
-
skip_notification=True,
|
214
|
-
raise_on_no_change=False
|
223
|
+
try:
|
224
|
+
yield new_invocation
|
225
|
+
finally:
|
226
|
+
assert self._invocation_stack
|
227
|
+
invocation = self._invocation_stack.pop(-1)
|
228
|
+
invocation.rebind(
|
229
|
+
result=action.result, skip_notification=True, raise_on_no_change=False
|
215
230
|
)
|
231
|
+
assert invocation.action is action, (invocation.action, action)
|
232
|
+
assert self._invocation_stack, self._invocation_stack
|
233
|
+
|
234
|
+
if len(self._invocation_stack) == 1:
|
235
|
+
self.root_invocation.rebind(
|
236
|
+
result=invocation.result,
|
237
|
+
skip_notification=True,
|
238
|
+
raise_on_no_change=False
|
239
|
+
)
|
240
|
+
|
241
|
+
def query(
|
242
|
+
self,
|
243
|
+
prompt: Union[str, lf.Template, Any],
|
244
|
+
schema: Union[
|
245
|
+
lf_structured.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
246
|
+
] = None,
|
247
|
+
default: Any = lf.RAISE_IF_HAS_ERROR,
|
248
|
+
*,
|
249
|
+
lm: lf.LanguageModel | None = None,
|
250
|
+
examples: list[lf_structured.MappingExample] | None = None,
|
251
|
+
**kwargs
|
252
|
+
) -> Any:
|
253
|
+
"""Calls `lf.query` and associates it with the current invocation."""
|
254
|
+
with lf_structured.track_queries() as queries:
|
255
|
+
output = lf_structured.query(
|
256
|
+
prompt,
|
257
|
+
schema=schema,
|
258
|
+
default=default,
|
259
|
+
lm=lm,
|
260
|
+
examples=examples,
|
261
|
+
**kwargs
|
262
|
+
)
|
263
|
+
with pg.notify_on_change(False):
|
264
|
+
self.current_invocation.execution.extend(queries)
|
265
|
+
return output
|
216
266
|
|
217
267
|
def _log(self, level: lf.logging.LogLevel, message: str, **kwargs):
|
218
268
|
with pg.notify_on_change(False):
|
@@ -17,6 +17,7 @@ import unittest
|
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core.agentic import action as action_lib
|
20
|
+
from langfun.core.llms import fake
|
20
21
|
|
21
22
|
|
22
23
|
class SessionTest(unittest.TestCase):
|
@@ -26,25 +27,35 @@ class SessionTest(unittest.TestCase):
|
|
26
27
|
|
27
28
|
class Bar(action_lib.Action):
|
28
29
|
|
29
|
-
def call(self, session, **kwargs):
|
30
|
+
def call(self, session, *, lm, **kwargs):
|
30
31
|
test.assertIs(session.current_invocation.action, self)
|
31
32
|
session.info('Begin Bar')
|
33
|
+
session.query('bar', lm=lm)
|
32
34
|
return 2
|
33
35
|
|
34
36
|
class Foo(action_lib.Action):
|
35
37
|
x: int
|
36
38
|
|
37
|
-
def call(self, session, **kwargs):
|
39
|
+
def call(self, session, *, lm, **kwargs):
|
38
40
|
test.assertIs(session.current_invocation.action, self)
|
39
41
|
session.info('Begin Foo', x=1)
|
40
|
-
|
42
|
+
session.query('foo', lm=lm)
|
43
|
+
return self.x + Bar()(session, lm=lm)
|
41
44
|
|
45
|
+
lm = fake.StaticResponse('lm response')
|
42
46
|
session = action_lib.Session()
|
43
47
|
root = session.root_invocation
|
44
48
|
self.assertIsInstance(root.action, action_lib.RootAction)
|
45
49
|
self.assertIs(session.current_invocation, session.root_invocation)
|
46
|
-
self.assertEqual(Foo(1)(session), 3)
|
50
|
+
self.assertEqual(Foo(1)(session, lm=lm), 3)
|
47
51
|
self.assertEqual(len(session.root_invocation.child_invocations), 1)
|
52
|
+
self.assertEqual(len(list(session.root_invocation.queries())), 0)
|
53
|
+
self.assertEqual(
|
54
|
+
len(list(session.root_invocation.queries(include_children=True))), 2
|
55
|
+
)
|
56
|
+
self.assertEqual(
|
57
|
+
len(list(session.root_invocation.child_invocations[0].queries())), 1
|
58
|
+
)
|
48
59
|
self.assertEqual(len(session.root_invocation.child_invocations[0].logs), 1)
|
49
60
|
self.assertEqual(
|
50
61
|
len(session.root_invocation.child_invocations[0].child_invocations),
|
@@ -55,6 +66,11 @@ class SessionTest(unittest.TestCase):
|
|
55
66
|
.child_invocations[0].child_invocations[0].logs),
|
56
67
|
1
|
57
68
|
)
|
69
|
+
self.assertEqual(
|
70
|
+
len(list(session.root_invocation
|
71
|
+
.child_invocations[0].child_invocations[0].queries())),
|
72
|
+
1
|
73
|
+
)
|
58
74
|
self.assertEqual(
|
59
75
|
len(session.root_invocation
|
60
76
|
.child_invocations[0].child_invocations[0].child_invocations),
|
langfun/core/eval/v2/runners.py
CHANGED
@@ -261,6 +261,9 @@ class RunnerBase(Runner):
|
|
261
261
|
if cache is not None:
|
262
262
|
self.background_run(cache.save)
|
263
263
|
|
264
|
+
# Wait for the background tasks to finish.
|
265
|
+
self._io_pool.shutdown(wait=True)
|
266
|
+
|
264
267
|
@abc.abstractmethod
|
265
268
|
def _run(self, evaluations: list[Evaluation]) -> None:
|
266
269
|
"""Runs multiple evaluations."""
|
langfun/core/llms/__init__.py
CHANGED
@@ -120,25 +120,19 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3
|
|
120
120
|
from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
|
121
121
|
|
122
122
|
from langfun.core.llms.vertexai import VertexAI
|
123
|
-
from langfun.core.llms.vertexai import VertexAIRest
|
124
|
-
from langfun.core.llms.vertexai import VertexAIRestGemini1_5
|
125
123
|
from langfun.core.llms.vertexai import VertexAIGemini1_5
|
126
124
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
|
127
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_Latest
|
128
125
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
|
129
126
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002
|
130
127
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0514
|
131
128
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0409
|
132
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_Latest
|
133
129
|
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5
|
134
130
|
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
|
135
131
|
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
|
136
132
|
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_0514
|
137
133
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1
|
138
134
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision
|
139
|
-
from langfun.core.llms.vertexai import
|
140
|
-
from langfun.core.llms.vertexai import VertexAIPalm2_32K
|
141
|
-
from langfun.core.llms.vertexai import VertexAICustom
|
135
|
+
from langfun.core.llms.vertexai import VertexAIEndpoint
|
142
136
|
|
143
137
|
|
144
138
|
# LLaMA C++ models.
|