langfun 0.1.2.dev202510140804__py3-none-any.whl → 0.1.2.dev202510160805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/agentic/__init__.py +4 -0
- langfun/core/agentic/action.py +278 -89
- langfun/core/agentic/action_test.py +33 -5
- langfun/core/eval/v2/progress_tracking_test.py +3 -0
- langfun/core/structured/querying.py +12 -12
- langfun/core/structured/querying_test.py +4 -4
- {langfun-0.1.2.dev202510140804.dist-info → langfun-0.1.2.dev202510160805.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202510140804.dist-info → langfun-0.1.2.dev202510160805.dist-info}/RECORD +11 -11
- {langfun-0.1.2.dev202510140804.dist-info → langfun-0.1.2.dev202510160805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202510140804.dist-info → langfun-0.1.2.dev202510160805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202510140804.dist-info → langfun-0.1.2.dev202510160805.dist-info}/top_level.txt +0 -0
langfun/core/agentic/__init__.py
CHANGED
|
@@ -26,6 +26,10 @@ from langfun.core.agentic.action import ExecutionTrace
|
|
|
26
26
|
from langfun.core.agentic.action import ParallelExecutions
|
|
27
27
|
from langfun.core.agentic.action import Session
|
|
28
28
|
|
|
29
|
+
from langfun.core.agentic.action import SessionEventHandler
|
|
30
|
+
from langfun.core.agentic.action import SessionEventHandlerChain
|
|
31
|
+
from langfun.core.agentic.action import SessionLogging
|
|
32
|
+
|
|
29
33
|
from langfun.core.agentic.action_eval import ActionEval
|
|
30
34
|
from langfun.core.agentic.action_eval import ActionEvalV1
|
|
31
35
|
|
langfun/core/agentic/action.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import abc
|
|
17
17
|
import contextlib
|
|
18
|
+
import dataclasses
|
|
18
19
|
import functools
|
|
19
20
|
import threading
|
|
20
21
|
import time
|
|
@@ -1161,26 +1162,271 @@ class RootAction(Action):
|
|
|
1161
1162
|
raise NotImplementedError('Shall not be called.')
|
|
1162
1163
|
|
|
1163
1164
|
|
|
1165
|
+
class SessionEventHandler:
|
|
1166
|
+
"""Interface for handling session events."""
|
|
1167
|
+
|
|
1168
|
+
def on_session_start(
|
|
1169
|
+
self,
|
|
1170
|
+
session: 'Session'
|
|
1171
|
+
) -> None:
|
|
1172
|
+
"""Called when a session starts."""
|
|
1173
|
+
|
|
1174
|
+
def on_session_end(
|
|
1175
|
+
self,
|
|
1176
|
+
session: 'Session'
|
|
1177
|
+
) -> None:
|
|
1178
|
+
"""Called when a session ends."""
|
|
1179
|
+
|
|
1180
|
+
def on_action_start(
|
|
1181
|
+
self,
|
|
1182
|
+
session: 'Session',
|
|
1183
|
+
action: ActionInvocation
|
|
1184
|
+
) -> None:
|
|
1185
|
+
"""Called when an action starts."""
|
|
1186
|
+
|
|
1187
|
+
def on_action_end(
|
|
1188
|
+
self,
|
|
1189
|
+
session: 'Session',
|
|
1190
|
+
action: ActionInvocation
|
|
1191
|
+
) -> None:
|
|
1192
|
+
"""Called when an action ends."""
|
|
1193
|
+
|
|
1194
|
+
def on_action_progress(
|
|
1195
|
+
self,
|
|
1196
|
+
session: 'Session',
|
|
1197
|
+
action: ActionInvocation,
|
|
1198
|
+
title: str,
|
|
1199
|
+
**kwargs
|
|
1200
|
+
) -> None:
|
|
1201
|
+
"""Called when an action progress is updated."""
|
|
1202
|
+
|
|
1203
|
+
def on_query_start(
|
|
1204
|
+
self,
|
|
1205
|
+
session: 'Session',
|
|
1206
|
+
action: ActionInvocation,
|
|
1207
|
+
query: lf_structured.QueryInvocation,
|
|
1208
|
+
) -> None:
|
|
1209
|
+
"""Called when a query starts."""
|
|
1210
|
+
|
|
1211
|
+
def on_query_end(
|
|
1212
|
+
self,
|
|
1213
|
+
session: 'Session',
|
|
1214
|
+
action: ActionInvocation,
|
|
1215
|
+
query: lf_structured.QueryInvocation,
|
|
1216
|
+
) -> None:
|
|
1217
|
+
"""Called when a query ends."""
|
|
1218
|
+
|
|
1219
|
+
|
|
1220
|
+
@dataclasses.dataclass
|
|
1221
|
+
class SessionEventHandlerChain(SessionEventHandler):
|
|
1222
|
+
"""A session event handler that chains multiple event handlers."""
|
|
1223
|
+
|
|
1224
|
+
handlers: list[SessionEventHandler]
|
|
1225
|
+
|
|
1226
|
+
def on_session_start(self, session: 'Session') -> None:
|
|
1227
|
+
"""Called when a session starts."""
|
|
1228
|
+
for handler in self.handlers:
|
|
1229
|
+
handler.on_session_start(session)
|
|
1230
|
+
|
|
1231
|
+
def on_session_end(self, session: 'Session') -> None:
|
|
1232
|
+
"""Called when a session ends."""
|
|
1233
|
+
for handler in self.handlers:
|
|
1234
|
+
handler.on_session_end(session)
|
|
1235
|
+
|
|
1236
|
+
def on_action_start(
|
|
1237
|
+
self,
|
|
1238
|
+
session: 'Session',
|
|
1239
|
+
action: ActionInvocation) -> None:
|
|
1240
|
+
"""Called when an action starts."""
|
|
1241
|
+
for handler in self.handlers:
|
|
1242
|
+
handler.on_action_start(session, action)
|
|
1243
|
+
|
|
1244
|
+
def on_action_end(
|
|
1245
|
+
self,
|
|
1246
|
+
session: 'Session',
|
|
1247
|
+
action: ActionInvocation) -> None:
|
|
1248
|
+
"""Called when an action ends."""
|
|
1249
|
+
for handler in self.handlers:
|
|
1250
|
+
handler.on_action_end(session, action)
|
|
1251
|
+
|
|
1252
|
+
def on_action_progress(
|
|
1253
|
+
self,
|
|
1254
|
+
session: 'Session',
|
|
1255
|
+
action: ActionInvocation,
|
|
1256
|
+
title: str,
|
|
1257
|
+
**kwargs
|
|
1258
|
+
) -> None:
|
|
1259
|
+
"""Called when an action progress is updated."""
|
|
1260
|
+
for handler in self.handlers:
|
|
1261
|
+
handler.on_action_progress(session, action, title, **kwargs)
|
|
1262
|
+
|
|
1263
|
+
def on_query_start(
|
|
1264
|
+
self,
|
|
1265
|
+
session: 'Session',
|
|
1266
|
+
action: ActionInvocation,
|
|
1267
|
+
query: lf_structured.QueryInvocation,
|
|
1268
|
+
) -> None:
|
|
1269
|
+
"""Called when a query starts."""
|
|
1270
|
+
for handler in self.handlers:
|
|
1271
|
+
handler.on_query_start(session, action, query)
|
|
1272
|
+
|
|
1273
|
+
def on_query_end(
|
|
1274
|
+
self,
|
|
1275
|
+
session: 'Session',
|
|
1276
|
+
action: ActionInvocation,
|
|
1277
|
+
query: lf_structured.QueryInvocation,
|
|
1278
|
+
) -> None:
|
|
1279
|
+
"""Called when a query ends."""
|
|
1280
|
+
for handler in self.handlers:
|
|
1281
|
+
handler.on_query_end(session, action, query)
|
|
1282
|
+
|
|
1283
|
+
|
|
1284
|
+
@dataclasses.dataclass
|
|
1285
|
+
class SessionLogging(SessionEventHandler):
|
|
1286
|
+
"""An event handler that logs Session events."""
|
|
1287
|
+
|
|
1288
|
+
verbose: bool = False
|
|
1289
|
+
|
|
1290
|
+
def on_session_end(self, session: 'Session'):
|
|
1291
|
+
if session.has_error:
|
|
1292
|
+
session.error(
|
|
1293
|
+
f'Trajectory failed in {session.elapse:.2f} seconds.',
|
|
1294
|
+
error=session.final_error,
|
|
1295
|
+
metadata=session.root.metadata,
|
|
1296
|
+
keep=True,
|
|
1297
|
+
)
|
|
1298
|
+
elif self.verbose:
|
|
1299
|
+
session.info(
|
|
1300
|
+
f'Trajectory succeeded in {session.elapse:.2f} seconds.',
|
|
1301
|
+
result=session.final_result,
|
|
1302
|
+
metadata=session.root.metadata,
|
|
1303
|
+
keep=False,
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
def on_action_start(
|
|
1307
|
+
self,
|
|
1308
|
+
session: 'Session',
|
|
1309
|
+
action: ActionInvocation
|
|
1310
|
+
) -> None:
|
|
1311
|
+
if self.verbose:
|
|
1312
|
+
session.info(
|
|
1313
|
+
'Action execution started.',
|
|
1314
|
+
action=action.action,
|
|
1315
|
+
keep=False,
|
|
1316
|
+
)
|
|
1317
|
+
|
|
1318
|
+
def on_action_end(
|
|
1319
|
+
self,
|
|
1320
|
+
session: 'Session',
|
|
1321
|
+
action: ActionInvocation
|
|
1322
|
+
) -> None:
|
|
1323
|
+
if action.has_error:
|
|
1324
|
+
session.warning(
|
|
1325
|
+
(
|
|
1326
|
+
f'Action execution failed in '
|
|
1327
|
+
f'{action.execution.elapse:.2f} seconds.'
|
|
1328
|
+
),
|
|
1329
|
+
action=action.action,
|
|
1330
|
+
error=action.error,
|
|
1331
|
+
keep=True,
|
|
1332
|
+
)
|
|
1333
|
+
elif self.verbose:
|
|
1334
|
+
session.info(
|
|
1335
|
+
(
|
|
1336
|
+
f'Action execution succeeded in '
|
|
1337
|
+
f'{action.execution.elapse:.2f} seconds.'
|
|
1338
|
+
),
|
|
1339
|
+
action=action.action,
|
|
1340
|
+
result=action.result,
|
|
1341
|
+
keep=False,
|
|
1342
|
+
)
|
|
1343
|
+
|
|
1344
|
+
def on_query_start(
|
|
1345
|
+
self,
|
|
1346
|
+
session: 'Session',
|
|
1347
|
+
action: ActionInvocation,
|
|
1348
|
+
query: lf_structured.QueryInvocation,
|
|
1349
|
+
) -> None:
|
|
1350
|
+
if self.verbose:
|
|
1351
|
+
session.info(
|
|
1352
|
+
'Querying LLM started.',
|
|
1353
|
+
lm=query.lm.model_id,
|
|
1354
|
+
output_type=(
|
|
1355
|
+
lf_structured.annotation(query.schema.spec)
|
|
1356
|
+
if query.schema is not None else None
|
|
1357
|
+
),
|
|
1358
|
+
keep=False,
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
def on_query_end(
|
|
1362
|
+
self,
|
|
1363
|
+
session: 'Session',
|
|
1364
|
+
action: ActionInvocation,
|
|
1365
|
+
query: lf_structured.QueryInvocation,
|
|
1366
|
+
) -> None:
|
|
1367
|
+
if query.has_error:
|
|
1368
|
+
session.warning(
|
|
1369
|
+
(
|
|
1370
|
+
f'Querying LLM failed in '
|
|
1371
|
+
f'{time.time() - query.start_time:.2f} seconds.'
|
|
1372
|
+
),
|
|
1373
|
+
lm=query.lm.model_id,
|
|
1374
|
+
output_type=(
|
|
1375
|
+
lf_structured.annotation(query.schema.spec)
|
|
1376
|
+
if query.schema is not None else None
|
|
1377
|
+
),
|
|
1378
|
+
error=query.error,
|
|
1379
|
+
keep=True,
|
|
1380
|
+
)
|
|
1381
|
+
elif self.verbose:
|
|
1382
|
+
session.info(
|
|
1383
|
+
(
|
|
1384
|
+
f'Querying LLM succeeded in '
|
|
1385
|
+
f'{time.time() - query.start_time:.2f} seconds.'
|
|
1386
|
+
),
|
|
1387
|
+
lm=query.lm.model_id,
|
|
1388
|
+
output_type=(
|
|
1389
|
+
lf_structured.annotation(query.schema.spec)
|
|
1390
|
+
if query.schema is not None else None
|
|
1391
|
+
),
|
|
1392
|
+
keep=False,
|
|
1393
|
+
)
|
|
1394
|
+
|
|
1395
|
+
|
|
1164
1396
|
class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
1165
1397
|
"""Session for performing an agentic task."""
|
|
1166
1398
|
|
|
1167
1399
|
root: Annotated[
|
|
1168
1400
|
ActionInvocation,
|
|
1169
1401
|
'The root action invocation of the session.'
|
|
1170
|
-
]
|
|
1402
|
+
]
|
|
1171
1403
|
|
|
1172
1404
|
id: Annotated[
|
|
1173
1405
|
str | None,
|
|
1174
|
-
'An optional identifier for the
|
|
1175
|
-
]
|
|
1406
|
+
'An optional identifier for the session, which will be used for logging.'
|
|
1407
|
+
]
|
|
1176
1408
|
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1409
|
+
event_handler: Annotated[
|
|
1410
|
+
SessionEventHandler,
|
|
1411
|
+
'Event handler for the session.'
|
|
1412
|
+
]
|
|
1413
|
+
|
|
1414
|
+
@pg.explicit_method_override
|
|
1415
|
+
def __init__(
|
|
1416
|
+
self,
|
|
1417
|
+
id: str | None = None, # pylint: disable=redefined-builtin
|
|
1418
|
+
*,
|
|
1419
|
+
verbose: bool = False,
|
|
1420
|
+
event_handler: SessionEventHandler | None = None,
|
|
1421
|
+
root: ActionInvocation | None = None,
|
|
1422
|
+
**kwargs
|
|
1423
|
+
):
|
|
1424
|
+
super().__init__(
|
|
1425
|
+
id=id,
|
|
1426
|
+
root=root or ActionInvocation(RootAction()),
|
|
1427
|
+
event_handler=event_handler or SessionLogging(verbose=verbose),
|
|
1428
|
+
**kwargs
|
|
1429
|
+
)
|
|
1184
1430
|
|
|
1185
1431
|
#
|
|
1186
1432
|
# Shortcut methods for accessing the root action invocation.
|
|
@@ -1271,6 +1517,7 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
|
1271
1517
|
def start(self) -> None:
|
|
1272
1518
|
"""Starts the session."""
|
|
1273
1519
|
self.root.execution.start()
|
|
1520
|
+
self.event_handler.on_session_start(self)
|
|
1274
1521
|
|
|
1275
1522
|
def end(
|
|
1276
1523
|
self,
|
|
@@ -1279,21 +1526,8 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
|
1279
1526
|
metadata: dict[str, Any] | None = None,
|
|
1280
1527
|
) -> None:
|
|
1281
1528
|
"""Ends the session."""
|
|
1282
|
-
if error is not None:
|
|
1283
|
-
self.error(
|
|
1284
|
-
f'Trajectory failed in {self.elapse:.2f} seconds.',
|
|
1285
|
-
error=error,
|
|
1286
|
-
metadata=metadata,
|
|
1287
|
-
keep=True,
|
|
1288
|
-
)
|
|
1289
|
-
elif self.verbose:
|
|
1290
|
-
self.info(
|
|
1291
|
-
f'Trajectory succeeded in {self.elapse:.2f} seconds.',
|
|
1292
|
-
result=result,
|
|
1293
|
-
metadata=metadata,
|
|
1294
|
-
keep=False,
|
|
1295
|
-
)
|
|
1296
1529
|
self.root.end(result, error, metadata)
|
|
1530
|
+
self.event_handler.on_session_end(self)
|
|
1297
1531
|
|
|
1298
1532
|
def check_execution_time(self) -> None:
|
|
1299
1533
|
"""Checks the execution time of the current action."""
|
|
@@ -1312,6 +1546,20 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
|
1312
1546
|
'seconds.'
|
|
1313
1547
|
)
|
|
1314
1548
|
|
|
1549
|
+
def update_progress(self, title: str, **kwargs: Any) -> None:
|
|
1550
|
+
"""Update the progress of current action's execution.
|
|
1551
|
+
|
|
1552
|
+
Args:
|
|
1553
|
+
title: The title of the progress update.
|
|
1554
|
+
**kwargs: Additional keyword arguments to pass to the event handler.
|
|
1555
|
+
"""
|
|
1556
|
+
self.event_handler.on_action_progress(
|
|
1557
|
+
self,
|
|
1558
|
+
self._current_action,
|
|
1559
|
+
title,
|
|
1560
|
+
**kwargs
|
|
1561
|
+
)
|
|
1562
|
+
|
|
1315
1563
|
def __enter__(self):
|
|
1316
1564
|
"""Enters the session."""
|
|
1317
1565
|
self.start()
|
|
@@ -1371,34 +1619,10 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
|
1371
1619
|
self._current_execution = invocation.execution
|
|
1372
1620
|
# Start the execution of the current action.
|
|
1373
1621
|
self._current_action.start()
|
|
1374
|
-
|
|
1375
|
-
self.info(
|
|
1376
|
-
'Action execution started.',
|
|
1377
|
-
action=invocation.action,
|
|
1378
|
-
keep=False,
|
|
1379
|
-
)
|
|
1622
|
+
self.event_handler.on_action_start(self, self._current_action)
|
|
1380
1623
|
yield invocation
|
|
1381
1624
|
finally:
|
|
1382
|
-
|
|
1383
|
-
self.warning(
|
|
1384
|
-
(
|
|
1385
|
-
f'Action execution failed in '
|
|
1386
|
-
f'{invocation.execution.elapse:.2f} seconds.'
|
|
1387
|
-
),
|
|
1388
|
-
action=invocation.action,
|
|
1389
|
-
error=invocation.error,
|
|
1390
|
-
keep=True,
|
|
1391
|
-
)
|
|
1392
|
-
elif self.verbose:
|
|
1393
|
-
self.info(
|
|
1394
|
-
(
|
|
1395
|
-
f'Action execution succeeded in '
|
|
1396
|
-
f'{invocation.execution.elapse:.2f} seconds.'
|
|
1397
|
-
),
|
|
1398
|
-
action=invocation.action,
|
|
1399
|
-
result=invocation.result,
|
|
1400
|
-
keep=False,
|
|
1401
|
-
)
|
|
1625
|
+
self.event_handler.on_action_end(self, self._current_action)
|
|
1402
1626
|
self._current_execution = parent_execution
|
|
1403
1627
|
self._current_action = parent_action
|
|
1404
1628
|
|
|
@@ -1446,51 +1670,16 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
|
1446
1670
|
skip_notification=False, raise_on_no_change=False
|
|
1447
1671
|
)
|
|
1448
1672
|
execution.append(invocation)
|
|
1449
|
-
|
|
1450
|
-
self.info(
|
|
1451
|
-
'Querying LLM started.',
|
|
1452
|
-
lm=invocation.lm.model_id,
|
|
1453
|
-
output_type=(
|
|
1454
|
-
lf_structured.annotation(invocation.schema.spec)
|
|
1455
|
-
if invocation.schema is not None else None
|
|
1456
|
-
),
|
|
1457
|
-
keep=False,
|
|
1458
|
-
)
|
|
1673
|
+
self.event_handler.on_query_start(self, self._current_action, invocation)
|
|
1459
1674
|
|
|
1460
1675
|
def _query_end(invocation: lf_structured.QueryInvocation):
|
|
1461
1676
|
self._current_execution.merge_usage_summary(invocation.usage_summary)
|
|
1462
|
-
|
|
1463
|
-
self.warning(
|
|
1464
|
-
(
|
|
1465
|
-
f'Querying LLM failed in '
|
|
1466
|
-
f'{time.time() - invocation.start_time:.2f} seconds.'
|
|
1467
|
-
),
|
|
1468
|
-
lm=invocation.lm.model_id,
|
|
1469
|
-
output_type=(
|
|
1470
|
-
lf_structured.annotation(invocation.schema.spec)
|
|
1471
|
-
if invocation.schema is not None else None
|
|
1472
|
-
),
|
|
1473
|
-
error=invocation.error,
|
|
1474
|
-
keep=True,
|
|
1475
|
-
)
|
|
1476
|
-
elif self.verbose:
|
|
1477
|
-
self.info(
|
|
1478
|
-
(
|
|
1479
|
-
f'Querying LLM succeeded in '
|
|
1480
|
-
f'{time.time() - invocation.start_time:.2f} seconds.'
|
|
1481
|
-
),
|
|
1482
|
-
lm=invocation.lm.model_id,
|
|
1483
|
-
output_type=(
|
|
1484
|
-
lf_structured.annotation(invocation.schema.spec)
|
|
1485
|
-
if invocation.schema is not None else None
|
|
1486
|
-
),
|
|
1487
|
-
keep=False,
|
|
1488
|
-
)
|
|
1677
|
+
self.event_handler.on_query_end(self, self._current_action, invocation)
|
|
1489
1678
|
|
|
1490
1679
|
with self.track_phase(phase), lf_structured.track_queries(
|
|
1491
1680
|
include_child_scopes=False,
|
|
1492
|
-
|
|
1493
|
-
|
|
1681
|
+
start_callback=_query_start,
|
|
1682
|
+
end_callback=_query_end,
|
|
1494
1683
|
) as queries:
|
|
1495
1684
|
try:
|
|
1496
1685
|
yield queries
|
|
@@ -34,6 +34,7 @@ class Bar(action_lib.Action):
|
|
|
34
34
|
time.sleep(self.simulate_execution_time)
|
|
35
35
|
session.query('bar', lm=lm)
|
|
36
36
|
session.add_metadata(note='bar')
|
|
37
|
+
session.update_progress('Query completed')
|
|
37
38
|
if self.simulate_action_error:
|
|
38
39
|
raise ValueError('Bar error')
|
|
39
40
|
return 2 + pg.contextual_value('baz', 0)
|
|
@@ -128,7 +129,7 @@ class SessionTest(unittest.TestCase):
|
|
|
128
129
|
self.assertIsNone(foo.result)
|
|
129
130
|
self.assertIsNone(foo.metadata)
|
|
130
131
|
|
|
131
|
-
session = action_lib.Session(id='agent@1')
|
|
132
|
+
session = action_lib.Session(id='agent@1', verbose=True)
|
|
132
133
|
self.assertEqual(session.id, 'agent@1')
|
|
133
134
|
self.assertFalse(session.has_started)
|
|
134
135
|
self.assertFalse(session.has_stopped)
|
|
@@ -137,7 +138,7 @@ class SessionTest(unittest.TestCase):
|
|
|
137
138
|
_ = session.to_html()
|
|
138
139
|
|
|
139
140
|
with session:
|
|
140
|
-
result = foo(session, lm=lm
|
|
141
|
+
result = foo(session, lm=lm)
|
|
141
142
|
|
|
142
143
|
self.assertTrue(session.has_started)
|
|
143
144
|
self.assertTrue(session.has_stopped)
|
|
@@ -375,7 +376,7 @@ class SessionTest(unittest.TestCase):
|
|
|
375
376
|
self.assertFalse(session.has_stopped)
|
|
376
377
|
|
|
377
378
|
session.start()
|
|
378
|
-
result = foo(session, lm=lm
|
|
379
|
+
result = foo(session, lm=lm)
|
|
379
380
|
session.end(result)
|
|
380
381
|
|
|
381
382
|
self.assertTrue(session.has_started)
|
|
@@ -395,7 +396,7 @@ class SessionTest(unittest.TestCase):
|
|
|
395
396
|
session = action_lib.Session(id='agent@1')
|
|
396
397
|
with self.assertRaisesRegex(ValueError, 'Bar error'):
|
|
397
398
|
with session:
|
|
398
|
-
foo(session, lm=lm
|
|
399
|
+
foo(session, lm=lm)
|
|
399
400
|
self.assertTrue(session.has_started)
|
|
400
401
|
self.assertTrue(session.has_stopped)
|
|
401
402
|
self.assertTrue(session.has_error)
|
|
@@ -408,7 +409,7 @@ class SessionTest(unittest.TestCase):
|
|
|
408
409
|
foo = Foo(1, simulate_action_error=True)
|
|
409
410
|
session = action_lib.Session(id='agent@1')
|
|
410
411
|
with self.assertRaisesRegex(ValueError, 'Please call `Session.start'):
|
|
411
|
-
foo(session, lm=lm
|
|
412
|
+
foo(session, lm=lm)
|
|
412
413
|
|
|
413
414
|
def test_succeed_with_multiple_actions(self):
|
|
414
415
|
lm = fake.StaticResponse('lm response')
|
|
@@ -489,6 +490,33 @@ class SessionTest(unittest.TestCase):
|
|
|
489
490
|
):
|
|
490
491
|
foo(lm=lm, max_execution_time=1.0)
|
|
491
492
|
|
|
493
|
+
def test_event_handler(self):
|
|
494
|
+
|
|
495
|
+
class MyActionHandler(pg.Object, action_lib.SessionEventHandler):
|
|
496
|
+
def _on_bound(self):
|
|
497
|
+
super()._on_bound()
|
|
498
|
+
self.progresses = []
|
|
499
|
+
|
|
500
|
+
def on_action_progress(self, session, action, title, **kwargs):
|
|
501
|
+
self.progresses.append((action.id, title))
|
|
502
|
+
|
|
503
|
+
handler = MyActionHandler()
|
|
504
|
+
session = action_lib.Session(
|
|
505
|
+
id='agent@1',
|
|
506
|
+
event_handler=action_lib.SessionEventHandlerChain(
|
|
507
|
+
handlers=[handler, action_lib.SessionLogging()]
|
|
508
|
+
)
|
|
509
|
+
)
|
|
510
|
+
bar = Bar()
|
|
511
|
+
with session:
|
|
512
|
+
bar(session, lm=fake.StaticResponse('lm response'))
|
|
513
|
+
session.update_progress('Trajectory completed')
|
|
514
|
+
|
|
515
|
+
self.assertEqual(handler.progresses, [
|
|
516
|
+
('agent@1:/a1', 'Query completed'),
|
|
517
|
+
('agent@1:', 'Trajectory completed'),
|
|
518
|
+
])
|
|
519
|
+
|
|
492
520
|
def test_log(self):
|
|
493
521
|
session = action_lib.Session()
|
|
494
522
|
session.debug('hi', x=1, y=2)
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
import contextlib
|
|
15
15
|
import io
|
|
16
16
|
import os
|
|
17
|
+
import sys
|
|
17
18
|
import tempfile
|
|
18
19
|
import unittest
|
|
19
20
|
|
|
@@ -49,6 +50,7 @@ class TqdmProgressTrackerTest(unittest.TestCase):
|
|
|
49
50
|
string_io = io.StringIO()
|
|
50
51
|
with contextlib.redirect_stderr(string_io):
|
|
51
52
|
_ = experiment.run(root_dir, 'new', plugins=[])
|
|
53
|
+
sys.stderr.flush()
|
|
52
54
|
self.assertIn('All: 100%', string_io.getvalue())
|
|
53
55
|
|
|
54
56
|
def test_with_example_ids(self):
|
|
@@ -59,6 +61,7 @@ class TqdmProgressTrackerTest(unittest.TestCase):
|
|
|
59
61
|
string_io = io.StringIO()
|
|
60
62
|
with contextlib.redirect_stderr(string_io):
|
|
61
63
|
_ = experiment.run(root_dir, 'new', example_ids=[1], plugins=[])
|
|
64
|
+
sys.stderr.flush()
|
|
62
65
|
self.assertIn('All: 100%', string_io.getvalue())
|
|
63
66
|
|
|
64
67
|
|
|
@@ -1209,14 +1209,14 @@ class _QueryTracker:
|
|
|
1209
1209
|
)
|
|
1210
1210
|
] = True
|
|
1211
1211
|
|
|
1212
|
-
|
|
1212
|
+
start_callback: Annotated[
|
|
1213
1213
|
Callable[[QueryInvocation], None] | None,
|
|
1214
1214
|
(
|
|
1215
1215
|
'A callback function to be called when a query is started.'
|
|
1216
1216
|
)
|
|
1217
1217
|
] = None
|
|
1218
1218
|
|
|
1219
|
-
|
|
1219
|
+
end_callback: Annotated[
|
|
1220
1220
|
Callable[[QueryInvocation], None] | None,
|
|
1221
1221
|
(
|
|
1222
1222
|
'A callback function to be called when a query is completed.'
|
|
@@ -1232,21 +1232,21 @@ class _QueryTracker:
|
|
|
1232
1232
|
|
|
1233
1233
|
def track(self, invocation: QueryInvocation) -> None:
|
|
1234
1234
|
self.tracked_queries.append(invocation)
|
|
1235
|
-
if self.
|
|
1236
|
-
self.
|
|
1235
|
+
if self.start_callback is not None:
|
|
1236
|
+
self.start_callback(invocation)
|
|
1237
1237
|
|
|
1238
1238
|
def mark_completed(self, invocation: QueryInvocation) -> None:
|
|
1239
1239
|
assert invocation in self.tracked_queries, invocation
|
|
1240
|
-
if self.
|
|
1241
|
-
self.
|
|
1240
|
+
if self.end_callback is not None:
|
|
1241
|
+
self.end_callback(invocation)
|
|
1242
1242
|
|
|
1243
1243
|
|
|
1244
1244
|
@contextlib.contextmanager
|
|
1245
1245
|
def track_queries(
|
|
1246
1246
|
include_child_scopes: bool = True,
|
|
1247
1247
|
*,
|
|
1248
|
-
|
|
1249
|
-
|
|
1248
|
+
start_callback: Callable[[QueryInvocation], None] | None = None,
|
|
1249
|
+
end_callback: Callable[[QueryInvocation], None] | None = None,
|
|
1250
1250
|
) -> Iterator[list[QueryInvocation]]:
|
|
1251
1251
|
"""Track all queries made during the context.
|
|
1252
1252
|
|
|
@@ -1264,8 +1264,8 @@ def track_queries(
|
|
|
1264
1264
|
include_child_scopes: If True, the queries made in child scopes will be
|
|
1265
1265
|
included in the returned list. Otherwise, only the queries made in the
|
|
1266
1266
|
current scope will be included.
|
|
1267
|
-
|
|
1268
|
-
|
|
1267
|
+
start_callback: A callback function to be called when a query is started.
|
|
1268
|
+
end_callback: A callback function to be called when a query is completed.
|
|
1269
1269
|
|
|
1270
1270
|
Yields:
|
|
1271
1271
|
A list of `QueryInvocation` objects representing the queries made during
|
|
@@ -1274,8 +1274,8 @@ def track_queries(
|
|
|
1274
1274
|
trackers = lf.context_value('__query_trackers__', [])
|
|
1275
1275
|
tracker = _QueryTracker(
|
|
1276
1276
|
include_child_scopes=include_child_scopes,
|
|
1277
|
-
|
|
1278
|
-
|
|
1277
|
+
start_callback=start_callback,
|
|
1278
|
+
end_callback=end_callback
|
|
1279
1279
|
)
|
|
1280
1280
|
|
|
1281
1281
|
with lf.context(
|
|
@@ -1590,7 +1590,7 @@ class TrackQueriesTest(unittest.TestCase):
|
|
|
1590
1590
|
'bar',
|
|
1591
1591
|
])
|
|
1592
1592
|
state = {}
|
|
1593
|
-
def
|
|
1593
|
+
def start_callback(query):
|
|
1594
1594
|
self.assertFalse(query.is_completed)
|
|
1595
1595
|
self.assertIsNone(query.end_time)
|
|
1596
1596
|
elapse1 = query.elapse
|
|
@@ -1615,7 +1615,7 @@ class TrackQueriesTest(unittest.TestCase):
|
|
|
1615
1615
|
state['end'] = query
|
|
1616
1616
|
|
|
1617
1617
|
with querying.track_queries(
|
|
1618
|
-
|
|
1618
|
+
start_callback=start_callback, end_callback=end_callback
|
|
1619
1619
|
) as queries:
|
|
1620
1620
|
querying.query('foo', lm=lm)
|
|
1621
1621
|
self.assertIs(state['start'], queries[0])
|
|
@@ -1626,7 +1626,7 @@ class TrackQueriesTest(unittest.TestCase):
|
|
|
1626
1626
|
'bar',
|
|
1627
1627
|
])
|
|
1628
1628
|
state = {}
|
|
1629
|
-
def
|
|
1629
|
+
def start_callback(query):
|
|
1630
1630
|
self.assertFalse(query.is_completed)
|
|
1631
1631
|
self.assertIsNone(query.end_time)
|
|
1632
1632
|
self.assertIsNotNone(query.usage_summary)
|
|
@@ -1648,7 +1648,7 @@ class TrackQueriesTest(unittest.TestCase):
|
|
|
1648
1648
|
|
|
1649
1649
|
with self.assertRaises(mapping.MappingError):
|
|
1650
1650
|
with querying.track_queries(
|
|
1651
|
-
|
|
1651
|
+
start_callback=start_callback, end_callback=end_callback
|
|
1652
1652
|
) as queries:
|
|
1653
1653
|
querying.query('foo', int, lm=lm)
|
|
1654
1654
|
self.assertIs(state['start'], queries[0])
|
|
@@ -34,11 +34,11 @@ langfun/core/subscription.py,sha256=euawEuSZP-BHydaT-AQpfYFL0m5pWPGcW0upFhrojqc,
|
|
|
34
34
|
langfun/core/subscription_test.py,sha256=Y4ZdbZEwm83YNZBxHff0QR4QUa4rdaNXA3_jfIcArBo,8717
|
|
35
35
|
langfun/core/template.py,sha256=GSOZ3OcmRtw-q05GdHrE-Y4-2MGDYsTRKmWGvVPbdhE,24962
|
|
36
36
|
langfun/core/template_test.py,sha256=AQv_m9qE93WxhEhSlm1xaBgB4hu0UVtA53dljngkUW0,17090
|
|
37
|
-
langfun/core/agentic/__init__.py,sha256=
|
|
38
|
-
langfun/core/agentic/action.py,sha256=
|
|
37
|
+
langfun/core/agentic/__init__.py,sha256=vsnuvjaz9-nysBjdihGf43JC8AyLPhPJwIOevyONyAQ,1517
|
|
38
|
+
langfun/core/agentic/action.py,sha256=ojwaPIV_a_khKPR6x1Fk5i2dsUTSe3VjKaxnZ92b0nE,58243
|
|
39
39
|
langfun/core/agentic/action_eval.py,sha256=YTilyUEkJl_8FVMgdfO17PurWWaEJ6oA15CuefJJRLk,4887
|
|
40
40
|
langfun/core/agentic/action_eval_test.py,sha256=7AkOwNbUX-ZgR1R0a7bvUZ5abNTUV7blf_8Mnrwb-II,2811
|
|
41
|
-
langfun/core/agentic/action_test.py,sha256=
|
|
41
|
+
langfun/core/agentic/action_test.py,sha256=a2D4FOuob7MviuPZR2Wy6xNKnjlTLxhK8HUy8WIyt08,19076
|
|
42
42
|
langfun/core/coding/__init__.py,sha256=5utju_fwEsImaiftx4oXKl9FAM8p281k8-Esdh_-m1w,835
|
|
43
43
|
langfun/core/coding/python/__init__.py,sha256=yTXm92oLpQb4A-fZ2qy-bzfhPYND7B-oidtbv1PNaX0,1678
|
|
44
44
|
langfun/core/coding/python/correction.py,sha256=4PD76Xfv36Xrm8Ji3-GgGDNImtcDqWfMw3z6ircJMlM,7285
|
|
@@ -85,7 +85,7 @@ langfun/core/eval/v2/metrics_test.py,sha256=LibZXvWEJDVRY-Mza_bQT-SbmbXCHUnFhL7Z
|
|
|
85
85
|
langfun/core/eval/v2/progress.py,sha256=azZgssQgNdv3IgjKEaQBuGI5ucFDNbdi02P4z_nQ8GE,10292
|
|
86
86
|
langfun/core/eval/v2/progress_test.py,sha256=YU7VHzmy5knPZwj9vpBN3rQQH2tukj9eKHkuBCI62h8,2540
|
|
87
87
|
langfun/core/eval/v2/progress_tracking.py,sha256=zNhNPGlnJnHELEfFpbTMCSXFn8d1IJ57OOYkfFaBFfM,6097
|
|
88
|
-
langfun/core/eval/v2/progress_tracking_test.py,sha256=
|
|
88
|
+
langfun/core/eval/v2/progress_tracking_test.py,sha256=0d13LQyUKy1_bkscN0-vcBcQ36HNp89kgJ_N0jl2URM,2339
|
|
89
89
|
langfun/core/eval/v2/reporting.py,sha256=yUIPCAMnp7InIzpv1DDWrcLO-75iiOUTpscj7smkfrA,8335
|
|
90
90
|
langfun/core/eval/v2/reporting_test.py,sha256=CMK-vwho8cNRJwlbkCqm_v5fykE7Y3V6SaIOCY0CDyA,5671
|
|
91
91
|
langfun/core/eval/v2/runners.py,sha256=bEniZDNu44AQgvqpwLsvBU4V_7WltAe-NPhYgIsLj1E,16848
|
|
@@ -146,8 +146,8 @@ langfun/core/structured/mapping.py,sha256=1YBW8PKpJKXS7DKukfzKNioL84PrKUcB4KOUud
|
|
|
146
146
|
langfun/core/structured/mapping_test.py,sha256=OntYvfDitAf0tAnzQty3YS90vyEn6FY1Mi93r_ViEk8,9594
|
|
147
147
|
langfun/core/structured/parsing.py,sha256=bLi7o1AdyDWc6TwxmYg70t_oTgmLkJWBafakdF8n2RI,14195
|
|
148
148
|
langfun/core/structured/parsing_test.py,sha256=vRfCSzA9q7C1cElkAnDvbRepULEa_vclqDIv-heypDw,22745
|
|
149
|
-
langfun/core/structured/querying.py,sha256=
|
|
150
|
-
langfun/core/structured/querying_test.py,sha256=
|
|
149
|
+
langfun/core/structured/querying.py,sha256=P8miXfg6Q6HEvChnrLYUq5frIYREtduoTPTyQuysMGc,39649
|
|
150
|
+
langfun/core/structured/querying_test.py,sha256=ulRgKniUd-pMEWFKo_B5NY0pv_mVCSJ_-34QIGspZRo,50353
|
|
151
151
|
langfun/core/structured/schema.py,sha256=xtgrr3t5tcYQ2gi_fkTKz2IgDMf84gpiykmBdfnV6Io,29486
|
|
152
152
|
langfun/core/structured/schema_generation.py,sha256=pEWeTd8tQWYnEHukas6GVl4uGerLsQ2aNybtnm4Qgxc,5352
|
|
153
153
|
langfun/core/structured/schema_generation_test.py,sha256=RM9s71kMNg2jTePwInkiW9fK1ACN37eyPeF8OII-0zw,2950
|
|
@@ -181,8 +181,8 @@ langfun/env/event_handlers/event_logger.py,sha256=3dbPjBe53dBgntYHlyLlj_77hVecPS
|
|
|
181
181
|
langfun/env/event_handlers/event_logger_test.py,sha256=PGof3rPllNnyzs3Yp8kaOHLeTkVrzUgCJwlODTrVRKI,9111
|
|
182
182
|
langfun/env/event_handlers/metric_writer.py,sha256=NgJKsd6xWOtEd0IjYi7coGEaqGYkkPcDjXN9CQ3vxPU,18043
|
|
183
183
|
langfun/env/event_handlers/metric_writer_test.py,sha256=flRqK10wonhJk4idGD_8jjEjrfjgH0R-qcu-7Bj1G5s,5335
|
|
184
|
-
langfun-0.1.2.
|
|
185
|
-
langfun-0.1.2.
|
|
186
|
-
langfun-0.1.2.
|
|
187
|
-
langfun-0.1.2.
|
|
188
|
-
langfun-0.1.2.
|
|
184
|
+
langfun-0.1.2.dev202510160805.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
|
185
|
+
langfun-0.1.2.dev202510160805.dist-info/METADATA,sha256=au899IUDN58JgGfkb7QaynFP9i-qZruve58dFnL8b4g,7380
|
|
186
|
+
langfun-0.1.2.dev202510160805.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
187
|
+
langfun-0.1.2.dev202510160805.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
|
188
|
+
langfun-0.1.2.dev202510160805.dist-info/RECORD,,
|
|
File without changes
|
{langfun-0.1.2.dev202510140804.dist-info → langfun-0.1.2.dev202510160805.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{langfun-0.1.2.dev202510140804.dist-info → langfun-0.1.2.dev202510160805.dist-info}/top_level.txt
RENAMED
|
File without changes
|