langfun 0.1.2.dev202505020804__py3-none-any.whl → 0.1.2.dev202505030803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

@@ -27,7 +27,152 @@ import pyglove as pg
27
27
 
28
28
 
29
29
  class Action(pg.Object):
30
- """Base class for agent actions."""
30
+ """Base class for Langfun's agentic actions.
31
+
32
+ # Developing Actions
33
+
34
+ In Langfun, an `Action` is a class representing a task an agent can execute.
35
+ To define custom actions, subclass `lf.agentic.Action` and implement the
36
+ `call` method, which contains the logic for the action's execution.
37
+
38
+ ```python
39
+ class Calculate(lf.agentic.Action):
40
+ expression: str
41
+
42
+ def call(self, session: Session, *, lm: lf.LanguageModel, **kwargs):
43
+ return session.query(expression, float, lm=lm)
44
+ ```
45
+
46
+ Key aspects of the `call` method:
47
+
48
+ - `session` (First Argument): An `lf.Session` object required to make queries,
49
+ perform logging, and add metadata to the action. It also tracks the
50
+ execution of the action and its sub-actions.
51
+
52
+ - Use `session.query(...)` to make calls to a Language Model.
53
+ - Use `session.debug(...)`, `session.info(...)`, `session.warning(...)`,
54
+ and `session.error(...)` for adding logs associated with the
55
+ current action.
56
+ - Use `session.add_metadata(...)` to associate custom metadata with
57
+ the current action.
58
+
59
+ - Keyword Arguments (e.g., lm): Arguments required for the action's execution
60
+ (like a language model) should be defined as keyword arguments.
61
+
62
+ - **kwargs: Include **kwargs to allow:
63
+
64
+ - Users to pass additional arguments to child actions.
65
+ - The action to gracefully handle extra arguments passed by parent actions.
66
+
67
+ # Using Actions
68
+
69
+ ## Creating Action objects
70
+ Action objects can be instantiated in two primary ways:
71
+
72
+ - Direct instantiation by Users:
73
+
74
+ ```
75
+ calculate_action = Calculate(expression='1 + 1')
76
+ ```
77
+
78
+ - Generation by Language Models (LLMs): LLMs can generate Action objects when
79
+ provided with an "action space" (a schema defining possible actions). The
80
+ LLM populates the action's attributes. User code can then invoke the
81
+ generated action.
82
+
83
+ ```python
84
+ import pyglove as pg
85
+ import langfun as lf
86
+
87
+ # Define possible actions for the LLM
88
+ class Search(lf.agentic.Action):
89
+ query: str
90
+ def call(self, session: lf.Session, *, lm: lf.LanguageModel, **kwargs):
91
+ # Placeholder for actual search logic
92
+ return f"Results for: {self.query}"
93
+
94
+ class DirectAnswer(lf.agentic.Action):
95
+ answer: str
96
+ def call(self, session: lf.Session, *, lm: lf.LanguageModel, **kwargs):
97
+ return self.answer
98
+
99
+ # Define the schema for the LLM's output
100
+ class NextStep(pg.Object):
101
+ step_by_step_thoughts: list[str]
102
+ next_action: Calculate | Search | DirectAnswer
103
+
104
+ # Query the LLM to determine the next step
105
+ next_step = lf.query(
106
+ 'What is the next step for {{question}}?',
107
+ NextStep,
108
+ question='why is the sky blue?'
109
+ )
110
+ # Execute the action chosen by the LLM
111
+ result = next_step.next_action()
112
+ print(result)
113
+ ```
114
+
115
+ ## Invoking Actions and Managing Sessions:
116
+
117
+ When an action is called, the session argument (the first argument to call)
118
+ is handled as follows:
119
+
120
+ - Implicit Session Management: If no session is explicitly provided when
121
+ calling an action, Langfun automatically creates and passes one.
122
+
123
+ ```python
124
+ calc = Calculate(expression='1 + 1')
125
+
126
+ # A session is implicitly created and passed here.
127
+ result = calc()
128
+ print(result)
129
+
130
+ # Access the implicitly created session.
131
+ # print(calc.session)
132
+ ```
133
+
134
+ - Explicit Session Management: You can create and manage `lf.Session` objects
135
+ explicitly. This is useful for customizing session identifiers or managing
136
+ a shared context for multiple actions.
137
+
138
+ ```python
139
+ calc = Calculate(expression='1 + 1')
140
+
141
+ # Explicitly create and pass a session.
142
+ with lf.Session(id='my_agent_session') as session:
143
+ result = calc(session=session) # Pass the session explicitly
144
+ print(result)
145
+ ```
146
+
147
+ ## Accessing Execution Trajectory:
148
+
149
+ After an action is executed, the Session object holds a record of its
150
+ execution, known as the trajectory. This includes queries made and any
151
+ sub-actions performed.
152
+
153
+ - To access all queries issued directly by the root action:
154
+
155
+ ```python
156
+ print(session.root.execution.queries)
157
+ ```
158
+ - To access all actions issued by the root action and any of its
159
+ sub-actions (recursively):
160
+
161
+ ```python
162
+ print(session.root.execution.all_queries)
163
+ ```
164
+ - To access all child actions issued by the root action:
165
+
166
+ ```python
167
+ print(session.root.execution.actions)
168
+ ```
169
+
170
+ - To access all the actions in the sub-tree issued by the root action:
171
+
172
+ ```python
173
+ print(session.root.execution.all_actions)
174
+ ```
175
+ """
31
176
 
32
177
  def _on_bound(self):
33
178
  super()._on_bound()
@@ -60,6 +205,8 @@ class Action(pg.Object):
60
205
  """Executes the action."""
61
206
  if session is None:
62
207
  session = Session()
208
+ session.start()
209
+
63
210
  if show_progress:
64
211
  lf.console.display(pg.view(session, name='agent_session'))
65
212
 
@@ -107,8 +254,14 @@ class Action(pg.Object):
107
254
  action=self,
108
255
  error=error
109
256
  )
257
+ if self._session is not None:
258
+ self._session.end(result=None, error=error)
110
259
  raise
111
- return result
260
+
261
+ if self._session is not None:
262
+ # Session is created by current action. Stop the session.
263
+ self._session.end(result)
264
+ return result
112
265
 
113
266
  @abc.abstractmethod
114
267
  def call(self, session: 'Session', **kwargs) -> Any:
@@ -229,9 +382,6 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
229
382
  remove_class=['running'],
230
383
  )
231
384
 
232
- def __len__(self) -> int:
233
- return len(self.items)
234
-
235
385
  @property
236
386
  def has_started(self) -> bool:
237
387
  return self.start_time is not None
@@ -306,6 +456,22 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
306
456
  for x in branch._iter_subtree(item_cls): # pylint: disable=protected-access
307
457
  yield x
308
458
 
459
+ #
460
+ # Shortcut methods to operate on the execution trace.
461
+ #
462
+
463
+ def __len__(self) -> int:
464
+ return len(self.items)
465
+
466
+ def __iter__(self) -> Iterator[TracedItem]:
467
+ return iter(self.items)
468
+
469
+ def __bool__(self) -> bool:
470
+ return bool(self.items)
471
+
472
+ def __getitem__(self, index: int) -> TracedItem:
473
+ return self.items[index]
474
+
309
475
  def append(self, item: TracedItem) -> None:
310
476
  """Appends an item to the sequence."""
311
477
  with pg.notify_on_change(False):
@@ -935,6 +1101,44 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
935
1101
  skip_notification=True
936
1102
  )
937
1103
 
1104
+ def start(self) -> None:
1105
+ """Starts the session."""
1106
+ self.root.execution.start()
1107
+
1108
+ def end(
1109
+ self,
1110
+ result: Any,
1111
+ error: pg.utils.ErrorInfo | None = None,
1112
+ metadata: dict[str, Any] | None = None,
1113
+ ) -> None:
1114
+ """Ends the session."""
1115
+ self.root.end(result, error, metadata)
1116
+
1117
+ def __enter__(self):
1118
+ """Enters the session."""
1119
+ self.start()
1120
+ return self
1121
+
1122
+ def __exit__(self, exc_type, exc_val, exc_tb):
1123
+ """Exits the session."""
1124
+ # We allow users to explicitly end the session with specified result
1125
+ # and metadata.
1126
+ if self.root.execution.has_stopped:
1127
+ return
1128
+
1129
+ if exc_val is not None:
1130
+ result, metadata = None, None
1131
+ error = pg.utils.ErrorInfo.from_exception(exc_val)
1132
+ else:
1133
+ actions = self.root.actions
1134
+ if actions:
1135
+ result = actions[-1].result
1136
+ error = actions[-1].error
1137
+ metadata = actions[-1].metadata
1138
+ else:
1139
+ result, error, metadata = None, None, None
1140
+ self.end(result, error, metadata)
1141
+
938
1142
  #
939
1143
  # Context-manager for information tracking.
940
1144
  #
@@ -942,8 +1146,12 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
942
1146
  @contextlib.contextmanager
943
1147
  def track_action(self, action: Action) -> Iterator[ActionInvocation]:
944
1148
  """Track the execution of an action."""
945
- if not self._current_execution.has_started:
946
- self._current_execution.start()
1149
+ if not self.root.execution.has_started:
1150
+ raise ValueError(
1151
+ 'Please call `Session.start() / Session.end()` explicitly, '
1152
+ 'or use `with Session(...) as session: ...` context manager to '
1153
+ 'signal the start and end of the session.'
1154
+ )
947
1155
 
948
1156
  invocation = ActionInvocation(pg.maybe_ref(action))
949
1157
  action._invocation = invocation # pylint: disable=protected-access
@@ -960,12 +1168,6 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
960
1168
  finally:
961
1169
  self._current_execution = parent_execution
962
1170
  self._current_action = parent_action
963
- if parent_action is self.root:
964
- parent_action.end(
965
- result=invocation.result,
966
- metadata=invocation.metadata,
967
- error=invocation.error
968
- )
969
1171
 
970
1172
  @contextlib.contextmanager
971
1173
  def track_phase(self, name: str | None) -> Iterator[ExecutionTrace]:
@@ -1255,6 +1457,21 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1255
1457
  """Returns the final result of the session."""
1256
1458
  return self.root.result
1257
1459
 
1460
+ @property
1461
+ def has_started(self) -> bool:
1462
+ """Returns whether the session has started."""
1463
+ return self.root.execution.has_started
1464
+
1465
+ @property
1466
+ def has_stopped(self) -> bool:
1467
+ """Returns whether the session has stopped."""
1468
+ return self.root.execution.has_stopped
1469
+
1470
+ @property
1471
+ def has_error(self) -> bool:
1472
+ """Returns whether the session has an error."""
1473
+ return self.root.has_error
1474
+
1258
1475
  @property
1259
1476
  def current_action(self) -> ActionInvocation:
1260
1477
  """Returns the current invocation."""
@@ -34,17 +34,20 @@ class ActionEval(lf.eval.v2.Evaluation):
34
34
  def process(self, example: lf.eval.v2.Example) -> tuple[str, dict[str, Any]]:
35
35
  example_input = example.input
36
36
  action = example_input.action
37
- session = action_lib.Session(id=f'{self.id}#example-{example.id}')
38
37
 
39
- # NOTE(daiyip): Setting session as metadata before action execution, so we
40
- # could use `Evaluation.state.in_progress_examples` to access the session
41
- # for status reporting from other threads.
42
- example.metadata['session'] = session
38
+ # We explicitly create a session here to use a custom session ID.
39
+ with action_lib.Session(id=f'{self.id}#example-{example.id}') as session:
40
+
41
+ # NOTE(daiyip): Setting session as metadata before action execution, so we
42
+ # could use `Evaluation.state.in_progress_examples` to access the session
43
+ # for status reporting from other threads.
44
+ example.metadata['session'] = session
45
+
46
+ with lf.logging.use_log_level('fatal'):
47
+ kwargs = self.action_args.copy()
48
+ kwargs.update(verbose=True)
49
+ action(session=session, **kwargs)
43
50
 
44
- with lf.logging.use_log_level('fatal'):
45
- kwargs = self.action_args.copy()
46
- kwargs.update(verbose=True)
47
- action(session=session, **kwargs)
48
51
  return session.final_result, dict(session=session)
49
52
 
50
53
  #
@@ -76,8 +79,9 @@ class ActionEvalV1(lf_eval.Matching):
76
79
 
77
80
  def process(self, example: pg.Dict, **kwargs):
78
81
  action = example.action
79
- session = action_lib.Session(id=str(getattr(example, 'id', '<empty>')))
80
- action(session=session, lm=self.lm, **kwargs)
82
+ with action_lib.Session(
83
+ id=str(getattr(example, 'id', '<empty>'))) as session:
84
+ action(session=session, lm=self.lm, **kwargs)
81
85
  return session.as_message()
82
86
 
83
87
  def answer(self, output: Any, example: pg.Dict) -> Any:
@@ -68,7 +68,6 @@ class ActionEvalV1Test(unittest.TestCase):
68
68
 
69
69
  s = FooEval()
70
70
  result = s.run(summary=False)
71
- pg.print(result)
72
71
  self.assertEqual(
73
72
  result,
74
73
  dict(
@@ -98,7 +98,7 @@ class ExecutionTraceTest(unittest.TestCase):
98
98
  self.assertEqual(action_invocation.execution.id, '/a1')
99
99
 
100
100
  root.execution.reset()
101
- self.assertEqual(len(root.execution), 0)
101
+ self.assertFalse(root.execution)
102
102
 
103
103
 
104
104
  class SessionTest(unittest.TestCase):
@@ -112,12 +112,18 @@ class SessionTest(unittest.TestCase):
112
112
 
113
113
  session = action_lib.Session(id='agent@1')
114
114
  self.assertEqual(session.id, 'agent@1')
115
+ self.assertFalse(session.has_started)
116
+ self.assertFalse(session.has_stopped)
115
117
 
116
118
  # Render HTML view to trigger dynamic update during execution.
117
119
  _ = session.to_html()
118
120
 
119
- self.assertEqual(foo(session, lm=lm, verbose=True), 3)
121
+ with session:
122
+ result = foo(session, lm=lm, verbose=True)
120
123
 
124
+ self.assertTrue(session.has_started)
125
+ self.assertTrue(session.has_stopped)
126
+ self.assertEqual(result, 3)
121
127
  self.assertIsNone(foo.session)
122
128
  self.assertEqual(foo.result, 3)
123
129
  self.assertEqual(
@@ -135,8 +141,8 @@ class SessionTest(unittest.TestCase):
135
141
  self.assertIsNone(root.parent_action)
136
142
  self.assertEqual(root.id, 'agent@1:')
137
143
  self.assertEqual(root.execution.id, 'agent@1:')
138
- self.assertEqual(len(root.execution.items), 1)
139
- self.assertIs(root.execution.items[0].action, foo)
144
+ self.assertEqual(len(root.execution), 1)
145
+ self.assertIs(root.execution[0].action, foo)
140
146
 
141
147
  self.assertTrue(root.execution.has_started)
142
148
  self.assertTrue(root.execution.has_stopped)
@@ -160,14 +166,14 @@ class SessionTest(unittest.TestCase):
160
166
  self.assertEqual(root.usage_summary.total.num_requests, 6)
161
167
 
162
168
  # Inspecting the top-level action (Foo)
163
- foo_invocation = root.execution.items[0]
169
+ foo_invocation = root.execution[0]
164
170
  self.assertIs(foo_invocation.parent_action, root)
165
171
  self.assertEqual(foo_invocation.id, 'agent@1:/a1')
166
172
  self.assertEqual(foo_invocation.execution.id, 'agent@1:/a1')
167
173
  self.assertEqual(len(foo_invocation.execution.items), 4)
168
174
 
169
175
  # Prepare phase.
170
- prepare_phase = foo_invocation.execution.items[0]
176
+ prepare_phase = foo_invocation.execution[0]
171
177
  self.assertIsInstance(prepare_phase, action_lib.ExecutionTrace)
172
178
  self.assertEqual(prepare_phase.id, 'agent@1:/a1/prepare')
173
179
  self.assertEqual(len(prepare_phase.items), 2)
@@ -179,7 +185,7 @@ class SessionTest(unittest.TestCase):
179
185
  self.assertEqual(prepare_phase.items[1].id, 'agent@1:/a1/prepare/q1')
180
186
 
181
187
  # Tracked queries.
182
- query_invocation = foo_invocation.execution.items[1]
188
+ query_invocation = foo_invocation.execution[1]
183
189
  self.assertIsInstance(query_invocation, lf_structured.QueryInvocation)
184
190
  self.assertEqual(query_invocation.id, 'agent@1:/a1/q2')
185
191
  self.assertIs(query_invocation.lm, lm)
@@ -197,7 +203,7 @@ class SessionTest(unittest.TestCase):
197
203
  )
198
204
 
199
205
  # Tracked parallel executions.
200
- parallel_executions = foo_invocation.execution.items[2]
206
+ parallel_executions = foo_invocation.execution[2]
201
207
  self.assertEqual(parallel_executions.id, 'agent@1:/a1/p1')
202
208
  self.assertIsInstance(parallel_executions, action_lib.ParallelExecutions)
203
209
  self.assertEqual(len(parallel_executions), 3)
@@ -209,7 +215,7 @@ class SessionTest(unittest.TestCase):
209
215
  self.assertEqual(len(parallel_executions[2].queries), 1)
210
216
 
211
217
  # Invocation to Bar.
212
- bar_invocation = foo_invocation.execution.items[3]
218
+ bar_invocation = foo_invocation.execution[3]
213
219
  self.assertIs(bar_invocation.parent_action, foo_invocation)
214
220
  self.assertEqual(bar_invocation.id, 'agent@1:/a1/a1')
215
221
  self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
@@ -240,10 +246,10 @@ class SessionTest(unittest.TestCase):
240
246
  root = session.root
241
247
  self.assertRegex(root.id, 'agent@.*:')
242
248
  self.assertTrue(root.has_error)
243
- foo_invocation = root.execution.items[0]
249
+ foo_invocation = root.execution[0]
244
250
  self.assertIsInstance(foo_invocation, action_lib.ActionInvocation)
245
251
  self.assertTrue(foo_invocation.has_error)
246
- bar_invocation = foo_invocation.execution.items[3]
252
+ bar_invocation = foo_invocation.execution[3]
247
253
  self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
248
254
  self.assertTrue(bar_invocation.has_error)
249
255
 
@@ -265,11 +271,146 @@ class SessionTest(unittest.TestCase):
265
271
  root = session.root
266
272
  self.assertRegex(root.id, 'agent@.*:')
267
273
  self.assertTrue(root.has_error)
268
- foo_invocation = root.execution.items[0]
274
+ foo_invocation = root.execution[0]
269
275
  self.assertIsInstance(foo_invocation, action_lib.ActionInvocation)
270
276
  self.assertTrue(foo_invocation.has_error)
271
277
  self.assertEqual(len(foo_invocation.execution.items), 2)
272
278
 
279
+ def test_succeeded_with_implicit_session(self):
280
+ lm = fake.StaticResponse('lm response')
281
+ foo = Foo(1)
282
+ foo(lm=lm, verbose=True)
283
+ session = foo.session
284
+ self.assertIsNotNone(session)
285
+ self.assertIsInstance(session.root.action, action_lib.RootAction)
286
+ self.assertIs(session.current_action, session.root)
287
+ self.assertTrue(session.has_started)
288
+ self.assertTrue(session.has_stopped)
289
+ self.assertEqual(session.final_result, 3)
290
+ self.assertFalse(session.root.has_error)
291
+ self.assertEqual(session.root.metadata, {})
292
+
293
+ def test_failed_with_implicit_session(self):
294
+ lm = fake.StaticResponse('lm response')
295
+ foo = Foo(1, simulate_action_error=True)
296
+ with self.assertRaisesRegex(ValueError, 'Bar error'):
297
+ foo(lm=lm)
298
+ session = foo.session
299
+ self.assertIsNotNone(session)
300
+ self.assertIsInstance(session.root.action, action_lib.RootAction)
301
+ self.assertIs(session.current_action, session.root)
302
+ self.assertTrue(session.has_started)
303
+ self.assertTrue(session.has_stopped)
304
+ self.assertTrue(session.has_error)
305
+ self.assertIsInstance(session.root.error, pg.utils.ErrorInfo)
306
+ self.assertIn('Bar error', str(session.root.error))
307
+
308
+ def test_succeeded_with_explicit_session(self):
309
+ lm = fake.StaticResponse('lm response')
310
+ foo = Foo(1)
311
+ self.assertIsNone(foo.session)
312
+ self.assertIsNone(foo.result)
313
+ self.assertIsNone(foo.metadata)
314
+
315
+ session = action_lib.Session(id='agent@1')
316
+ self.assertEqual(session.id, 'agent@1')
317
+ self.assertFalse(session.has_started)
318
+ self.assertFalse(session.has_stopped)
319
+
320
+ with session:
321
+ result = foo(session, lm=lm, verbose=True)
322
+
323
+ self.assertTrue(session.has_started)
324
+ self.assertTrue(session.has_stopped)
325
+ self.assertEqual(result, 3)
326
+ self.assertIsNone(foo.session)
327
+ self.assertEqual(foo.result, 3)
328
+ self.assertEqual(
329
+ foo.metadata, dict(note='foo', subtask_0=0, subtask_1=1, subtask_2=2)
330
+ )
331
+ self.assertIs(session.final_result, foo.result)
332
+ self.assertFalse(session.has_error)
333
+
334
+ def test_succeeded_with_explicit_session_start_end(self):
335
+ lm = fake.StaticResponse('lm response')
336
+ foo = Foo(1)
337
+ self.assertIsNone(foo.session)
338
+ self.assertIsNone(foo.result)
339
+ self.assertIsNone(foo.metadata)
340
+
341
+ session = action_lib.Session(id='agent@1')
342
+ self.assertEqual(session.id, 'agent@1')
343
+ self.assertFalse(session.has_started)
344
+ self.assertFalse(session.has_stopped)
345
+
346
+ session.start()
347
+ result = foo(session, lm=lm, verbose=True)
348
+ session.end(result)
349
+
350
+ self.assertTrue(session.has_started)
351
+ self.assertTrue(session.has_stopped)
352
+ self.assertEqual(result, 3)
353
+ self.assertIsNone(foo.session)
354
+ self.assertEqual(foo.result, 3)
355
+ self.assertEqual(
356
+ foo.metadata, dict(note='foo', subtask_0=0, subtask_1=1, subtask_2=2)
357
+ )
358
+ self.assertIs(session.final_result, foo.result)
359
+ self.assertFalse(session.has_error)
360
+
361
+ def test_failed_with_explicit_session(self):
362
+ lm = fake.StaticResponse('lm response')
363
+ foo = Foo(1, simulate_action_error=True)
364
+ session = action_lib.Session(id='agent@1')
365
+ with self.assertRaisesRegex(ValueError, 'Bar error'):
366
+ with session:
367
+ foo(session, lm=lm, verbose=True)
368
+ self.assertTrue(session.has_started)
369
+ self.assertTrue(session.has_stopped)
370
+ self.assertTrue(session.has_error)
371
+ self.assertIsNone(session.final_result)
372
+ self.assertIsInstance(session.root.error, pg.utils.ErrorInfo)
373
+ self.assertIn('Bar error', str(session.root.error))
374
+
375
+ def test_failed_with_explicit_session_without_start(self):
376
+ lm = fake.StaticResponse('lm response')
377
+ foo = Foo(1, simulate_action_error=True)
378
+ session = action_lib.Session(id='agent@1')
379
+ with self.assertRaisesRegex(ValueError, 'Please call `Session.start'):
380
+ foo(session, lm=lm, verbose=True)
381
+
382
+ def test_succeed_with_multiple_actions(self):
383
+ lm = fake.StaticResponse('lm response')
384
+ with action_lib.Session() as session:
385
+ x = Bar()(session, lm=lm)
386
+ y = Bar()(session, lm=lm)
387
+ self.assertTrue(session.has_started)
388
+ self.assertFalse(session.has_stopped)
389
+ session.add_metadata(note='root metadata')
390
+ session.end(x + y)
391
+
392
+ self.assertTrue(session.has_started)
393
+ self.assertTrue(session.has_stopped)
394
+ self.assertEqual(session.final_result, 2 + 2)
395
+ self.assertEqual(len(session.root.execution), 2)
396
+ self.assertEqual(session.root.metadata, dict(note='root metadata'))
397
+
398
+ def test_failed_with_multiple_actions(self):
399
+ lm = fake.StaticResponse('lm response')
400
+ with self.assertRaisesRegex(ValueError, 'Bar error'):
401
+ with action_lib.Session() as session:
402
+ x = Bar()(session, lm=lm)
403
+ y = Bar(simulate_action_error=True)(session, lm=lm)
404
+ session.end(x + y)
405
+
406
+ self.assertTrue(session.has_started)
407
+ self.assertTrue(session.has_stopped)
408
+ self.assertTrue(session.has_error)
409
+ self.assertIsInstance(session.root.error, pg.utils.ErrorInfo)
410
+ self.assertEqual(len(session.root.execution), 2)
411
+ self.assertFalse(session.root.execution[0].has_error)
412
+ self.assertTrue(session.root.execution[1].has_error)
413
+
273
414
  def test_log(self):
274
415
  session = action_lib.Session()
275
416
  session.debug('hi', x=1, y=2)
@@ -605,13 +605,13 @@ class Gemini(rest.REST):
605
605
  raise lf.ModalityError(f'Unsupported modality: {chunk!r}') from e
606
606
  return chunk
607
607
 
608
- contents = []
609
608
  if system_message := prompt.get('system_message'):
610
609
  assert isinstance(system_message, lf.SystemMessage), type(system_message)
611
- contents.append(
612
- system_message.as_format(
613
- 'gemini', chunk_preprocessor=modality_conversion)
610
+ request['systemInstruction'] = system_message.as_format(
611
+ 'gemini', chunk_preprocessor=modality_conversion
614
612
  )
613
+
614
+ contents = []
615
615
  contents.append(
616
616
  prompt.as_format('gemini', chunk_preprocessor=modality_conversion)
617
617
  )
@@ -647,6 +647,11 @@ class Gemini(rest.REST):
647
647
  + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
648
648
  + pg.to_json_str(json_schema, json_indent=2)
649
649
  )
650
+ if options.max_thinking_tokens is not None:
651
+ config['thinkingConfig'] = {
652
+ 'thinkingBudget': options.max_thinking_tokens
653
+ }
654
+
650
655
  return config
651
656
 
652
657
  def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
@@ -659,18 +664,25 @@ class Gemini(rest.REST):
659
664
  # NOTE(daiyip): We saw cases that `candidatesTokenCount` is not present.
660
665
  # Therefore, we use 0 as the default value.
661
666
  output_tokens = usage.get('candidatesTokenCount', 0)
667
+ thinking_tokens = usage.get('thoughtsTokenCount', 0)
668
+ total_tokens = usage.get('totalTokenCount', 0)
662
669
 
663
670
  return lf.LMSamplingResult(
664
671
  [lf.LMSample(message) for message in messages],
665
672
  usage=lf.LMSamplingUsage(
666
673
  prompt_tokens=input_tokens,
667
674
  completion_tokens=output_tokens,
668
- total_tokens=input_tokens + output_tokens,
675
+ total_tokens=total_tokens,
676
+ completion_tokens_details={
677
+ 'thinking_tokens': thinking_tokens,
678
+ },
669
679
  ),
670
680
  )
671
681
 
672
682
  def _error(self, status_code: int, content: str) -> lf.LMError:
673
- if (status_code == 400
674
- and b'exceeds the maximum number of tokens' in content):
683
+ if (
684
+ status_code == 400
685
+ and b'exceeds the maximum number of tokens' in content
686
+ ):
675
687
  return lf.ContextLimitError(f'{status_code}: {content}')
676
688
  return super()._error(status_code, content)
@@ -38,14 +38,21 @@ example_image = (
38
38
  def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
39
39
  del url, kwargs
40
40
  c = pg.Dict(json['generationConfig'])
41
- content = '\n'.join(c['parts'][0]['text'] for c in json['contents'])
41
+ parts = []
42
+ if system_instruction := json.get('systemInstruction'):
43
+ parts.extend([p['text'] for p in system_instruction.get('parts', [])])
44
+
45
+ # Add text from the main contents.
46
+ for c_item in json.get('contents', []):
47
+ for p in c_item.get('parts', []):
48
+ parts.append(p['text'])
49
+ content = '\n'.join(parts)
42
50
  response = requests.Response()
43
51
  response.status_code = 200
44
52
  response._content = pg.to_json_str({
45
53
  'candidates': [
46
54
  {
47
55
  'content': {
48
- 'role': 'model',
49
56
  'parts': [
50
57
  {
51
58
  'text': (
@@ -146,6 +153,30 @@ class GeminiTest(unittest.TestCase):
146
153
  }
147
154
  ),
148
155
  )
156
+
157
+ # Add test for thinkingConfig.
158
+ actual = model._generation_config(
159
+ lf.UserMessage('hi'),
160
+ lf.LMSamplingOptions(
161
+ max_thinking_tokens=100,
162
+ ),
163
+ )
164
+ self.assertEqual(
165
+ actual,
166
+ dict(
167
+ candidateCount=1,
168
+ temperature=None,
169
+ topP=None,
170
+ topK=40,
171
+ maxOutputTokens=None,
172
+ stopSequences=None,
173
+ responseLogprobs=False,
174
+ logprobs=None,
175
+ seed=None,
176
+ thinkingConfig={'thinkingBudget': 100},
177
+ ),
178
+ )
179
+
149
180
  with self.assertRaisesRegex(
150
181
  ValueError, '`json_schema` must be a dict, got'
151
182
  ):
@@ -28,6 +28,7 @@ import pyglove as pg
28
28
  try:
29
29
  # pylint: disable=g-import-not-at-top
30
30
  from google import auth as google_auth
31
+ from google.auth import exceptions as auth_exceptions
31
32
  from google.auth import credentials as credentials_lib
32
33
  from google.auth.transport import requests as auth_requests
33
34
  # pylint: enable=g-import-not-at-top
@@ -35,6 +36,7 @@ try:
35
36
  Credentials = credentials_lib.Credentials
36
37
  except ImportError:
37
38
  google_auth = None
39
+ auth_exceptions = None
38
40
  credentials_lib = None
39
41
  auth_requests = None
40
42
  Credentials = Any
@@ -134,6 +136,16 @@ class VertexAI(rest.REST):
134
136
  assert auth_requests is not None
135
137
  return auth_requests.AuthorizedSession(self._credentials)
136
138
 
139
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
140
+ assert auth_exceptions is not None
141
+ try:
142
+ return super()._sample_single(prompt)
143
+ except (
144
+ auth_exceptions.RefreshError,
145
+ ) as e:
146
+ raise lf.TemporaryLMError(
147
+ f'Failed to refresh Google authentication credentials: {e}'
148
+ ) from e
137
149
 
138
150
  #
139
151
  # Gemini models served by Vertex AI.
@@ -19,6 +19,7 @@ from unittest import mock
19
19
 
20
20
  from google.auth import exceptions
21
21
  import langfun.core as lf
22
+ from langfun.core.llms import rest
22
23
  from langfun.core.llms import vertexai
23
24
  import pyglove as pg
24
25
 
@@ -51,6 +52,22 @@ class VertexAITest(unittest.TestCase):
51
52
  del os.environ['VERTEXAI_PROJECT']
52
53
  del os.environ['VERTEXAI_LOCATION']
53
54
 
55
+ def test_auth_refresh_error(self):
56
+ def _auth_refresh_error(*args, **kwargs):
57
+ del args, kwargs
58
+ raise exceptions.RefreshError('Cannot refresh token')
59
+
60
+ with self.assertRaisesRegex(
61
+ lf.concurrent.RetryError,
62
+ 'Failed to refresh Google authentication credentials'
63
+ ):
64
+ with mock.patch.object(rest.REST, '_sample_single') as mock_sample_single:
65
+ mock_sample_single.side_effect = _auth_refresh_error
66
+ model = vertexai.VertexAIGemini15Pro(
67
+ project='abc', location='us-central1', max_attempts=1
68
+ )
69
+ model('hi')
70
+
54
71
 
55
72
  class VertexAIAnthropicTest(unittest.TestCase):
56
73
  """Tests for VertexAI Anthropic models."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langfun
3
- Version: 0.1.2.dev202505020804
3
+ Version: 0.1.2.dev202505030803
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -26,10 +26,10 @@ langfun/core/subscription_test.py,sha256=Y4ZdbZEwm83YNZBxHff0QR4QUa4rdaNXA3_jfIc
26
26
  langfun/core/template.py,sha256=jNhYSrbLIn9kZOa03w5QZbyjgfnzJzE_ZrrMvvWY4t4,24929
27
27
  langfun/core/template_test.py,sha256=AQv_m9qE93WxhEhSlm1xaBgB4hu0UVtA53dljngkUW0,17090
28
28
  langfun/core/agentic/__init__.py,sha256=qR3jlfUO4rhIoYdRDLz-d22YZf3FvU4FW88vsjiGDQQ,1224
29
- langfun/core/agentic/action.py,sha256=9P7xDiZVUV9MvJDfuAfLx-xa7qvS5F0EOGWDQnjAZBw,38931
30
- langfun/core/agentic/action_eval.py,sha256=NwjQ5hR-7YT6mo2q0mbDOgmNCKzTMpEzslYtR3fjXJY,4862
31
- langfun/core/agentic/action_eval_test.py,sha256=tRUkWmOE9p0rpNOq19xAY2oDEnYsEEykjg6sUpAwJk0,2832
32
- langfun/core/agentic/action_test.py,sha256=9EZKgLaBrqTErSRoxtrSlzmCz_cbnwWu0ZqpwKLst-s,10224
29
+ langfun/core/agentic/action.py,sha256=3m2-k07Zz8qrzOdOa7xPl7fRH3I0c3VsMLR86_JBCcU,45359
30
+ langfun/core/agentic/action_eval.py,sha256=JXhS5qEjWu9EZ0chDsjWxCqPAV26PUCBijtUYxiDeO4,4975
31
+ langfun/core/agentic/action_eval_test.py,sha256=7AkOwNbUX-ZgR1R0a7bvUZ5abNTUV7blf_8Mnrwb-II,2811
32
+ langfun/core/agentic/action_test.py,sha256=ezqg3tKlVwgLMnHKUmOdtxpnntuL8YIvhcTCSSdb8oc,15468
33
33
  langfun/core/coding/__init__.py,sha256=5utju_fwEsImaiftx4oXKl9FAM8p281k8-Esdh_-m1w,835
34
34
  langfun/core/coding/python/__init__.py,sha256=4ByknuoNU-mOIHwHKnTtmo6oD64oMFtlqPlYWmA5Wic,1736
35
35
  langfun/core/coding/python/correction.py,sha256=7zBedlhQKMPA4cfchUMxAOFl6Zl5RqCyllRHGWys40s,7092
@@ -92,8 +92,8 @@ langfun/core/llms/deepseek.py,sha256=jvTxdXPr-vH6HNakn_Ootx1heDg8Fen2FUkUW36bpCs
92
92
  langfun/core/llms/deepseek_test.py,sha256=DvROWPlDuow5E1lfoSkhyGt_ELA19JoQoDsTnRgDtTg,1847
93
93
  langfun/core/llms/fake.py,sha256=xmgCkk9y0I4x0IT32SZ9_OT27aLadXH8PRiYNo5VTd4,3265
94
94
  langfun/core/llms/fake_test.py,sha256=2h13qkwEz_JR0mtUDPxdAhQo7MueXaFSwsD2DIRDW9g,7653
95
- langfun/core/llms/gemini.py,sha256=k1uQy1aykkPtCpjnm56I74yMYIoOHN9j1ON7O8LDBJI,24111
96
- langfun/core/llms/gemini_test.py,sha256=d9Pvf3xmHgofv8AKXmbnfndsScxmgR5q_ctSIvEXYrU,6808
95
+ langfun/core/llms/gemini.py,sha256=ZtUo2lQMSByYlzSALWae3KxFiKNtOOwGbkFwZTI1dO0,24472
96
+ langfun/core/llms/gemini_test.py,sha256=Ve9X2Wvwu9wVFHpKZDP-qoM1_hzB4kgt6_HR9wxtNkg,7592
97
97
  langfun/core/llms/google_genai.py,sha256=j8W22WFvkT80Fw-r7Rg-e7MKhcSwljZkmtuufwSEn5s,5051
98
98
  langfun/core/llms/google_genai_test.py,sha256=NKNtpebArQ9ZR7Qsnhd2prFIpMjleojy6o6VMXkJ1zY,1502
99
99
  langfun/core/llms/groq.py,sha256=S9V10kFo3cgX89qPgt_umq-SpRnxEDLTt_hJmpERfbo,12066
@@ -106,8 +106,8 @@ langfun/core/llms/openai_compatible_test.py,sha256=KwOMA7tsmOxFBjezltkBDSU77AvOQ
106
106
  langfun/core/llms/openai_test.py,sha256=gwuO6aoa296iM2welWV9ua4KF8gEVGsEPakgbtkWkFQ,2687
107
107
  langfun/core/llms/rest.py,sha256=MCybcHApJcf49lubLnDzScN9Oc2IWY_JnMHIGdbDOuU,4474
108
108
  langfun/core/llms/rest_test.py,sha256=_zM7nV8DEVyoXNiQOnuwJ917mWjki0614H88rNmDboE,5020
109
- langfun/core/llms/vertexai.py,sha256=jCO1AjB3kBdBDNznygFpeXMZy-a7Lcap0NTe5Y7Wzx4,18205
110
- langfun/core/llms/vertexai_test.py,sha256=dOprP_uLNmXHYxMoX_hMPMsjKR-e_B5nKHjhlMCQoOQ,4252
109
+ langfun/core/llms/vertexai.py,sha256=4t_Noj7cqzLNmESYCYzz9Ndodd_K4I4zxVLmljJ7r3E,18630
110
+ langfun/core/llms/vertexai_test.py,sha256=0M4jsPOXGagdzPfEdJixmyLdhmmERePZWSFfTwnaYCQ,4875
111
111
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
112
112
  langfun/core/llms/cache/base.py,sha256=rt3zwmyw0y9jsSGW-ZbV1vAfLxQ7_3AVk0l2EySlse4,3918
113
113
  langfun/core/llms/cache/in_memory.py,sha256=i58oiQL28RDsq37dwqgVpC2mBETJjIEFS20yHiV5MKU,5185
@@ -156,8 +156,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
156
156
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
157
157
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
158
158
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
159
- langfun-0.1.2.dev202505020804.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
160
- langfun-0.1.2.dev202505020804.dist-info/METADATA,sha256=AWOuM3J68NYl8LFoEFVYgZyCDxyIDvcxwz79ce1GTSE,8178
161
- langfun-0.1.2.dev202505020804.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
162
- langfun-0.1.2.dev202505020804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
163
- langfun-0.1.2.dev202505020804.dist-info/RECORD,,
159
+ langfun-0.1.2.dev202505030803.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
160
+ langfun-0.1.2.dev202505030803.dist-info/METADATA,sha256=g86bSMVJK0G9ZcJSYdMdBiHudgkbJ4sxiO5KvoHJwt8,8178
161
+ langfun-0.1.2.dev202505030803.dist-info/WHEEL,sha256=7ciDxtlje1X8OhobNuGgi1t-ACdFSelPnSmDPrtlobY,91
162
+ langfun-0.1.2.dev202505030803.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
163
+ langfun-0.1.2.dev202505030803.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.1.0)
2
+ Generator: setuptools (80.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5