langfun 0.1.2.dev202505020804__py3-none-any.whl → 0.1.2.dev202505040804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/agentic/action.py +230 -13
- langfun/core/agentic/action_eval.py +15 -11
- langfun/core/agentic/action_eval_test.py +0 -1
- langfun/core/agentic/action_test.py +153 -12
- langfun/core/llms/gemini.py +19 -7
- langfun/core/llms/gemini_test.py +33 -2
- langfun/core/llms/vertexai.py +12 -0
- langfun/core/llms/vertexai_test.py +17 -0
- {langfun-0.1.2.dev202505020804.dist-info → langfun-0.1.2.dev202505040804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202505020804.dist-info → langfun-0.1.2.dev202505040804.dist-info}/RECORD +13 -13
- {langfun-0.1.2.dev202505020804.dist-info → langfun-0.1.2.dev202505040804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202505020804.dist-info → langfun-0.1.2.dev202505040804.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202505020804.dist-info → langfun-0.1.2.dev202505040804.dist-info}/top_level.txt +0 -0
langfun/core/agentic/action.py
CHANGED
@@ -27,7 +27,152 @@ import pyglove as pg
|
|
27
27
|
|
28
28
|
|
29
29
|
class Action(pg.Object):
|
30
|
-
"""Base class for
|
30
|
+
"""Base class for Langfun's agentic actions.
|
31
|
+
|
32
|
+
# Developing Actions
|
33
|
+
|
34
|
+
In Langfun, an `Action` is a class representing a task an agent can execute.
|
35
|
+
To define custom actions, subclass `lf.agentic.Action` and implement the
|
36
|
+
`call` method, which contains the logic for the action's execution.
|
37
|
+
|
38
|
+
```python
|
39
|
+
class Calculate(lf.agentic.Action):
|
40
|
+
expression: str
|
41
|
+
|
42
|
+
def call(self, session: Session, *, lm: lf.LanguageModel, **kwargs):
|
43
|
+
return session.query(expression, float, lm=lm)
|
44
|
+
```
|
45
|
+
|
46
|
+
Key aspects of the `call` method:
|
47
|
+
|
48
|
+
- `session` (First Argument): An `lf.Session` object required to make queries,
|
49
|
+
perform logging, and add metadata to the action. It also tracks the
|
50
|
+
execution of the action and its sub-actions.
|
51
|
+
|
52
|
+
- Use `session.query(...)` to make calls to a Language Model.
|
53
|
+
- Use `session.debug(...)`, `session.info(...)`, `session.warning(...)`,
|
54
|
+
and `session.error(...)` for adding logs associated with the
|
55
|
+
current action.
|
56
|
+
- Use `session.add_metadata(...)` to associate custom metadata with
|
57
|
+
the current action.
|
58
|
+
|
59
|
+
- Keyword Arguments (e.g., lm): Arguments required for the action's execution
|
60
|
+
(like a language model) should be defined as keyword arguments.
|
61
|
+
|
62
|
+
- **kwargs: Include **kwargs to allow:
|
63
|
+
|
64
|
+
- Users to pass additional arguments to child actions.
|
65
|
+
- The action to gracefully handle extra arguments passed by parent actions.
|
66
|
+
|
67
|
+
# Using Actions
|
68
|
+
|
69
|
+
## Creating Action objects
|
70
|
+
Action objects can be instantiated in two primary ways:
|
71
|
+
|
72
|
+
- Direct instantiation by Users:
|
73
|
+
|
74
|
+
```
|
75
|
+
calculate_action = Calculate(expression='1 + 1')
|
76
|
+
```
|
77
|
+
|
78
|
+
- Generation by Language Models (LLMs): LLMs can generate Action objects when
|
79
|
+
provided with an "action space" (a schema defining possible actions). The
|
80
|
+
LLM populates the action's attributes. User code can then invoke the
|
81
|
+
generated action.
|
82
|
+
|
83
|
+
```python
|
84
|
+
import pyglove as pg
|
85
|
+
import langfun as lf
|
86
|
+
|
87
|
+
# Define possible actions for the LLM
|
88
|
+
class Search(lf.agentic.Action):
|
89
|
+
query: str
|
90
|
+
def call(self, session: lf.Session, *, lm: lf.LanguageModel, **kwargs):
|
91
|
+
# Placeholder for actual search logic
|
92
|
+
return f"Results for: {self.query}"
|
93
|
+
|
94
|
+
class DirectAnswer(lf.agentic.Action):
|
95
|
+
answer: str
|
96
|
+
def call(self, session: lf.Session, *, lm: lf.LanguageModel, **kwargs):
|
97
|
+
return self.answer
|
98
|
+
|
99
|
+
# Define the schema for the LLM's output
|
100
|
+
class NextStep(pg.Object):
|
101
|
+
step_by_step_thoughts: list[str]
|
102
|
+
next_action: Calculate | Search | DirectAnswer
|
103
|
+
|
104
|
+
# Query the LLM to determine the next step
|
105
|
+
next_step = lf.query(
|
106
|
+
'What is the next step for {{question}}?',
|
107
|
+
NextStep,
|
108
|
+
question='why is the sky blue?'
|
109
|
+
)
|
110
|
+
# Execute the action chosen by the LLM
|
111
|
+
result = next_step.next_action()
|
112
|
+
print(result)
|
113
|
+
```
|
114
|
+
|
115
|
+
## Invoking Actions and Managing Sessions:
|
116
|
+
|
117
|
+
When an action is called, the session argument (the first argument to call)
|
118
|
+
is handled as follows:
|
119
|
+
|
120
|
+
- Implicit Session Management: If no session is explicitly provided when
|
121
|
+
calling an action, Langfun automatically creates and passes one.
|
122
|
+
|
123
|
+
```python
|
124
|
+
calc = Calculate(expression='1 + 1')
|
125
|
+
|
126
|
+
# A session is implicitly created and passed here.
|
127
|
+
result = calc()
|
128
|
+
print(result)
|
129
|
+
|
130
|
+
# Access the implicitly created session.
|
131
|
+
# print(calc.session)
|
132
|
+
```
|
133
|
+
|
134
|
+
- Explicit Session Management: You can create and manage `lf.Session` objects
|
135
|
+
explicitly. This is useful for customizing session identifiers or managing
|
136
|
+
a shared context for multiple actions.
|
137
|
+
|
138
|
+
```python
|
139
|
+
calc = Calculate(expression='1 + 1')
|
140
|
+
|
141
|
+
# Explicitly create and pass a session.
|
142
|
+
with lf.Session(id='my_agent_session') as session:
|
143
|
+
result = calc(session=session) # Pass the session explicitly
|
144
|
+
print(result)
|
145
|
+
```
|
146
|
+
|
147
|
+
## Accessing Execution Trajectory:
|
148
|
+
|
149
|
+
After an action is executed, the Session object holds a record of its
|
150
|
+
execution, known as the trajectory. This includes queries made and any
|
151
|
+
sub-actions performed.
|
152
|
+
|
153
|
+
- To access all queries issued directly by the root action:
|
154
|
+
|
155
|
+
```python
|
156
|
+
print(session.root.execution.queries)
|
157
|
+
```
|
158
|
+
- To access all actions issued by the root action and any of its
|
159
|
+
sub-actions (recursively):
|
160
|
+
|
161
|
+
```python
|
162
|
+
print(session.root.execution.all_queries)
|
163
|
+
```
|
164
|
+
- To access all child actions issued by the root action:
|
165
|
+
|
166
|
+
```python
|
167
|
+
print(session.root.execution.actions)
|
168
|
+
```
|
169
|
+
|
170
|
+
- To access all the actions in the sub-tree issued by the root action:
|
171
|
+
|
172
|
+
```python
|
173
|
+
print(session.root.execution.all_actions)
|
174
|
+
```
|
175
|
+
"""
|
31
176
|
|
32
177
|
def _on_bound(self):
|
33
178
|
super()._on_bound()
|
@@ -60,6 +205,8 @@ class Action(pg.Object):
|
|
60
205
|
"""Executes the action."""
|
61
206
|
if session is None:
|
62
207
|
session = Session()
|
208
|
+
session.start()
|
209
|
+
|
63
210
|
if show_progress:
|
64
211
|
lf.console.display(pg.view(session, name='agent_session'))
|
65
212
|
|
@@ -107,8 +254,14 @@ class Action(pg.Object):
|
|
107
254
|
action=self,
|
108
255
|
error=error
|
109
256
|
)
|
257
|
+
if self._session is not None:
|
258
|
+
self._session.end(result=None, error=error)
|
110
259
|
raise
|
111
|
-
|
260
|
+
|
261
|
+
if self._session is not None:
|
262
|
+
# Session is created by current action. Stop the session.
|
263
|
+
self._session.end(result)
|
264
|
+
return result
|
112
265
|
|
113
266
|
@abc.abstractmethod
|
114
267
|
def call(self, session: 'Session', **kwargs) -> Any:
|
@@ -229,9 +382,6 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
229
382
|
remove_class=['running'],
|
230
383
|
)
|
231
384
|
|
232
|
-
def __len__(self) -> int:
|
233
|
-
return len(self.items)
|
234
|
-
|
235
385
|
@property
|
236
386
|
def has_started(self) -> bool:
|
237
387
|
return self.start_time is not None
|
@@ -306,6 +456,22 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
306
456
|
for x in branch._iter_subtree(item_cls): # pylint: disable=protected-access
|
307
457
|
yield x
|
308
458
|
|
459
|
+
#
|
460
|
+
# Shortcut methods to operate on the execution trace.
|
461
|
+
#
|
462
|
+
|
463
|
+
def __len__(self) -> int:
|
464
|
+
return len(self.items)
|
465
|
+
|
466
|
+
def __iter__(self) -> Iterator[TracedItem]:
|
467
|
+
return iter(self.items)
|
468
|
+
|
469
|
+
def __bool__(self) -> bool:
|
470
|
+
return bool(self.items)
|
471
|
+
|
472
|
+
def __getitem__(self, index: int) -> TracedItem:
|
473
|
+
return self.items[index]
|
474
|
+
|
309
475
|
def append(self, item: TracedItem) -> None:
|
310
476
|
"""Appends an item to the sequence."""
|
311
477
|
with pg.notify_on_change(False):
|
@@ -935,6 +1101,44 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
935
1101
|
skip_notification=True
|
936
1102
|
)
|
937
1103
|
|
1104
|
+
def start(self) -> None:
|
1105
|
+
"""Starts the session."""
|
1106
|
+
self.root.execution.start()
|
1107
|
+
|
1108
|
+
def end(
|
1109
|
+
self,
|
1110
|
+
result: Any,
|
1111
|
+
error: pg.utils.ErrorInfo | None = None,
|
1112
|
+
metadata: dict[str, Any] | None = None,
|
1113
|
+
) -> None:
|
1114
|
+
"""Ends the session."""
|
1115
|
+
self.root.end(result, error, metadata)
|
1116
|
+
|
1117
|
+
def __enter__(self):
|
1118
|
+
"""Enters the session."""
|
1119
|
+
self.start()
|
1120
|
+
return self
|
1121
|
+
|
1122
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
1123
|
+
"""Exits the session."""
|
1124
|
+
# We allow users to explicitly end the session with specified result
|
1125
|
+
# and metadata.
|
1126
|
+
if self.root.execution.has_stopped:
|
1127
|
+
return
|
1128
|
+
|
1129
|
+
if exc_val is not None:
|
1130
|
+
result, metadata = None, None
|
1131
|
+
error = pg.utils.ErrorInfo.from_exception(exc_val)
|
1132
|
+
else:
|
1133
|
+
actions = self.root.actions
|
1134
|
+
if actions:
|
1135
|
+
result = actions[-1].result
|
1136
|
+
error = actions[-1].error
|
1137
|
+
metadata = actions[-1].metadata
|
1138
|
+
else:
|
1139
|
+
result, error, metadata = None, None, None
|
1140
|
+
self.end(result, error, metadata)
|
1141
|
+
|
938
1142
|
#
|
939
1143
|
# Context-manager for information tracking.
|
940
1144
|
#
|
@@ -942,8 +1146,12 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
942
1146
|
@contextlib.contextmanager
|
943
1147
|
def track_action(self, action: Action) -> Iterator[ActionInvocation]:
|
944
1148
|
"""Track the execution of an action."""
|
945
|
-
if not self.
|
946
|
-
|
1149
|
+
if not self.root.execution.has_started:
|
1150
|
+
raise ValueError(
|
1151
|
+
'Please call `Session.start() / Session.end()` explicitly, '
|
1152
|
+
'or use `with Session(...) as session: ...` context manager to '
|
1153
|
+
'signal the start and end of the session.'
|
1154
|
+
)
|
947
1155
|
|
948
1156
|
invocation = ActionInvocation(pg.maybe_ref(action))
|
949
1157
|
action._invocation = invocation # pylint: disable=protected-access
|
@@ -960,12 +1168,6 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
960
1168
|
finally:
|
961
1169
|
self._current_execution = parent_execution
|
962
1170
|
self._current_action = parent_action
|
963
|
-
if parent_action is self.root:
|
964
|
-
parent_action.end(
|
965
|
-
result=invocation.result,
|
966
|
-
metadata=invocation.metadata,
|
967
|
-
error=invocation.error
|
968
|
-
)
|
969
1171
|
|
970
1172
|
@contextlib.contextmanager
|
971
1173
|
def track_phase(self, name: str | None) -> Iterator[ExecutionTrace]:
|
@@ -1255,6 +1457,21 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
1255
1457
|
"""Returns the final result of the session."""
|
1256
1458
|
return self.root.result
|
1257
1459
|
|
1460
|
+
@property
|
1461
|
+
def has_started(self) -> bool:
|
1462
|
+
"""Returns whether the session has started."""
|
1463
|
+
return self.root.execution.has_started
|
1464
|
+
|
1465
|
+
@property
|
1466
|
+
def has_stopped(self) -> bool:
|
1467
|
+
"""Returns whether the session has stopped."""
|
1468
|
+
return self.root.execution.has_stopped
|
1469
|
+
|
1470
|
+
@property
|
1471
|
+
def has_error(self) -> bool:
|
1472
|
+
"""Returns whether the session has an error."""
|
1473
|
+
return self.root.has_error
|
1474
|
+
|
1258
1475
|
@property
|
1259
1476
|
def current_action(self) -> ActionInvocation:
|
1260
1477
|
"""Returns the current invocation."""
|
@@ -34,17 +34,20 @@ class ActionEval(lf.eval.v2.Evaluation):
|
|
34
34
|
def process(self, example: lf.eval.v2.Example) -> tuple[str, dict[str, Any]]:
|
35
35
|
example_input = example.input
|
36
36
|
action = example_input.action
|
37
|
-
session = action_lib.Session(id=f'{self.id}#example-{example.id}')
|
38
37
|
|
39
|
-
#
|
40
|
-
|
41
|
-
|
42
|
-
|
38
|
+
# We explicitly create a session here to use a custom session ID.
|
39
|
+
with action_lib.Session(id=f'{self.id}#example-{example.id}') as session:
|
40
|
+
|
41
|
+
# NOTE(daiyip): Setting session as metadata before action execution, so we
|
42
|
+
# could use `Evaluation.state.in_progress_examples` to access the session
|
43
|
+
# for status reporting from other threads.
|
44
|
+
example.metadata['session'] = session
|
45
|
+
|
46
|
+
with lf.logging.use_log_level('fatal'):
|
47
|
+
kwargs = self.action_args.copy()
|
48
|
+
kwargs.update(verbose=True)
|
49
|
+
action(session=session, **kwargs)
|
43
50
|
|
44
|
-
with lf.logging.use_log_level('fatal'):
|
45
|
-
kwargs = self.action_args.copy()
|
46
|
-
kwargs.update(verbose=True)
|
47
|
-
action(session=session, **kwargs)
|
48
51
|
return session.final_result, dict(session=session)
|
49
52
|
|
50
53
|
#
|
@@ -76,8 +79,9 @@ class ActionEvalV1(lf_eval.Matching):
|
|
76
79
|
|
77
80
|
def process(self, example: pg.Dict, **kwargs):
|
78
81
|
action = example.action
|
79
|
-
|
80
|
-
|
82
|
+
with action_lib.Session(
|
83
|
+
id=str(getattr(example, 'id', '<empty>'))) as session:
|
84
|
+
action(session=session, lm=self.lm, **kwargs)
|
81
85
|
return session.as_message()
|
82
86
|
|
83
87
|
def answer(self, output: Any, example: pg.Dict) -> Any:
|
@@ -98,7 +98,7 @@ class ExecutionTraceTest(unittest.TestCase):
|
|
98
98
|
self.assertEqual(action_invocation.execution.id, '/a1')
|
99
99
|
|
100
100
|
root.execution.reset()
|
101
|
-
self.
|
101
|
+
self.assertFalse(root.execution)
|
102
102
|
|
103
103
|
|
104
104
|
class SessionTest(unittest.TestCase):
|
@@ -112,12 +112,18 @@ class SessionTest(unittest.TestCase):
|
|
112
112
|
|
113
113
|
session = action_lib.Session(id='agent@1')
|
114
114
|
self.assertEqual(session.id, 'agent@1')
|
115
|
+
self.assertFalse(session.has_started)
|
116
|
+
self.assertFalse(session.has_stopped)
|
115
117
|
|
116
118
|
# Render HTML view to trigger dynamic update during execution.
|
117
119
|
_ = session.to_html()
|
118
120
|
|
119
|
-
|
121
|
+
with session:
|
122
|
+
result = foo(session, lm=lm, verbose=True)
|
120
123
|
|
124
|
+
self.assertTrue(session.has_started)
|
125
|
+
self.assertTrue(session.has_stopped)
|
126
|
+
self.assertEqual(result, 3)
|
121
127
|
self.assertIsNone(foo.session)
|
122
128
|
self.assertEqual(foo.result, 3)
|
123
129
|
self.assertEqual(
|
@@ -135,8 +141,8 @@ class SessionTest(unittest.TestCase):
|
|
135
141
|
self.assertIsNone(root.parent_action)
|
136
142
|
self.assertEqual(root.id, 'agent@1:')
|
137
143
|
self.assertEqual(root.execution.id, 'agent@1:')
|
138
|
-
self.assertEqual(len(root.execution
|
139
|
-
self.assertIs(root.execution
|
144
|
+
self.assertEqual(len(root.execution), 1)
|
145
|
+
self.assertIs(root.execution[0].action, foo)
|
140
146
|
|
141
147
|
self.assertTrue(root.execution.has_started)
|
142
148
|
self.assertTrue(root.execution.has_stopped)
|
@@ -160,14 +166,14 @@ class SessionTest(unittest.TestCase):
|
|
160
166
|
self.assertEqual(root.usage_summary.total.num_requests, 6)
|
161
167
|
|
162
168
|
# Inspecting the top-level action (Foo)
|
163
|
-
foo_invocation = root.execution
|
169
|
+
foo_invocation = root.execution[0]
|
164
170
|
self.assertIs(foo_invocation.parent_action, root)
|
165
171
|
self.assertEqual(foo_invocation.id, 'agent@1:/a1')
|
166
172
|
self.assertEqual(foo_invocation.execution.id, 'agent@1:/a1')
|
167
173
|
self.assertEqual(len(foo_invocation.execution.items), 4)
|
168
174
|
|
169
175
|
# Prepare phase.
|
170
|
-
prepare_phase = foo_invocation.execution
|
176
|
+
prepare_phase = foo_invocation.execution[0]
|
171
177
|
self.assertIsInstance(prepare_phase, action_lib.ExecutionTrace)
|
172
178
|
self.assertEqual(prepare_phase.id, 'agent@1:/a1/prepare')
|
173
179
|
self.assertEqual(len(prepare_phase.items), 2)
|
@@ -179,7 +185,7 @@ class SessionTest(unittest.TestCase):
|
|
179
185
|
self.assertEqual(prepare_phase.items[1].id, 'agent@1:/a1/prepare/q1')
|
180
186
|
|
181
187
|
# Tracked queries.
|
182
|
-
query_invocation = foo_invocation.execution
|
188
|
+
query_invocation = foo_invocation.execution[1]
|
183
189
|
self.assertIsInstance(query_invocation, lf_structured.QueryInvocation)
|
184
190
|
self.assertEqual(query_invocation.id, 'agent@1:/a1/q2')
|
185
191
|
self.assertIs(query_invocation.lm, lm)
|
@@ -197,7 +203,7 @@ class SessionTest(unittest.TestCase):
|
|
197
203
|
)
|
198
204
|
|
199
205
|
# Tracked parallel executions.
|
200
|
-
parallel_executions = foo_invocation.execution
|
206
|
+
parallel_executions = foo_invocation.execution[2]
|
201
207
|
self.assertEqual(parallel_executions.id, 'agent@1:/a1/p1')
|
202
208
|
self.assertIsInstance(parallel_executions, action_lib.ParallelExecutions)
|
203
209
|
self.assertEqual(len(parallel_executions), 3)
|
@@ -209,7 +215,7 @@ class SessionTest(unittest.TestCase):
|
|
209
215
|
self.assertEqual(len(parallel_executions[2].queries), 1)
|
210
216
|
|
211
217
|
# Invocation to Bar.
|
212
|
-
bar_invocation = foo_invocation.execution
|
218
|
+
bar_invocation = foo_invocation.execution[3]
|
213
219
|
self.assertIs(bar_invocation.parent_action, foo_invocation)
|
214
220
|
self.assertEqual(bar_invocation.id, 'agent@1:/a1/a1')
|
215
221
|
self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
|
@@ -240,10 +246,10 @@ class SessionTest(unittest.TestCase):
|
|
240
246
|
root = session.root
|
241
247
|
self.assertRegex(root.id, 'agent@.*:')
|
242
248
|
self.assertTrue(root.has_error)
|
243
|
-
foo_invocation = root.execution
|
249
|
+
foo_invocation = root.execution[0]
|
244
250
|
self.assertIsInstance(foo_invocation, action_lib.ActionInvocation)
|
245
251
|
self.assertTrue(foo_invocation.has_error)
|
246
|
-
bar_invocation = foo_invocation.execution
|
252
|
+
bar_invocation = foo_invocation.execution[3]
|
247
253
|
self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
|
248
254
|
self.assertTrue(bar_invocation.has_error)
|
249
255
|
|
@@ -265,11 +271,146 @@ class SessionTest(unittest.TestCase):
|
|
265
271
|
root = session.root
|
266
272
|
self.assertRegex(root.id, 'agent@.*:')
|
267
273
|
self.assertTrue(root.has_error)
|
268
|
-
foo_invocation = root.execution
|
274
|
+
foo_invocation = root.execution[0]
|
269
275
|
self.assertIsInstance(foo_invocation, action_lib.ActionInvocation)
|
270
276
|
self.assertTrue(foo_invocation.has_error)
|
271
277
|
self.assertEqual(len(foo_invocation.execution.items), 2)
|
272
278
|
|
279
|
+
def test_succeeded_with_implicit_session(self):
|
280
|
+
lm = fake.StaticResponse('lm response')
|
281
|
+
foo = Foo(1)
|
282
|
+
foo(lm=lm, verbose=True)
|
283
|
+
session = foo.session
|
284
|
+
self.assertIsNotNone(session)
|
285
|
+
self.assertIsInstance(session.root.action, action_lib.RootAction)
|
286
|
+
self.assertIs(session.current_action, session.root)
|
287
|
+
self.assertTrue(session.has_started)
|
288
|
+
self.assertTrue(session.has_stopped)
|
289
|
+
self.assertEqual(session.final_result, 3)
|
290
|
+
self.assertFalse(session.root.has_error)
|
291
|
+
self.assertEqual(session.root.metadata, {})
|
292
|
+
|
293
|
+
def test_failed_with_implicit_session(self):
|
294
|
+
lm = fake.StaticResponse('lm response')
|
295
|
+
foo = Foo(1, simulate_action_error=True)
|
296
|
+
with self.assertRaisesRegex(ValueError, 'Bar error'):
|
297
|
+
foo(lm=lm)
|
298
|
+
session = foo.session
|
299
|
+
self.assertIsNotNone(session)
|
300
|
+
self.assertIsInstance(session.root.action, action_lib.RootAction)
|
301
|
+
self.assertIs(session.current_action, session.root)
|
302
|
+
self.assertTrue(session.has_started)
|
303
|
+
self.assertTrue(session.has_stopped)
|
304
|
+
self.assertTrue(session.has_error)
|
305
|
+
self.assertIsInstance(session.root.error, pg.utils.ErrorInfo)
|
306
|
+
self.assertIn('Bar error', str(session.root.error))
|
307
|
+
|
308
|
+
def test_succeeded_with_explicit_session(self):
|
309
|
+
lm = fake.StaticResponse('lm response')
|
310
|
+
foo = Foo(1)
|
311
|
+
self.assertIsNone(foo.session)
|
312
|
+
self.assertIsNone(foo.result)
|
313
|
+
self.assertIsNone(foo.metadata)
|
314
|
+
|
315
|
+
session = action_lib.Session(id='agent@1')
|
316
|
+
self.assertEqual(session.id, 'agent@1')
|
317
|
+
self.assertFalse(session.has_started)
|
318
|
+
self.assertFalse(session.has_stopped)
|
319
|
+
|
320
|
+
with session:
|
321
|
+
result = foo(session, lm=lm, verbose=True)
|
322
|
+
|
323
|
+
self.assertTrue(session.has_started)
|
324
|
+
self.assertTrue(session.has_stopped)
|
325
|
+
self.assertEqual(result, 3)
|
326
|
+
self.assertIsNone(foo.session)
|
327
|
+
self.assertEqual(foo.result, 3)
|
328
|
+
self.assertEqual(
|
329
|
+
foo.metadata, dict(note='foo', subtask_0=0, subtask_1=1, subtask_2=2)
|
330
|
+
)
|
331
|
+
self.assertIs(session.final_result, foo.result)
|
332
|
+
self.assertFalse(session.has_error)
|
333
|
+
|
334
|
+
def test_succeeded_with_explicit_session_start_end(self):
|
335
|
+
lm = fake.StaticResponse('lm response')
|
336
|
+
foo = Foo(1)
|
337
|
+
self.assertIsNone(foo.session)
|
338
|
+
self.assertIsNone(foo.result)
|
339
|
+
self.assertIsNone(foo.metadata)
|
340
|
+
|
341
|
+
session = action_lib.Session(id='agent@1')
|
342
|
+
self.assertEqual(session.id, 'agent@1')
|
343
|
+
self.assertFalse(session.has_started)
|
344
|
+
self.assertFalse(session.has_stopped)
|
345
|
+
|
346
|
+
session.start()
|
347
|
+
result = foo(session, lm=lm, verbose=True)
|
348
|
+
session.end(result)
|
349
|
+
|
350
|
+
self.assertTrue(session.has_started)
|
351
|
+
self.assertTrue(session.has_stopped)
|
352
|
+
self.assertEqual(result, 3)
|
353
|
+
self.assertIsNone(foo.session)
|
354
|
+
self.assertEqual(foo.result, 3)
|
355
|
+
self.assertEqual(
|
356
|
+
foo.metadata, dict(note='foo', subtask_0=0, subtask_1=1, subtask_2=2)
|
357
|
+
)
|
358
|
+
self.assertIs(session.final_result, foo.result)
|
359
|
+
self.assertFalse(session.has_error)
|
360
|
+
|
361
|
+
def test_failed_with_explicit_session(self):
|
362
|
+
lm = fake.StaticResponse('lm response')
|
363
|
+
foo = Foo(1, simulate_action_error=True)
|
364
|
+
session = action_lib.Session(id='agent@1')
|
365
|
+
with self.assertRaisesRegex(ValueError, 'Bar error'):
|
366
|
+
with session:
|
367
|
+
foo(session, lm=lm, verbose=True)
|
368
|
+
self.assertTrue(session.has_started)
|
369
|
+
self.assertTrue(session.has_stopped)
|
370
|
+
self.assertTrue(session.has_error)
|
371
|
+
self.assertIsNone(session.final_result)
|
372
|
+
self.assertIsInstance(session.root.error, pg.utils.ErrorInfo)
|
373
|
+
self.assertIn('Bar error', str(session.root.error))
|
374
|
+
|
375
|
+
def test_failed_with_explicit_session_without_start(self):
|
376
|
+
lm = fake.StaticResponse('lm response')
|
377
|
+
foo = Foo(1, simulate_action_error=True)
|
378
|
+
session = action_lib.Session(id='agent@1')
|
379
|
+
with self.assertRaisesRegex(ValueError, 'Please call `Session.start'):
|
380
|
+
foo(session, lm=lm, verbose=True)
|
381
|
+
|
382
|
+
def test_succeed_with_multiple_actions(self):
|
383
|
+
lm = fake.StaticResponse('lm response')
|
384
|
+
with action_lib.Session() as session:
|
385
|
+
x = Bar()(session, lm=lm)
|
386
|
+
y = Bar()(session, lm=lm)
|
387
|
+
self.assertTrue(session.has_started)
|
388
|
+
self.assertFalse(session.has_stopped)
|
389
|
+
session.add_metadata(note='root metadata')
|
390
|
+
session.end(x + y)
|
391
|
+
|
392
|
+
self.assertTrue(session.has_started)
|
393
|
+
self.assertTrue(session.has_stopped)
|
394
|
+
self.assertEqual(session.final_result, 2 + 2)
|
395
|
+
self.assertEqual(len(session.root.execution), 2)
|
396
|
+
self.assertEqual(session.root.metadata, dict(note='root metadata'))
|
397
|
+
|
398
|
+
def test_failed_with_multiple_actions(self):
|
399
|
+
lm = fake.StaticResponse('lm response')
|
400
|
+
with self.assertRaisesRegex(ValueError, 'Bar error'):
|
401
|
+
with action_lib.Session() as session:
|
402
|
+
x = Bar()(session, lm=lm)
|
403
|
+
y = Bar(simulate_action_error=True)(session, lm=lm)
|
404
|
+
session.end(x + y)
|
405
|
+
|
406
|
+
self.assertTrue(session.has_started)
|
407
|
+
self.assertTrue(session.has_stopped)
|
408
|
+
self.assertTrue(session.has_error)
|
409
|
+
self.assertIsInstance(session.root.error, pg.utils.ErrorInfo)
|
410
|
+
self.assertEqual(len(session.root.execution), 2)
|
411
|
+
self.assertFalse(session.root.execution[0].has_error)
|
412
|
+
self.assertTrue(session.root.execution[1].has_error)
|
413
|
+
|
273
414
|
def test_log(self):
|
274
415
|
session = action_lib.Session()
|
275
416
|
session.debug('hi', x=1, y=2)
|
langfun/core/llms/gemini.py
CHANGED
@@ -605,13 +605,13 @@ class Gemini(rest.REST):
|
|
605
605
|
raise lf.ModalityError(f'Unsupported modality: {chunk!r}') from e
|
606
606
|
return chunk
|
607
607
|
|
608
|
-
contents = []
|
609
608
|
if system_message := prompt.get('system_message'):
|
610
609
|
assert isinstance(system_message, lf.SystemMessage), type(system_message)
|
611
|
-
|
612
|
-
|
613
|
-
'gemini', chunk_preprocessor=modality_conversion)
|
610
|
+
request['systemInstruction'] = system_message.as_format(
|
611
|
+
'gemini', chunk_preprocessor=modality_conversion
|
614
612
|
)
|
613
|
+
|
614
|
+
contents = []
|
615
615
|
contents.append(
|
616
616
|
prompt.as_format('gemini', chunk_preprocessor=modality_conversion)
|
617
617
|
)
|
@@ -647,6 +647,11 @@ class Gemini(rest.REST):
|
|
647
647
|
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
648
648
|
+ pg.to_json_str(json_schema, json_indent=2)
|
649
649
|
)
|
650
|
+
if options.max_thinking_tokens is not None:
|
651
|
+
config['thinkingConfig'] = {
|
652
|
+
'thinkingBudget': options.max_thinking_tokens
|
653
|
+
}
|
654
|
+
|
650
655
|
return config
|
651
656
|
|
652
657
|
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
@@ -659,18 +664,25 @@ class Gemini(rest.REST):
|
|
659
664
|
# NOTE(daiyip): We saw cases that `candidatesTokenCount` is not present.
|
660
665
|
# Therefore, we use 0 as the default value.
|
661
666
|
output_tokens = usage.get('candidatesTokenCount', 0)
|
667
|
+
thinking_tokens = usage.get('thoughtsTokenCount', 0)
|
668
|
+
total_tokens = usage.get('totalTokenCount', 0)
|
662
669
|
|
663
670
|
return lf.LMSamplingResult(
|
664
671
|
[lf.LMSample(message) for message in messages],
|
665
672
|
usage=lf.LMSamplingUsage(
|
666
673
|
prompt_tokens=input_tokens,
|
667
674
|
completion_tokens=output_tokens,
|
668
|
-
total_tokens=
|
675
|
+
total_tokens=total_tokens,
|
676
|
+
completion_tokens_details={
|
677
|
+
'thinking_tokens': thinking_tokens,
|
678
|
+
},
|
669
679
|
),
|
670
680
|
)
|
671
681
|
|
672
682
|
def _error(self, status_code: int, content: str) -> lf.LMError:
|
673
|
-
if (
|
674
|
-
|
683
|
+
if (
|
684
|
+
status_code == 400
|
685
|
+
and b'exceeds the maximum number of tokens' in content
|
686
|
+
):
|
675
687
|
return lf.ContextLimitError(f'{status_code}: {content}')
|
676
688
|
return super()._error(status_code, content)
|
langfun/core/llms/gemini_test.py
CHANGED
@@ -38,14 +38,21 @@ example_image = (
|
|
38
38
|
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
39
39
|
del url, kwargs
|
40
40
|
c = pg.Dict(json['generationConfig'])
|
41
|
-
|
41
|
+
parts = []
|
42
|
+
if system_instruction := json.get('systemInstruction'):
|
43
|
+
parts.extend([p['text'] for p in system_instruction.get('parts', [])])
|
44
|
+
|
45
|
+
# Add text from the main contents.
|
46
|
+
for c_item in json.get('contents', []):
|
47
|
+
for p in c_item.get('parts', []):
|
48
|
+
parts.append(p['text'])
|
49
|
+
content = '\n'.join(parts)
|
42
50
|
response = requests.Response()
|
43
51
|
response.status_code = 200
|
44
52
|
response._content = pg.to_json_str({
|
45
53
|
'candidates': [
|
46
54
|
{
|
47
55
|
'content': {
|
48
|
-
'role': 'model',
|
49
56
|
'parts': [
|
50
57
|
{
|
51
58
|
'text': (
|
@@ -146,6 +153,30 @@ class GeminiTest(unittest.TestCase):
|
|
146
153
|
}
|
147
154
|
),
|
148
155
|
)
|
156
|
+
|
157
|
+
# Add test for thinkingConfig.
|
158
|
+
actual = model._generation_config(
|
159
|
+
lf.UserMessage('hi'),
|
160
|
+
lf.LMSamplingOptions(
|
161
|
+
max_thinking_tokens=100,
|
162
|
+
),
|
163
|
+
)
|
164
|
+
self.assertEqual(
|
165
|
+
actual,
|
166
|
+
dict(
|
167
|
+
candidateCount=1,
|
168
|
+
temperature=None,
|
169
|
+
topP=None,
|
170
|
+
topK=40,
|
171
|
+
maxOutputTokens=None,
|
172
|
+
stopSequences=None,
|
173
|
+
responseLogprobs=False,
|
174
|
+
logprobs=None,
|
175
|
+
seed=None,
|
176
|
+
thinkingConfig={'thinkingBudget': 100},
|
177
|
+
),
|
178
|
+
)
|
179
|
+
|
149
180
|
with self.assertRaisesRegex(
|
150
181
|
ValueError, '`json_schema` must be a dict, got'
|
151
182
|
):
|
langfun/core/llms/vertexai.py
CHANGED
@@ -28,6 +28,7 @@ import pyglove as pg
|
|
28
28
|
try:
|
29
29
|
# pylint: disable=g-import-not-at-top
|
30
30
|
from google import auth as google_auth
|
31
|
+
from google.auth import exceptions as auth_exceptions
|
31
32
|
from google.auth import credentials as credentials_lib
|
32
33
|
from google.auth.transport import requests as auth_requests
|
33
34
|
# pylint: enable=g-import-not-at-top
|
@@ -35,6 +36,7 @@ try:
|
|
35
36
|
Credentials = credentials_lib.Credentials
|
36
37
|
except ImportError:
|
37
38
|
google_auth = None
|
39
|
+
auth_exceptions = None
|
38
40
|
credentials_lib = None
|
39
41
|
auth_requests = None
|
40
42
|
Credentials = Any
|
@@ -134,6 +136,16 @@ class VertexAI(rest.REST):
|
|
134
136
|
assert auth_requests is not None
|
135
137
|
return auth_requests.AuthorizedSession(self._credentials)
|
136
138
|
|
139
|
+
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
140
|
+
assert auth_exceptions is not None
|
141
|
+
try:
|
142
|
+
return super()._sample_single(prompt)
|
143
|
+
except (
|
144
|
+
auth_exceptions.RefreshError,
|
145
|
+
) as e:
|
146
|
+
raise lf.TemporaryLMError(
|
147
|
+
f'Failed to refresh Google authentication credentials: {e}'
|
148
|
+
) from e
|
137
149
|
|
138
150
|
#
|
139
151
|
# Gemini models served by Vertex AI.
|
@@ -19,6 +19,7 @@ from unittest import mock
|
|
19
19
|
|
20
20
|
from google.auth import exceptions
|
21
21
|
import langfun.core as lf
|
22
|
+
from langfun.core.llms import rest
|
22
23
|
from langfun.core.llms import vertexai
|
23
24
|
import pyglove as pg
|
24
25
|
|
@@ -51,6 +52,22 @@ class VertexAITest(unittest.TestCase):
|
|
51
52
|
del os.environ['VERTEXAI_PROJECT']
|
52
53
|
del os.environ['VERTEXAI_LOCATION']
|
53
54
|
|
55
|
+
def test_auth_refresh_error(self):
|
56
|
+
def _auth_refresh_error(*args, **kwargs):
|
57
|
+
del args, kwargs
|
58
|
+
raise exceptions.RefreshError('Cannot refresh token')
|
59
|
+
|
60
|
+
with self.assertRaisesRegex(
|
61
|
+
lf.concurrent.RetryError,
|
62
|
+
'Failed to refresh Google authentication credentials'
|
63
|
+
):
|
64
|
+
with mock.patch.object(rest.REST, '_sample_single') as mock_sample_single:
|
65
|
+
mock_sample_single.side_effect = _auth_refresh_error
|
66
|
+
model = vertexai.VertexAIGemini15Pro(
|
67
|
+
project='abc', location='us-central1', max_attempts=1
|
68
|
+
)
|
69
|
+
model('hi')
|
70
|
+
|
54
71
|
|
55
72
|
class VertexAIAnthropicTest(unittest.TestCase):
|
56
73
|
"""Tests for VertexAI Anthropic models."""
|
@@ -26,10 +26,10 @@ langfun/core/subscription_test.py,sha256=Y4ZdbZEwm83YNZBxHff0QR4QUa4rdaNXA3_jfIc
|
|
26
26
|
langfun/core/template.py,sha256=jNhYSrbLIn9kZOa03w5QZbyjgfnzJzE_ZrrMvvWY4t4,24929
|
27
27
|
langfun/core/template_test.py,sha256=AQv_m9qE93WxhEhSlm1xaBgB4hu0UVtA53dljngkUW0,17090
|
28
28
|
langfun/core/agentic/__init__.py,sha256=qR3jlfUO4rhIoYdRDLz-d22YZf3FvU4FW88vsjiGDQQ,1224
|
29
|
-
langfun/core/agentic/action.py,sha256=
|
30
|
-
langfun/core/agentic/action_eval.py,sha256=
|
31
|
-
langfun/core/agentic/action_eval_test.py,sha256=
|
32
|
-
langfun/core/agentic/action_test.py,sha256=
|
29
|
+
langfun/core/agentic/action.py,sha256=3m2-k07Zz8qrzOdOa7xPl7fRH3I0c3VsMLR86_JBCcU,45359
|
30
|
+
langfun/core/agentic/action_eval.py,sha256=JXhS5qEjWu9EZ0chDsjWxCqPAV26PUCBijtUYxiDeO4,4975
|
31
|
+
langfun/core/agentic/action_eval_test.py,sha256=7AkOwNbUX-ZgR1R0a7bvUZ5abNTUV7blf_8Mnrwb-II,2811
|
32
|
+
langfun/core/agentic/action_test.py,sha256=ezqg3tKlVwgLMnHKUmOdtxpnntuL8YIvhcTCSSdb8oc,15468
|
33
33
|
langfun/core/coding/__init__.py,sha256=5utju_fwEsImaiftx4oXKl9FAM8p281k8-Esdh_-m1w,835
|
34
34
|
langfun/core/coding/python/__init__.py,sha256=4ByknuoNU-mOIHwHKnTtmo6oD64oMFtlqPlYWmA5Wic,1736
|
35
35
|
langfun/core/coding/python/correction.py,sha256=7zBedlhQKMPA4cfchUMxAOFl6Zl5RqCyllRHGWys40s,7092
|
@@ -92,8 +92,8 @@ langfun/core/llms/deepseek.py,sha256=jvTxdXPr-vH6HNakn_Ootx1heDg8Fen2FUkUW36bpCs
|
|
92
92
|
langfun/core/llms/deepseek_test.py,sha256=DvROWPlDuow5E1lfoSkhyGt_ELA19JoQoDsTnRgDtTg,1847
|
93
93
|
langfun/core/llms/fake.py,sha256=xmgCkk9y0I4x0IT32SZ9_OT27aLadXH8PRiYNo5VTd4,3265
|
94
94
|
langfun/core/llms/fake_test.py,sha256=2h13qkwEz_JR0mtUDPxdAhQo7MueXaFSwsD2DIRDW9g,7653
|
95
|
-
langfun/core/llms/gemini.py,sha256=
|
96
|
-
langfun/core/llms/gemini_test.py,sha256=
|
95
|
+
langfun/core/llms/gemini.py,sha256=ZtUo2lQMSByYlzSALWae3KxFiKNtOOwGbkFwZTI1dO0,24472
|
96
|
+
langfun/core/llms/gemini_test.py,sha256=Ve9X2Wvwu9wVFHpKZDP-qoM1_hzB4kgt6_HR9wxtNkg,7592
|
97
97
|
langfun/core/llms/google_genai.py,sha256=j8W22WFvkT80Fw-r7Rg-e7MKhcSwljZkmtuufwSEn5s,5051
|
98
98
|
langfun/core/llms/google_genai_test.py,sha256=NKNtpebArQ9ZR7Qsnhd2prFIpMjleojy6o6VMXkJ1zY,1502
|
99
99
|
langfun/core/llms/groq.py,sha256=S9V10kFo3cgX89qPgt_umq-SpRnxEDLTt_hJmpERfbo,12066
|
@@ -106,8 +106,8 @@ langfun/core/llms/openai_compatible_test.py,sha256=KwOMA7tsmOxFBjezltkBDSU77AvOQ
|
|
106
106
|
langfun/core/llms/openai_test.py,sha256=gwuO6aoa296iM2welWV9ua4KF8gEVGsEPakgbtkWkFQ,2687
|
107
107
|
langfun/core/llms/rest.py,sha256=MCybcHApJcf49lubLnDzScN9Oc2IWY_JnMHIGdbDOuU,4474
|
108
108
|
langfun/core/llms/rest_test.py,sha256=_zM7nV8DEVyoXNiQOnuwJ917mWjki0614H88rNmDboE,5020
|
109
|
-
langfun/core/llms/vertexai.py,sha256=
|
110
|
-
langfun/core/llms/vertexai_test.py,sha256=
|
109
|
+
langfun/core/llms/vertexai.py,sha256=4t_Noj7cqzLNmESYCYzz9Ndodd_K4I4zxVLmljJ7r3E,18630
|
110
|
+
langfun/core/llms/vertexai_test.py,sha256=0M4jsPOXGagdzPfEdJixmyLdhmmERePZWSFfTwnaYCQ,4875
|
111
111
|
langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
|
112
112
|
langfun/core/llms/cache/base.py,sha256=rt3zwmyw0y9jsSGW-ZbV1vAfLxQ7_3AVk0l2EySlse4,3918
|
113
113
|
langfun/core/llms/cache/in_memory.py,sha256=i58oiQL28RDsq37dwqgVpC2mBETJjIEFS20yHiV5MKU,5185
|
@@ -156,8 +156,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
156
156
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
157
157
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
158
158
|
langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
|
159
|
-
langfun-0.1.2.
|
160
|
-
langfun-0.1.2.
|
161
|
-
langfun-0.1.2.
|
162
|
-
langfun-0.1.2.
|
163
|
-
langfun-0.1.2.
|
159
|
+
langfun-0.1.2.dev202505040804.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
160
|
+
langfun-0.1.2.dev202505040804.dist-info/METADATA,sha256=5IYkqanozH8dVb3Nr0bH2jmybe-FPKrvx7-nAEPNCeM,8178
|
161
|
+
langfun-0.1.2.dev202505040804.dist-info/WHEEL,sha256=GHB6lJx2juba1wDgXDNlMTyM13ckjBMKf-OnwgKOCtA,91
|
162
|
+
langfun-0.1.2.dev202505040804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
163
|
+
langfun-0.1.2.dev202505040804.dist-info/RECORD,,
|
{langfun-0.1.2.dev202505020804.dist-info → langfun-0.1.2.dev202505040804.dist-info}/licenses/LICENSE
RENAMED
File without changes
|
{langfun-0.1.2.dev202505020804.dist-info → langfun-0.1.2.dev202505040804.dist-info}/top_level.txt
RENAMED
File without changes
|