langfun 0.1.2.dev202412020805__py3-none-any.whl → 0.1.2.dev202412050804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/__init__.py CHANGED
@@ -35,6 +35,8 @@ complete = structured.complete
35
35
  score = structured.score
36
36
  generate_class = structured.generate_class
37
37
 
38
+ track_queries = structured.track_queries
39
+
38
40
  # Helper functions for input/output transformations based on
39
41
  # `lf.query` (e.g. jax-on-beam could use these for batch processing)
40
42
  query_prompt = structured.query_prompt
@@ -14,8 +14,10 @@
14
14
  """Base classes for agentic actions."""
15
15
 
16
16
  import abc
17
- from typing import Annotated, Any, Optional, Union
17
+ import contextlib
18
+ from typing import Annotated, Any, Iterable, Iterator, Optional, Type, Union
18
19
  import langfun.core as lf
20
+ from langfun.core import structured as lf_structured
19
21
  import pyglove as pg
20
22
 
21
23
 
@@ -35,12 +37,9 @@ class Action(pg.Object):
35
37
  self, session: Optional['Session'] = None, **kwargs) -> Any:
36
38
  """Executes the action."""
37
39
  session = session or Session()
38
- try:
39
- session.begin(self)
40
+ with session.track(self):
40
41
  self._result = self.call(session=session, **kwargs)
41
42
  return self._result
42
- finally:
43
- session.end(self)
44
43
 
45
44
  @abc.abstractmethod
46
45
  def call(self, session: 'Session', **kwargs) -> Any:
@@ -50,9 +49,20 @@ class Action(pg.Object):
50
49
  class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
51
50
  """A class for capturing the invocation of an action."""
52
51
  action: Action
53
- result: Any = None
52
+
53
+ result: Annotated[
54
+ Any,
55
+ 'The result of the action.'
56
+ ] = None
57
+
54
58
  execution: Annotated[
55
- list[Union['ActionInvocation', lf.logging.LogEntry]],
59
+ list[
60
+ Union[
61
+ lf_structured.QueryInvocation,
62
+ 'ActionInvocation',
63
+ lf.logging.LogEntry
64
+ ]
65
+ ],
56
66
  'Execution execution.'
57
67
  ] = []
58
68
 
@@ -69,6 +79,18 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
69
79
  """Returns child action invocations."""
70
80
  return [v for v in self.execution if isinstance(v, ActionInvocation)]
71
81
 
82
+ def queries(
83
+ self,
84
+ include_children: bool = False
85
+ ) -> Iterable[lf_structured.QueryInvocation]:
86
+ """Iterates over queries from the current invocation."""
87
+ for v in self.execution:
88
+ if isinstance(v, lf_structured.QueryInvocation):
89
+ yield v
90
+ elif isinstance(v, ActionInvocation):
91
+ if include_children:
92
+ yield from v.queries(include_children=True)
93
+
72
94
  def _html_tree_view_summary(
73
95
  self, *, view: pg.views.html.HtmlTreeView, **kwargs
74
96
  ):
@@ -190,29 +212,57 @@ class Session(pg.Object):
190
212
  assert self._invocation_stack
191
213
  return self._invocation_stack[-1]
192
214
 
193
- def begin(self, action: Action):
194
- """Signal the beginning of the execution of an action."""
215
+ @contextlib.contextmanager
216
+ def track(self, action: Action) -> Iterator[ActionInvocation]:
217
+ """Track the execution of an action."""
195
218
  new_invocation = ActionInvocation(pg.maybe_ref(action))
196
219
  with pg.notify_on_change(False):
197
220
  self.current_invocation.execution.append(new_invocation)
198
221
  self._invocation_stack.append(new_invocation)
199
222
 
200
- def end(self, action: Action):
201
- """Signal the end of the execution of an action."""
202
- assert self._invocation_stack
203
- invocation = self._invocation_stack.pop(-1)
204
- invocation.rebind(
205
- result=action.result, skip_notification=True, raise_on_no_change=False
206
- )
207
- assert invocation.action is action, (invocation.action, action)
208
- assert self._invocation_stack, self._invocation_stack
209
-
210
- if len(self._invocation_stack) == 1:
211
- self.root_invocation.rebind(
212
- result=invocation.result,
213
- skip_notification=True,
214
- raise_on_no_change=False
223
+ try:
224
+ yield new_invocation
225
+ finally:
226
+ assert self._invocation_stack
227
+ invocation = self._invocation_stack.pop(-1)
228
+ invocation.rebind(
229
+ result=action.result, skip_notification=True, raise_on_no_change=False
215
230
  )
231
+ assert invocation.action is action, (invocation.action, action)
232
+ assert self._invocation_stack, self._invocation_stack
233
+
234
+ if len(self._invocation_stack) == 1:
235
+ self.root_invocation.rebind(
236
+ result=invocation.result,
237
+ skip_notification=True,
238
+ raise_on_no_change=False
239
+ )
240
+
241
+ def query(
242
+ self,
243
+ prompt: Union[str, lf.Template, Any],
244
+ schema: Union[
245
+ lf_structured.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
246
+ ] = None,
247
+ default: Any = lf.RAISE_IF_HAS_ERROR,
248
+ *,
249
+ lm: lf.LanguageModel | None = None,
250
+ examples: list[lf_structured.MappingExample] | None = None,
251
+ **kwargs
252
+ ) -> Any:
253
+ """Calls `lf.query` and associates it with the current invocation."""
254
+ with lf_structured.track_queries() as queries:
255
+ output = lf_structured.query(
256
+ prompt,
257
+ schema=schema,
258
+ default=default,
259
+ lm=lm,
260
+ examples=examples,
261
+ **kwargs
262
+ )
263
+ with pg.notify_on_change(False):
264
+ self.current_invocation.execution.extend(queries)
265
+ return output
216
266
 
217
267
  def _log(self, level: lf.logging.LogLevel, message: str, **kwargs):
218
268
  with pg.notify_on_change(False):
@@ -17,6 +17,7 @@ import unittest
17
17
 
18
18
  import langfun.core as lf
19
19
  from langfun.core.agentic import action as action_lib
20
+ from langfun.core.llms import fake
20
21
 
21
22
 
22
23
  class SessionTest(unittest.TestCase):
@@ -26,25 +27,35 @@ class SessionTest(unittest.TestCase):
26
27
 
27
28
  class Bar(action_lib.Action):
28
29
 
29
- def call(self, session, **kwargs):
30
+ def call(self, session, *, lm, **kwargs):
30
31
  test.assertIs(session.current_invocation.action, self)
31
32
  session.info('Begin Bar')
33
+ session.query('bar', lm=lm)
32
34
  return 2
33
35
 
34
36
  class Foo(action_lib.Action):
35
37
  x: int
36
38
 
37
- def call(self, session, **kwargs):
39
+ def call(self, session, *, lm, **kwargs):
38
40
  test.assertIs(session.current_invocation.action, self)
39
41
  session.info('Begin Foo', x=1)
40
- return self.x + Bar()(session)
42
+ session.query('foo', lm=lm)
43
+ return self.x + Bar()(session, lm=lm)
41
44
 
45
+ lm = fake.StaticResponse('lm response')
42
46
  session = action_lib.Session()
43
47
  root = session.root_invocation
44
48
  self.assertIsInstance(root.action, action_lib.RootAction)
45
49
  self.assertIs(session.current_invocation, session.root_invocation)
46
- self.assertEqual(Foo(1)(session), 3)
50
+ self.assertEqual(Foo(1)(session, lm=lm), 3)
47
51
  self.assertEqual(len(session.root_invocation.child_invocations), 1)
52
+ self.assertEqual(len(list(session.root_invocation.queries())), 0)
53
+ self.assertEqual(
54
+ len(list(session.root_invocation.queries(include_children=True))), 2
55
+ )
56
+ self.assertEqual(
57
+ len(list(session.root_invocation.child_invocations[0].queries())), 1
58
+ )
48
59
  self.assertEqual(len(session.root_invocation.child_invocations[0].logs), 1)
49
60
  self.assertEqual(
50
61
  len(session.root_invocation.child_invocations[0].child_invocations),
@@ -55,6 +66,11 @@ class SessionTest(unittest.TestCase):
55
66
  .child_invocations[0].child_invocations[0].logs),
56
67
  1
57
68
  )
69
+ self.assertEqual(
70
+ len(list(session.root_invocation
71
+ .child_invocations[0].child_invocations[0].queries())),
72
+ 1
73
+ )
58
74
  self.assertEqual(
59
75
  len(session.root_invocation
60
76
  .child_invocations[0].child_invocations[0].child_invocations),
@@ -261,6 +261,9 @@ class RunnerBase(Runner):
261
261
  if cache is not None:
262
262
  self.background_run(cache.save)
263
263
 
264
+ # Wait for the background tasks to finish.
265
+ self._io_pool.shutdown(wait=True)
266
+
264
267
  @abc.abstractmethod
265
268
  def _run(self, evaluations: list[Evaluation]) -> None:
266
269
  """Runs multiple evaluations."""
@@ -120,25 +120,19 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3
120
120
  from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
121
121
 
122
122
  from langfun.core.llms.vertexai import VertexAI
123
- from langfun.core.llms.vertexai import VertexAIRest
124
- from langfun.core.llms.vertexai import VertexAIRestGemini1_5
125
123
  from langfun.core.llms.vertexai import VertexAIGemini1_5
126
124
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
127
- from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_Latest
128
125
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
129
126
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002
130
127
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0514
131
128
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0409
132
- from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_Latest
133
129
  from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5
134
130
  from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
135
131
  from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
136
132
  from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_0514
137
133
  from langfun.core.llms.vertexai import VertexAIGeminiPro1
138
134
  from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision
139
- from langfun.core.llms.vertexai import VertexAIPalm2
140
- from langfun.core.llms.vertexai import VertexAIPalm2_32K
141
- from langfun.core.llms.vertexai import VertexAICustom
135
+ from langfun.core.llms.vertexai import VertexAIEndpoint
142
136
 
143
137
 
144
138
  # LLaMA C++ models.