langfun 0.1.2.dev202508210804__py3-none-any.whl → 0.1.2.dev202508230803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

@@ -17,6 +17,9 @@
17
17
  # pylint: disable=g-importing-member
18
18
  # pylint: disable=g-import-not-at-top
19
19
 
20
+ from langfun.core.agentic.action import ActionError
21
+ from langfun.core.agentic.action import ActionTimeoutError
22
+
20
23
  from langfun.core.agentic.action import Action
21
24
  from langfun.core.agentic.action import ActionInvocation
22
25
  from langfun.core.agentic.action import ExecutionTrace
@@ -26,6 +26,14 @@ from langfun.core import structured as lf_structured
26
26
  import pyglove as pg
27
27
 
28
28
 
29
+ class ActionError(Exception): # pylint: disable=g-bad-exception-name
30
+ """Base class for common action errors."""
31
+
32
+
33
+ class ActionTimeoutError(ActionError):
34
+ """Raised when an action exceeds the max execution time."""
35
+
36
+
29
37
  class Action(pg.Object):
30
38
  """Base class for Langfun's agentic actions.
31
39
 
@@ -203,6 +211,7 @@ class Action(pg.Object):
203
211
  self,
204
212
  session: Optional['Session'] = None,
205
213
  *,
214
+ max_execution_time: float | None = None,
206
215
  show_progress: bool = True,
207
216
  verbose: bool = False,
208
217
  **kwargs
@@ -212,6 +221,7 @@ class Action(pg.Object):
212
221
  return await lf.invoke_async(
213
222
  self.__call__,
214
223
  session,
224
+ max_execution_time=max_execution_time,
215
225
  show_progress=show_progress,
216
226
  verbose=verbose,
217
227
  **kwargs
@@ -221,11 +231,29 @@ class Action(pg.Object):
221
231
  self,
222
232
  session: Optional['Session'] = None,
223
233
  *,
234
+ max_execution_time: float | None = None,
224
235
  show_progress: bool = True,
225
236
  verbose: bool = False,
226
237
  **kwargs
227
238
  ) -> Any:
228
- """Executes the action."""
239
+ """Executes the action.
240
+
241
+ Args:
242
+ session: The session to use for the action.
243
+ max_execution_time: The max allowed execution time in seconds for the
244
+ action. The effective `max_execution_time` is the smaller of the
245
+ remaining time of the parent invocation and the `max_execution_time`
246
+ for the current invocation. Since we are running action within the
247
+ current thread, there is no guarantee that the action will be executed
248
+ within this time limit, but we will try to stop the action as soon as
249
+ the next operation on the session is called.
250
+ show_progress: Whether to show the progress of the action.
251
+ verbose: Whether to log the action execution.
252
+ **kwargs: Additional keyword arguments to pass to the action.
253
+
254
+ Returns:
255
+ The result of the action.
256
+ """
229
257
  if session is None:
230
258
  session = Session(verbose=verbose)
231
259
  session.start()
@@ -238,7 +266,7 @@ class Action(pg.Object):
238
266
  else:
239
267
  self._session = None
240
268
 
241
- with session.track_action(self):
269
+ with session.track_action(self, max_execution_time=max_execution_time):
242
270
  try:
243
271
  result = self.call(session=session, **kwargs)
244
272
  self._invocation.end(result)
@@ -841,6 +869,21 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
841
869
  'Error from the action if failed.'
842
870
  ] = None
843
871
 
872
+ max_execution_time: Annotated[
873
+ float | None,
874
+ (
875
+ 'The maximum allowed execution time for the action. '
876
+ 'It is set when the action or any of its parent actions has been '
877
+ 'called with a `max_execution_time` argument. '
878
+ 'The value is the remaining time of the parent invocation or '
879
+ 'the `max_execution_time` for the current invocation, '
880
+ 'whichever is smaller. Since we are running action within the '
881
+ 'current thread, there is no guarantee that the action will be '
882
+ 'executed within this time limit, but we will try to stop the action '
883
+ 'as soon as the next operation on the session is called.'
884
+ )
885
+ ] = None
886
+
844
887
  execution: Annotated[
845
888
  ExecutionTrace,
846
889
  'The execution sequence of the action.'
@@ -863,6 +906,13 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
863
906
  """Returns the parent action invocation."""
864
907
  return self.sym_ancestor(lambda x: isinstance(x, ActionInvocation))
865
908
 
909
+ @property
910
+ def max_remaining_execution_time(self) -> float | None:
911
+ """Returns the remaining execution time for the action."""
912
+ if self.max_execution_time is None:
913
+ return None
914
+ return max(0, self.max_execution_time - self.execution.elapse)
915
+
866
916
  @functools.cached_property
867
917
  def id(self) -> str:
868
918
  """Returns the id of the action invocation."""
@@ -1224,6 +1274,23 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1224
1274
  )
1225
1275
  self.root.end(result, error, metadata)
1226
1276
 
1277
+ def check_execution_time(self) -> None:
1278
+ """Checks the execution time of the current action."""
1279
+ action = self._current_action
1280
+ if action is not None and 0 == action.max_remaining_execution_time:
1281
+ # We raise error on the top-most action which has a time limit.
1282
+ current_action = action
1283
+ parent_action = current_action.parent_action
1284
+ while parent_action is not None and parent_action.max_execution_time:
1285
+ current_action = parent_action
1286
+ parent_action = current_action.parent_action
1287
+ raise ActionTimeoutError(
1288
+ f'Action {current_action.id} '
1289
+ f'({current_action.action.__class__.__name__}) has exceeded its '
1290
+ f'maximum execution time of {current_action.max_execution_time} '
1291
+ 'seconds.'
1292
+ )
1293
+
1227
1294
  def __enter__(self):
1228
1295
  """Enters the session."""
1229
1296
  self.start()
@@ -1254,7 +1321,11 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1254
1321
  #
1255
1322
 
1256
1323
  @contextlib.contextmanager
1257
- def track_action(self, action: Action) -> Iterator[ActionInvocation]:
1324
+ def track_action(
1325
+ self,
1326
+ action: Action,
1327
+ max_execution_time: float | None = None
1328
+ ) -> Iterator[ActionInvocation]:
1258
1329
  """Track the execution of an action."""
1259
1330
  if not self.root.execution.has_started:
1260
1331
  raise ValueError(
@@ -1263,7 +1334,13 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1263
1334
  'signal the start and end of the session.'
1264
1335
  )
1265
1336
 
1266
- invocation = ActionInvocation(pg.maybe_ref(action))
1337
+ # Early terminate the action if the execution time is exceeded.
1338
+ self.check_execution_time()
1339
+
1340
+ invocation = ActionInvocation(
1341
+ pg.maybe_ref(action),
1342
+ max_execution_time=self._child_max_execution_time(max_execution_time)
1343
+ )
1267
1344
  parent_action = self._current_action
1268
1345
  parent_execution = self._current_execution
1269
1346
  parent_execution.append(invocation)
@@ -1339,6 +1416,9 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1339
1416
  call.
1340
1417
  """
1341
1418
  def _query_start(invocation: lf_structured.QueryInvocation):
1419
+ # Early terminate the action if the execution time is exceeded.
1420
+ self.check_execution_time()
1421
+
1342
1422
  execution = self._current_execution
1343
1423
  invocation.rebind(
1344
1424
  id=f'{execution.id}/q{len(execution.queries) + 1}',
@@ -1469,6 +1549,7 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1469
1549
  ] = Exception
1470
1550
  ) -> Iterator[Any]:
1471
1551
  """Starts and tracks parallel execution with `lf.concurrent_map`."""
1552
+ self.check_execution_time()
1472
1553
  parallel_inputs = list(parallel_inputs)
1473
1554
  parallel_execution = ParallelExecutions(name=phase)
1474
1555
  self._current_execution.append(parallel_execution)
@@ -1492,90 +1573,23 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1492
1573
  _map_single,
1493
1574
  parallel_inputs,
1494
1575
  max_workers=max_workers,
1495
- timeout=timeout,
1576
+ timeout=self._child_max_execution_time(timeout),
1496
1577
  silence_on_errors=silence_on_errors
1497
1578
  ):
1498
1579
  yield input_value, result, error
1499
1580
 
1500
- # NOTE(daiyip): Clean up `query_prompt` and `query_output` once TS
1501
- # code migration is done.
1502
- def query_prompt(
1503
- self,
1504
- prompt: Union[str, lf.Template, Any],
1505
- schema: Union[
1506
- lf_structured.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
1507
- ] = None,
1508
- **kwargs,
1509
- ) -> Any:
1510
- """Calls `lf.query_prompt` and associates it with the current invocation.
1511
-
1512
- The following code are equivalent:
1513
-
1514
- Code 1:
1515
- ```
1516
- session.query_prompt(...)
1517
- ```
1518
-
1519
- Code 2:
1520
- ```
1521
- with session.track_queries() as queries:
1522
- output = lf.query_prompt(...)
1523
- ```
1524
- The former is preferred when `lf.query_prompt` is directly called by the
1525
- action.
1526
- If `lf.query_prompt` is called by a function that does not have access to
1527
- the
1528
- session, the latter should be used.
1529
-
1530
- Args:
1531
- prompt: The prompt to query.
1532
- schema: The schema to use for the query.
1533
- **kwargs: Additional keyword arguments to pass to `lf.query_prompt`.
1534
-
1535
- Returns:
1536
- The result of the query.
1537
- """
1538
- with self.track_queries():
1539
- return lf_structured.query_prompt(prompt, schema=schema, **kwargs)
1540
-
1541
- def query_output(
1542
- self,
1543
- response: Union[str, lf.Template, Any],
1544
- schema: Union[
1545
- lf_structured.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
1546
- ] = None,
1547
- **kwargs,
1548
- ) -> Any:
1549
- """Calls `lf.query_output` and associates it with the current invocation.
1550
-
1551
- The following code are equivalent:
1552
-
1553
- Code 1:
1554
- ```
1555
- session.query_output(...)
1556
- ```
1557
-
1558
- Code 2:
1559
- ```
1560
- with session.track_queries() as queries:
1561
- output = lf.query_output(...)
1562
- ```
1563
- The former is preferred when `lf.query_output` is directly called by the
1564
- action.
1565
- If `lf.query_output` is called by a function that does not have access to
1566
- the
1567
- session, the latter should be used.
1568
-
1569
- Args:
1570
- response: The response to query.
1571
- schema: The schema to use for the query.
1572
- **kwargs: Additional keyword arguments to pass to `lf.query_prompt`.
1573
-
1574
- Returns:
1575
- The result of the query.
1576
- """
1577
- with self.track_queries():
1578
- return lf_structured.query_output(response, schema=schema, **kwargs)
1581
+ def _child_max_execution_time(
1582
+ self, max_execution_time: float | None
1583
+ ) -> float | None:
1584
+ """Returns the max execution time for the child action."""
1585
+ max_remaining = None
1586
+ if self._current_action is not None:
1587
+ max_remaining = self._current_action.max_remaining_execution_time
1588
+ if max_remaining is None:
1589
+ return max_execution_time
1590
+ if max_execution_time is None:
1591
+ return max_remaining
1592
+ return min(max_remaining, max_execution_time)
1579
1593
 
1580
1594
  def _log(
1581
1595
  self,
@@ -1597,6 +1611,9 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1597
1611
  **kwargs: Additional keyword arguments to pass to `lf.logging.log` as
1598
1612
  metadata to show.
1599
1613
  """
1614
+ # Early terminate the action if the execution time is exceeded.
1615
+ self.check_execution_time()
1616
+
1600
1617
  execution = self._current_execution
1601
1618
  if for_action is None:
1602
1619
  for_action = self._current_action
@@ -14,6 +14,7 @@
14
14
  """Tests for base action."""
15
15
 
16
16
  import asyncio
17
+ import time
17
18
  import unittest
18
19
 
19
20
  import langfun.core as lf
@@ -25,10 +26,12 @@ import pyglove as pg
25
26
 
26
27
  class Bar(action_lib.Action):
27
28
  simulate_action_error: bool = False
29
+ simulate_execution_time: float = 0
28
30
 
29
31
  def call(self, session, *, lm, **kwargs):
30
32
  assert session.current_action.action is self
31
33
  session.info('Begin Bar')
34
+ time.sleep(self.simulate_execution_time)
32
35
  session.query('bar', lm=lm)
33
36
  session.add_metadata(note='bar')
34
37
  if self.simulate_action_error:
@@ -40,22 +43,27 @@ class Foo(action_lib.Action):
40
43
  x: int
41
44
  simulate_action_error: bool = False
42
45
  simulate_query_error: bool = False
46
+ simulate_execution_time: list[float] = [0, 0, 0, 0]
47
+ max_bar_execution_time: float | None = None
43
48
 
44
49
  def call(self, session, *, lm, **kwargs):
45
50
  assert session.current_action.action is self
46
51
  with session.track_phase('prepare'):
47
52
  session.info('Begin Foo', x=1)
53
+ time.sleep(self.simulate_execution_time[0])
48
54
  session.query(
49
55
  'foo',
50
56
  schema=int if self.simulate_query_error else None,
51
57
  lm=lm
52
58
  )
53
59
  with session.track_queries():
60
+ time.sleep(self.simulate_execution_time[1])
54
61
  self.make_additional_query(lm)
55
62
  session.add_metadata(note='foo')
56
63
 
57
64
  def _sub_task(i):
58
65
  session.add_metadata(**{f'subtask_{i}': i})
66
+ time.sleep(self.simulate_execution_time[2])
59
67
  return lf_structured.query(f'subtask_{i}', lm=lm)
60
68
 
61
69
  for i, output, error in session.concurrent_map(
@@ -65,8 +73,9 @@ class Foo(action_lib.Action):
65
73
  assert isinstance(output, str), output
66
74
  assert error is None, error
67
75
  return self.x + Bar(
68
- simulate_action_error=self.simulate_action_error
69
- )(session, lm=lm)
76
+ simulate_action_error=self.simulate_action_error,
77
+ simulate_execution_time=self.simulate_execution_time[3]
78
+ )(session, lm=lm, max_execution_time=self.max_bar_execution_time)
70
79
 
71
80
  def make_additional_query(self, lm):
72
81
  lf_structured.query('additional query', lm=lm)
@@ -426,6 +435,51 @@ class SessionTest(unittest.TestCase):
426
435
  self.assertFalse(session.root.execution[0].has_error)
427
436
  self.assertTrue(session.root.execution[1].has_error)
428
437
 
438
+ def test_max_execution_time(self):
439
+ lm = fake.StaticResponse('lm response')
440
+ bar = Bar(simulate_execution_time=1)
441
+ with self.assertRaisesRegex(
442
+ action_lib.ActionTimeoutError,
443
+ 'Action .*Bar.*has exceeded .* 0.5 seconds'
444
+ ):
445
+ bar(lm=lm, max_execution_time=0.5)
446
+
447
+ foo = Foo(1, simulate_execution_time=[0, 0, 0, 1])
448
+ with self.assertRaisesRegex(
449
+ action_lib.ActionTimeoutError,
450
+ 'Action .*Foo.* has exceeded .* 0.5 seconds'
451
+ ):
452
+ foo(lm=lm, max_execution_time=0.5)
453
+
454
+ # Timeout within concurrent_map.
455
+ foo = Foo(1, simulate_execution_time=[0, 0, 1, 0])
456
+ with self.assertRaisesRegex(
457
+ action_lib.ActionTimeoutError,
458
+ 'Action .*Foo.* has exceeded .* 0.5 seconds'
459
+ ):
460
+ foo(lm=lm, max_execution_time=0.5)
461
+
462
+ # Timeout within bar.
463
+ foo = Foo(
464
+ 1, simulate_execution_time=[0, 0, 0, 1], max_bar_execution_time=0.5
465
+ )
466
+ with self.assertRaisesRegex(
467
+ action_lib.ActionTimeoutError,
468
+ 'Action .*Bar.* has exceeded .* 0.5 seconds'
469
+ ):
470
+ foo(lm=lm)
471
+
472
+ # Timeout within bar, however the effective max_execution_time of bar is the
473
+ # remaining time of the parent action as it's smaller (0.5 < 1).
474
+ foo = Foo(
475
+ 1, simulate_execution_time=[0, 0.5, 0, 1.0], max_bar_execution_time=1.0
476
+ )
477
+ with self.assertRaisesRegex(
478
+ action_lib.ActionTimeoutError,
479
+ 'Action .*Foo.* has exceeded .*1.0 seconds'
480
+ ):
481
+ foo(lm=lm, max_execution_time=1.0)
482
+
429
483
  def test_log(self):
430
484
  session = action_lib.Session()
431
485
  session.debug('hi', x=1, y=2)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langfun
3
- Version: 0.1.2.dev202508210804
3
+ Version: 0.1.2.dev202508230803
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -34,11 +34,11 @@ langfun/core/subscription.py,sha256=euawEuSZP-BHydaT-AQpfYFL0m5pWPGcW0upFhrojqc,
34
34
  langfun/core/subscription_test.py,sha256=Y4ZdbZEwm83YNZBxHff0QR4QUa4rdaNXA3_jfIcArBo,8717
35
35
  langfun/core/template.py,sha256=GSOZ3OcmRtw-q05GdHrE-Y4-2MGDYsTRKmWGvVPbdhE,24962
36
36
  langfun/core/template_test.py,sha256=AQv_m9qE93WxhEhSlm1xaBgB4hu0UVtA53dljngkUW0,17090
37
- langfun/core/agentic/__init__.py,sha256=qR3jlfUO4rhIoYdRDLz-d22YZf3FvU4FW88vsjiGDQQ,1224
38
- langfun/core/agentic/action.py,sha256=HvCaYb3x9tpa0xhDuFtEGyZ9du_edDpIdaQYPPo87ts,51097
37
+ langfun/core/agentic/__init__.py,sha256=s9zRiAOtiTvp_sWeyGcykjlo6rez8asvLK7tiOELWEU,1336
38
+ langfun/core/agentic/action.py,sha256=CgOmp9ht4J6N1bg0DSqkP1G2YYP3HPi1Z7Es5qqhevs,52824
39
39
  langfun/core/agentic/action_eval.py,sha256=YTilyUEkJl_8FVMgdfO17PurWWaEJ6oA15CuefJJRLk,4887
40
40
  langfun/core/agentic/action_eval_test.py,sha256=7AkOwNbUX-ZgR1R0a7bvUZ5abNTUV7blf_8Mnrwb-II,2811
41
- langfun/core/agentic/action_test.py,sha256=5ZUCbyuJDwNCUE_hl4qt5epzQQTRdZEYN5W8J2iKRQo,16084
41
+ langfun/core/agentic/action_test.py,sha256=rorIo58S1o27CZrtw05b49mLZOaiitl_xGanGZa5HXo,18028
42
42
  langfun/core/coding/__init__.py,sha256=5utju_fwEsImaiftx4oXKl9FAM8p281k8-Esdh_-m1w,835
43
43
  langfun/core/coding/python/__init__.py,sha256=yTXm92oLpQb4A-fZ2qy-bzfhPYND7B-oidtbv1PNaX0,1678
44
44
  langfun/core/coding/python/correction.py,sha256=4PD76Xfv36Xrm8Ji3-GgGDNImtcDqWfMw3z6ircJMlM,7285
@@ -165,8 +165,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
165
165
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
166
166
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
167
167
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
168
- langfun-0.1.2.dev202508210804.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
169
- langfun-0.1.2.dev202508210804.dist-info/METADATA,sha256=C-d5bGlbkHHdz0Hd0-qdvG7fnmNMmJ8pU3Hg5zDKhqU,7380
170
- langfun-0.1.2.dev202508210804.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
171
- langfun-0.1.2.dev202508210804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
172
- langfun-0.1.2.dev202508210804.dist-info/RECORD,,
168
+ langfun-0.1.2.dev202508230803.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
169
+ langfun-0.1.2.dev202508230803.dist-info/METADATA,sha256=GjLNmc4jIoMKc8vc79MpHhBbQnaVFt-N5myH4qdLnw8,7380
170
+ langfun-0.1.2.dev202508230803.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
171
+ langfun-0.1.2.dev202508230803.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
172
+ langfun-0.1.2.dev202508230803.dist-info/RECORD,,