langfun 0.1.2.dev202505130804__py3-none-any.whl → 0.1.2.dev202505150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/agentic/action.py +237 -108
- langfun/core/agentic/action_eval.py +4 -6
- langfun/core/agentic/action_test.py +15 -9
- langfun/core/coding/python/correction.py +4 -0
- langfun/core/console.py +6 -3
- langfun/core/language_model.py +4 -2
- langfun/core/llms/anthropic.py +4 -8
- langfun/core/llms/anthropic_test.py +38 -13
- langfun/core/llms/gemini.py +2 -2
- langfun/core/logging.py +3 -4
- langfun/core/structured/mapping.py +6 -0
- langfun/core/structured/querying.py +324 -91
- langfun/core/structured/querying_test.py +242 -2
- langfun/core/structured/schema.py +8 -0
- langfun/core/structured/schema_generation.py +1 -0
- langfun/core/structured/schema_test.py +6 -3
- {langfun-0.1.2.dev202505130804.dist-info → langfun-0.1.2.dev202505150805.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202505130804.dist-info → langfun-0.1.2.dev202505150805.dist-info}/RECORD +21 -21
- {langfun-0.1.2.dev202505130804.dist-info → langfun-0.1.2.dev202505150805.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202505130804.dist-info → langfun-0.1.2.dev202505150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202505130804.dist-info → langfun-0.1.2.dev202505150805.dist-info}/top_level.txt +0 -0
langfun/core/agentic/action.py
CHANGED
@@ -209,7 +209,7 @@ class Action(pg.Object):
|
|
209
209
|
) -> Any:
|
210
210
|
"""Executes the action."""
|
211
211
|
if session is None:
|
212
|
-
session = Session()
|
212
|
+
session = Session(verbose=verbose)
|
213
213
|
session.start()
|
214
214
|
|
215
215
|
if show_progress:
|
@@ -220,45 +220,13 @@ class Action(pg.Object):
|
|
220
220
|
else:
|
221
221
|
self._session = None
|
222
222
|
|
223
|
-
with session.track_action(self)
|
224
|
-
if verbose:
|
225
|
-
session.info('Action execution started.', keep=False, action=self)
|
226
|
-
|
223
|
+
with session.track_action(self):
|
227
224
|
try:
|
228
|
-
result = self.call(session=session,
|
225
|
+
result = self.call(session=session, **kwargs)
|
229
226
|
self._invocation.end(result)
|
230
|
-
if verbose:
|
231
|
-
session.info(
|
232
|
-
(
|
233
|
-
f'Action execution succeeded in '
|
234
|
-
f'{self._invocation.execution.elapse:.2f} seconds.'
|
235
|
-
),
|
236
|
-
keep=False,
|
237
|
-
result=result
|
238
|
-
)
|
239
227
|
except BaseException as e:
|
240
228
|
error = pg.utils.ErrorInfo.from_exception(e)
|
241
229
|
self._invocation.end(result=None, error=error)
|
242
|
-
if invocation.parent_action is session.root:
|
243
|
-
session.error(
|
244
|
-
(
|
245
|
-
f'Top-level action execution failed in '
|
246
|
-
f'{self._invocation.execution.elapse:.2f} seconds.'
|
247
|
-
),
|
248
|
-
keep=True,
|
249
|
-
action=self,
|
250
|
-
error=error
|
251
|
-
)
|
252
|
-
else:
|
253
|
-
session.warning(
|
254
|
-
(
|
255
|
-
f'Action execution failed in '
|
256
|
-
f'{self._invocation.execution.elapse:.2f} seconds.'
|
257
|
-
),
|
258
|
-
keep=False,
|
259
|
-
action=self,
|
260
|
-
error=error
|
261
|
-
)
|
262
230
|
if self._session is not None:
|
263
231
|
self._session.end(result=None, error=error)
|
264
232
|
raise
|
@@ -477,21 +445,26 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
477
445
|
def __getitem__(self, index: int) -> TracedItem:
|
478
446
|
return self.items[index]
|
479
447
|
|
448
|
+
def merge_usage_summary(self, usage_summary: lf.UsageSummary) -> None:
|
449
|
+
if usage_summary.total.num_requests == 0:
|
450
|
+
return
|
451
|
+
current_invocation = self
|
452
|
+
while current_invocation is not None:
|
453
|
+
current_invocation.usage_summary.merge(usage_summary)
|
454
|
+
current_invocation = typing.cast(
|
455
|
+
ExecutionTrace,
|
456
|
+
current_invocation.sym_ancestor(
|
457
|
+
lambda x: isinstance(x, ExecutionTrace)
|
458
|
+
)
|
459
|
+
)
|
460
|
+
|
480
461
|
def append(self, item: TracedItem) -> None:
|
481
462
|
"""Appends an item to the sequence."""
|
482
463
|
with pg.notify_on_change(False):
|
483
464
|
self.items.append(item)
|
484
465
|
|
485
466
|
if isinstance(item, lf_structured.QueryInvocation):
|
486
|
-
|
487
|
-
while current_invocation is not None:
|
488
|
-
current_invocation.usage_summary.merge(item.usage_summary)
|
489
|
-
current_invocation = typing.cast(
|
490
|
-
ExecutionTrace,
|
491
|
-
current_invocation.sym_ancestor(
|
492
|
-
lambda x: isinstance(x, ExecutionTrace)
|
493
|
-
)
|
494
|
-
)
|
467
|
+
self.merge_usage_summary(item.usage_summary)
|
495
468
|
|
496
469
|
if self._tab_control is not None:
|
497
470
|
self._tab_control.append(self._execution_item_tab(item))
|
@@ -519,15 +492,46 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
519
492
|
def execution_summary(self) -> dict[str, Any]:
|
520
493
|
"""Execution summary string."""
|
521
494
|
return pg.Dict(
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
495
|
+
subtree=dict(
|
496
|
+
num_actions=len(self.all_actions),
|
497
|
+
num_action_failures=len([
|
498
|
+
a for a in self.all_actions if a.has_error
|
499
|
+
]),
|
500
|
+
num_queries=len(self.all_queries),
|
501
|
+
num_oop_failures=len([
|
502
|
+
q for q in self.all_queries if q.has_oop_error
|
503
|
+
]),
|
504
|
+
num_non_oop_failures=len([
|
505
|
+
q for q in self.all_queries
|
506
|
+
if q.has_error and not q.has_oop_error
|
507
|
+
]),
|
508
|
+
total_query_time=sum(q.elapse for q in self.all_queries),
|
509
|
+
),
|
510
|
+
current_level=dict(
|
511
|
+
num_actions=len(self.actions),
|
512
|
+
num_action_failures=len([
|
513
|
+
a for a in self.actions if a.has_error
|
514
|
+
]),
|
515
|
+
num_queries=len(self.queries),
|
516
|
+
num_oop_failures=len([
|
517
|
+
q for q in self.queries if q.has_oop_error
|
518
|
+
]),
|
519
|
+
num_non_oop_failures=len([
|
520
|
+
q for q in self.queries
|
521
|
+
if q.has_error and not q.has_oop_error
|
522
|
+
]),
|
523
|
+
execution_breakdown=[
|
524
|
+
dict(
|
525
|
+
action=action.action.__class__.__name__,
|
526
|
+
usage=dict(
|
527
|
+
total_tokens=action.usage_summary.total.total_tokens,
|
528
|
+
estimated_cost=action.usage_summary.total.estimated_cost,
|
529
|
+
),
|
530
|
+
execution_time=action.execution.elapse,
|
531
|
+
)
|
532
|
+
for action in self.actions
|
533
|
+
]
|
534
|
+
)
|
531
535
|
)
|
532
536
|
|
533
537
|
#
|
@@ -894,6 +898,11 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
894
898
|
"""Returns the usage summary of the action."""
|
895
899
|
return self.execution.usage_summary
|
896
900
|
|
901
|
+
@property
|
902
|
+
def elapse(self) -> float:
|
903
|
+
"""Returns the elapsed time of the action."""
|
904
|
+
return self.execution.elapse
|
905
|
+
|
897
906
|
def start(self) -> None:
|
898
907
|
"""Starts the execution of the action."""
|
899
908
|
self.execution.start()
|
@@ -1025,9 +1034,6 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
1025
1034
|
self.usage_summary.to_html( # pylint: disable=g-long-ternary
|
1026
1035
|
extra_flags=dict(as_badge=True)
|
1027
1036
|
)
|
1028
|
-
if (interactive
|
1029
|
-
or self.usage_summary.total.num_requests > 0)
|
1030
|
-
else None
|
1031
1037
|
),
|
1032
1038
|
],
|
1033
1039
|
css_classes=['execution-tab-title']
|
@@ -1069,12 +1075,78 @@ class RootAction(Action):
|
|
1069
1075
|
class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
1070
1076
|
"""Session for performing an agentic task."""
|
1071
1077
|
|
1072
|
-
root:
|
1078
|
+
root: Annotated[
|
1079
|
+
ActionInvocation,
|
1080
|
+
'The root action invocation of the session.'
|
1081
|
+
] = ActionInvocation(RootAction())
|
1082
|
+
|
1073
1083
|
id: Annotated[
|
1074
1084
|
str | None,
|
1075
1085
|
'An optional identifier for the sessin, which will be used for logging.'
|
1076
1086
|
] = None
|
1077
1087
|
|
1088
|
+
verbose: Annotated[
|
1089
|
+
bool,
|
1090
|
+
(
|
1091
|
+
'If True, the session will be logged with verbose action and query '
|
1092
|
+
'activities.'
|
1093
|
+
)
|
1094
|
+
] = False
|
1095
|
+
|
1096
|
+
#
|
1097
|
+
# Shortcut methods for accessing the root action invocation.
|
1098
|
+
#
|
1099
|
+
|
1100
|
+
@property
|
1101
|
+
def all_queries(self) -> list[lf_structured.QueryInvocation]:
|
1102
|
+
"""Returns all queries made by the session."""
|
1103
|
+
return self.root.all_queries
|
1104
|
+
|
1105
|
+
@property
|
1106
|
+
def all_actions(self) -> list[ActionInvocation]:
|
1107
|
+
"""Returns all actions made by the session."""
|
1108
|
+
return self.root.all_actions
|
1109
|
+
|
1110
|
+
@property
|
1111
|
+
def all_logs(self) -> list[lf.logging.LogEntry]:
|
1112
|
+
"""Returns all logs made by the session."""
|
1113
|
+
return self.root.all_logs
|
1114
|
+
|
1115
|
+
@property
|
1116
|
+
def usage_summary(self) -> lf.UsageSummary:
|
1117
|
+
"""Returns the usage summary of the session."""
|
1118
|
+
return self.root.usage_summary
|
1119
|
+
|
1120
|
+
@property
|
1121
|
+
def has_started(self) -> bool:
|
1122
|
+
"""Returns True if the session has started."""
|
1123
|
+
return self.root.execution.has_started
|
1124
|
+
|
1125
|
+
@property
|
1126
|
+
def has_stopped(self) -> bool:
|
1127
|
+
"""Returns True if the session has stopped."""
|
1128
|
+
return self.root.execution.has_stopped
|
1129
|
+
|
1130
|
+
@property
|
1131
|
+
def has_error(self) -> bool:
|
1132
|
+
"""Returns True if the session has an error."""
|
1133
|
+
return self.root.has_error
|
1134
|
+
|
1135
|
+
@property
|
1136
|
+
def final_result(self) -> Any:
|
1137
|
+
"""Returns the final result of the session."""
|
1138
|
+
return self.root.result
|
1139
|
+
|
1140
|
+
@property
|
1141
|
+
def final_error(self) -> pg.utils.ErrorInfo | None:
|
1142
|
+
"""Returns the error of the session."""
|
1143
|
+
return self.root.error
|
1144
|
+
|
1145
|
+
@property
|
1146
|
+
def elapse(self) -> float:
|
1147
|
+
"""Returns the elapsed time of the session."""
|
1148
|
+
return self.root.elapse
|
1149
|
+
|
1078
1150
|
# NOTE(daiyip): Action execution may involve multi-threading, hence current
|
1079
1151
|
# action and execution are thread-local.
|
1080
1152
|
|
@@ -1118,6 +1190,20 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
1118
1190
|
metadata: dict[str, Any] | None = None,
|
1119
1191
|
) -> None:
|
1120
1192
|
"""Ends the session."""
|
1193
|
+
if error is not None:
|
1194
|
+
self.error(
|
1195
|
+
f'Trajectory failed in {self.elapse:.2f} seconds.',
|
1196
|
+
error=error,
|
1197
|
+
metadata=metadata,
|
1198
|
+
keep=True,
|
1199
|
+
)
|
1200
|
+
elif self.verbose:
|
1201
|
+
self.info(
|
1202
|
+
f'Trajectory succeeded in {self.elapse:.2f} seconds.',
|
1203
|
+
result=result,
|
1204
|
+
metadata=metadata,
|
1205
|
+
keep=False,
|
1206
|
+
)
|
1121
1207
|
self.root.end(result, error, metadata)
|
1122
1208
|
|
1123
1209
|
def __enter__(self):
|
@@ -1169,8 +1255,34 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
1169
1255
|
self._current_execution = invocation.execution
|
1170
1256
|
# Start the execution of the current action.
|
1171
1257
|
self._current_action.start()
|
1258
|
+
if self.verbose:
|
1259
|
+
self.info(
|
1260
|
+
'Action execution started.',
|
1261
|
+
action=invocation.action,
|
1262
|
+
keep=False,
|
1263
|
+
)
|
1172
1264
|
yield invocation
|
1173
1265
|
finally:
|
1266
|
+
if invocation.has_error:
|
1267
|
+
self.warning(
|
1268
|
+
(
|
1269
|
+
f'Action execution failed in '
|
1270
|
+
f'{invocation.execution.elapse:.2f} seconds.'
|
1271
|
+
),
|
1272
|
+
action=invocation.action,
|
1273
|
+
error=invocation.error,
|
1274
|
+
keep=True,
|
1275
|
+
)
|
1276
|
+
elif self.verbose:
|
1277
|
+
self.info(
|
1278
|
+
(
|
1279
|
+
f'Action execution succeeded in '
|
1280
|
+
f'{invocation.execution.elapse:.2f} seconds.'
|
1281
|
+
),
|
1282
|
+
action=invocation.action,
|
1283
|
+
result=invocation.result,
|
1284
|
+
keep=False,
|
1285
|
+
)
|
1174
1286
|
self._current_execution = parent_execution
|
1175
1287
|
self._current_action = parent_action
|
1176
1288
|
|
@@ -1208,18 +1320,63 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
1208
1320
|
A list of `lf.QueryInvocation` objects, each for a single `lf.query`
|
1209
1321
|
call.
|
1210
1322
|
"""
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1323
|
+
def _query_start(invocation: lf_structured.QueryInvocation):
|
1324
|
+
execution = self._current_execution
|
1325
|
+
invocation.rebind(
|
1326
|
+
id=f'{execution.id}/q{len(execution.queries) + 1}',
|
1327
|
+
skip_notification=False, raise_on_no_change=False
|
1328
|
+
)
|
1329
|
+
execution.append(invocation)
|
1330
|
+
if self.verbose:
|
1331
|
+
self.info(
|
1332
|
+
'Querying LLM started.',
|
1333
|
+
lm=invocation.lm.model_id,
|
1334
|
+
output_type=(
|
1335
|
+
lf_structured.annotation(invocation.schema.spec)
|
1336
|
+
if invocation.schema is not None else None
|
1337
|
+
),
|
1338
|
+
keep=False,
|
1339
|
+
)
|
1340
|
+
|
1341
|
+
def _query_end(invocation: lf_structured.QueryInvocation):
|
1342
|
+
self._current_execution.merge_usage_summary(invocation.usage_summary)
|
1343
|
+
if invocation.has_error:
|
1344
|
+
self.warning(
|
1345
|
+
(
|
1346
|
+
f'Querying LLM failed in '
|
1347
|
+
f'{time.time() - invocation.start_time:.2f} seconds.'
|
1348
|
+
),
|
1349
|
+
lm=invocation.lm.model_id,
|
1350
|
+
output_type=(
|
1351
|
+
lf_structured.annotation(invocation.schema.spec)
|
1352
|
+
if invocation.schema is not None else None
|
1353
|
+
),
|
1354
|
+
error=invocation.error,
|
1355
|
+
keep=True,
|
1356
|
+
)
|
1357
|
+
elif self.verbose:
|
1358
|
+
self.info(
|
1359
|
+
(
|
1360
|
+
f'Querying LLM succeeded in '
|
1361
|
+
f'{time.time() - invocation.start_time:.2f} seconds.'
|
1362
|
+
),
|
1363
|
+
lm=invocation.lm.model_id,
|
1364
|
+
output_type=(
|
1365
|
+
lf_structured.annotation(invocation.schema.spec)
|
1366
|
+
if invocation.schema is not None else None
|
1367
|
+
),
|
1368
|
+
keep=False,
|
1369
|
+
)
|
1370
|
+
|
1371
|
+
with self.track_phase(phase), lf_structured.track_queries(
|
1372
|
+
include_child_scopes=False,
|
1373
|
+
start_callabck=_query_start,
|
1374
|
+
end_callabck=_query_end,
|
1375
|
+
) as queries:
|
1376
|
+
try:
|
1377
|
+
yield queries
|
1378
|
+
finally:
|
1379
|
+
pass
|
1223
1380
|
|
1224
1381
|
#
|
1225
1382
|
# Operations with activity tracking.
|
@@ -1272,24 +1429,14 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
1272
1429
|
The result of the query.
|
1273
1430
|
"""
|
1274
1431
|
with self.track_queries():
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
**kwargs
|
1284
|
-
)
|
1285
|
-
except BaseException as e:
|
1286
|
-
elapse = time.time() - start_time
|
1287
|
-
self.warning(
|
1288
|
-
f'Failed to query LLM ({lm.model_id}) in {elapse:.2f} seconds.',
|
1289
|
-
error=pg.utils.ErrorInfo.from_exception(e),
|
1290
|
-
keep=False,
|
1291
|
-
)
|
1292
|
-
raise
|
1432
|
+
return lf_structured.query(
|
1433
|
+
prompt,
|
1434
|
+
schema=schema,
|
1435
|
+
default=default,
|
1436
|
+
lm=lm,
|
1437
|
+
examples=examples,
|
1438
|
+
**kwargs
|
1439
|
+
)
|
1293
1440
|
|
1294
1441
|
def concurrent_map(
|
1295
1442
|
self,
|
@@ -1437,7 +1584,9 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
1437
1584
|
for_action = self._current_action
|
1438
1585
|
elif isinstance(for_action, Action):
|
1439
1586
|
for_action = for_action.invocation
|
1440
|
-
assert for_action is not None
|
1587
|
+
assert for_action is not None, (
|
1588
|
+
f'Action must be called before it can be logged: {for_action}'
|
1589
|
+
)
|
1441
1590
|
|
1442
1591
|
log_entry = lf.logging.log(
|
1443
1592
|
level,
|
@@ -1522,26 +1671,6 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
1522
1671
|
result=self.root
|
1523
1672
|
)
|
1524
1673
|
|
1525
|
-
@property
|
1526
|
-
def final_result(self) -> Any:
|
1527
|
-
"""Returns the final result of the session."""
|
1528
|
-
return self.root.result
|
1529
|
-
|
1530
|
-
@property
|
1531
|
-
def has_started(self) -> bool:
|
1532
|
-
"""Returns whether the session has started."""
|
1533
|
-
return self.root.execution.has_started
|
1534
|
-
|
1535
|
-
@property
|
1536
|
-
def has_stopped(self) -> bool:
|
1537
|
-
"""Returns whether the session has stopped."""
|
1538
|
-
return self.root.execution.has_stopped
|
1539
|
-
|
1540
|
-
@property
|
1541
|
-
def has_error(self) -> bool:
|
1542
|
-
"""Returns whether the session has an error."""
|
1543
|
-
return self.root.has_error
|
1544
|
-
|
1545
1674
|
@property
|
1546
1675
|
def current_action(self) -> ActionInvocation:
|
1547
1676
|
"""Returns the current invocation."""
|
@@ -36,17 +36,15 @@ class ActionEval(lf.eval.v2.Evaluation):
|
|
36
36
|
action = example_input.action
|
37
37
|
|
38
38
|
# We explicitly create a session here to use a custom session ID.
|
39
|
-
with action_lib.Session(
|
39
|
+
with action_lib.Session(
|
40
|
+
id=f'{self.id}#example-{example.id}', verbose=True
|
41
|
+
) as session:
|
40
42
|
|
41
43
|
# NOTE(daiyip): Setting session as metadata before action execution, so we
|
42
44
|
# could use `Evaluation.state.in_progress_examples` to access the session
|
43
45
|
# for status reporting from other threads.
|
44
46
|
example.metadata['session'] = session
|
45
|
-
|
46
|
-
with lf.logging.use_log_level('fatal'):
|
47
|
-
kwargs = self.action_args.copy()
|
48
|
-
kwargs.update(verbose=True)
|
49
|
-
action(session=session, **kwargs)
|
47
|
+
action(session=session, **self.action_args)
|
50
48
|
|
51
49
|
return session.final_result, dict(session=session)
|
52
50
|
|
@@ -155,15 +155,19 @@ class SessionTest(unittest.TestCase):
|
|
155
155
|
)
|
156
156
|
|
157
157
|
# The root space should have one action (foo), no queries, and no logs.
|
158
|
-
self.assertEqual(len(
|
159
|
-
self.assertEqual(len(
|
160
|
-
self.assertEqual(len(
|
158
|
+
self.assertEqual(len(root.actions), 1)
|
159
|
+
self.assertEqual(len(root.queries), 0)
|
160
|
+
self.assertEqual(len(root.logs), 0)
|
161
161
|
# 1 query from Bar, 2 from Foo and 3 from parallel executions.
|
162
|
-
self.assertEqual(len(
|
162
|
+
self.assertEqual(len(session.all_queries), 6)
|
163
|
+
self.assertEqual(len(root.all_queries), 6)
|
163
164
|
# 2 actions: Foo and Bar.
|
164
|
-
self.assertEqual(len(
|
165
|
+
self.assertEqual(len(session.all_actions), 2)
|
166
|
+
self.assertEqual(len(root.all_actions), 2)
|
165
167
|
# 1 log from Bar and 1 from Foo.
|
166
|
-
self.assertEqual(len(
|
168
|
+
self.assertEqual(len(session.all_logs), 2)
|
169
|
+
self.assertEqual(len(root.all_logs), 2)
|
170
|
+
self.assertIs(session.usage_summary, root.usage_summary)
|
167
171
|
self.assertEqual(root.usage_summary.total.num_requests, 6)
|
168
172
|
|
169
173
|
# Inspecting the top-level action (Foo)
|
@@ -276,7 +280,7 @@ class SessionTest(unittest.TestCase):
|
|
276
280
|
foo_invocation = root.execution[0]
|
277
281
|
self.assertIsInstance(foo_invocation, action_lib.ActionInvocation)
|
278
282
|
self.assertTrue(foo_invocation.has_error)
|
279
|
-
self.assertEqual(len(foo_invocation.execution.items),
|
283
|
+
self.assertEqual(len(foo_invocation.execution.items), 3)
|
280
284
|
|
281
285
|
def test_succeeded_with_implicit_session(self):
|
282
286
|
lm = fake.StaticResponse('lm response')
|
@@ -304,7 +308,7 @@ class SessionTest(unittest.TestCase):
|
|
304
308
|
self.assertTrue(session.has_started)
|
305
309
|
self.assertTrue(session.has_stopped)
|
306
310
|
self.assertTrue(session.has_error)
|
307
|
-
self.assertIsInstance(session.
|
311
|
+
self.assertIsInstance(session.final_error, pg.utils.ErrorInfo)
|
308
312
|
self.assertIn('Bar error', str(session.root.error))
|
309
313
|
|
310
314
|
def test_succeeded_with_explicit_session(self):
|
@@ -409,7 +413,9 @@ class SessionTest(unittest.TestCase):
|
|
409
413
|
self.assertTrue(session.has_stopped)
|
410
414
|
self.assertTrue(session.has_error)
|
411
415
|
self.assertIsInstance(session.root.error, pg.utils.ErrorInfo)
|
412
|
-
self.assertEqual(len(session.root.execution),
|
416
|
+
self.assertEqual(len(session.root.execution), 3)
|
417
|
+
self.assertEqual(len(session.root.actions), 2)
|
418
|
+
self.assertEqual(len(session.root.logs), 1)
|
413
419
|
self.assertFalse(session.root.execution[0].has_error)
|
414
420
|
self.assertTrue(session.root.execution[1].has_error)
|
415
421
|
|
@@ -37,6 +37,7 @@ def run_with_correction(
|
|
37
37
|
lm: lf.LanguageModel | None = None,
|
38
38
|
max_attempts: int = 5,
|
39
39
|
sandbox: bool | None = None,
|
40
|
+
permission: pg.coding.CodePermission = pg.coding.CodePermission.ALL,
|
40
41
|
timeout: int | None = 5,
|
41
42
|
returns_code: bool = False,
|
42
43
|
returns_stdout: bool = False,
|
@@ -58,6 +59,7 @@ def run_with_correction(
|
|
58
59
|
process. If None, run in sandbox first, if the output could not be
|
59
60
|
serialized and pass to current process, run the code again in current
|
60
61
|
process.
|
62
|
+
permission: The permission to run the code.
|
61
63
|
timeout: The timeout for running the corrected code. If None, there is no
|
62
64
|
timeout. Applicable only when sandbox is set to True.
|
63
65
|
returns_code: If True, the return value is a tuple of (result, final code).
|
@@ -88,6 +90,7 @@ def run_with_correction(
|
|
88
90
|
global_vars=global_vars,
|
89
91
|
sandbox=sandbox,
|
90
92
|
timeout=timeout,
|
93
|
+
permission=permission,
|
91
94
|
returns_stdout=returns_stdout,
|
92
95
|
outputs_intermediate=outputs_intermediate,
|
93
96
|
)
|
@@ -102,6 +105,7 @@ def run_with_correction(
|
|
102
105
|
global_vars=global_vars,
|
103
106
|
sandbox=sandbox,
|
104
107
|
timeout=timeout,
|
108
|
+
permission=permission,
|
105
109
|
outputs_intermediate=outputs_intermediate,
|
106
110
|
)
|
107
111
|
)
|
langfun/core/console.py
CHANGED
@@ -52,10 +52,13 @@ def write(
|
|
52
52
|
)
|
53
53
|
|
54
54
|
|
55
|
+
_notebook = None
|
55
56
|
try:
|
56
|
-
|
57
|
-
|
58
|
-
|
57
|
+
ipython_module = sys.modules['IPython']
|
58
|
+
if 'IPKernelApp' in ipython_module.get_ipython().config:
|
59
|
+
_notebook = ipython_module.display
|
60
|
+
except (KeyError, AttributeError): # pylint: disable=broad-except
|
61
|
+
pass
|
59
62
|
|
60
63
|
|
61
64
|
def under_notebook() -> bool:
|
langfun/core/language_model.py
CHANGED
@@ -1453,7 +1453,8 @@ class UsageSummary(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
1453
1453
|
self._usage_badge.update(
|
1454
1454
|
self._badge_text(),
|
1455
1455
|
tooltip=pg.format(
|
1456
|
-
self, verbose=False, custom_format=self._tooltip_format
|
1456
|
+
self, verbose=False, custom_format=self._tooltip_format,
|
1457
|
+
hide_default_values=True,
|
1457
1458
|
),
|
1458
1459
|
styles=dict(color=self._badge_color()),
|
1459
1460
|
)
|
@@ -1500,7 +1501,8 @@ class UsageSummary(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
1500
1501
|
usage_badge = pg.views.html.controls.Badge(
|
1501
1502
|
self._badge_text(),
|
1502
1503
|
tooltip=pg.format(
|
1503
|
-
self, custom_format=self._tooltip_format, verbose=False
|
1504
|
+
self, custom_format=self._tooltip_format, verbose=False,
|
1505
|
+
hide_default_values=True,
|
1504
1506
|
),
|
1505
1507
|
css_classes=['usage-summary'],
|
1506
1508
|
styles=dict(color=self._badge_color()),
|
langfun/core/llms/anthropic.py
CHANGED
@@ -509,17 +509,13 @@ class Anthropic(rest.REST):
|
|
509
509
|
raise ValueError(f'Unsupported modality: {chunk!r}.')
|
510
510
|
return chunk
|
511
511
|
|
512
|
-
messages = []
|
513
512
|
if system_message := prompt.get('system_message'):
|
514
513
|
assert isinstance(system_message, lf.SystemMessage), type(system_message)
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
)
|
519
|
-
)
|
520
|
-
messages.append(
|
514
|
+
request['system'] = system_message.text
|
515
|
+
|
516
|
+
messages = [
|
521
517
|
prompt.as_format('anthropic', chunk_preprocessor=modality_check)
|
522
|
-
|
518
|
+
]
|
523
519
|
request.update(messages=messages)
|
524
520
|
return request
|
525
521
|
|
@@ -31,21 +31,46 @@ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
|
31
31
|
|
32
32
|
response = requests.Response()
|
33
33
|
response.status_code = 200
|
34
|
+
|
35
|
+
# Construct base text from user/assistant messages payload
|
36
|
+
messages_payload_text = '\n'.join(
|
37
|
+
c['content'][0]['text']
|
38
|
+
for c in json.get('messages', [])
|
39
|
+
if c.get('content')
|
40
|
+
and isinstance(c['content'], list)
|
41
|
+
and c['content']
|
42
|
+
and c['content'][0].get('type') == 'text'
|
43
|
+
and 'text' in c['content'][0]
|
44
|
+
)
|
45
|
+
|
46
|
+
# Check for a system prompt in the request payload
|
47
|
+
system_prompt_text = json.get('system')
|
48
|
+
|
49
|
+
processed_text_parts = []
|
50
|
+
if system_prompt_text:
|
51
|
+
processed_text_parts.append(system_prompt_text)
|
52
|
+
if messages_payload_text:
|
53
|
+
processed_text_parts.append(messages_payload_text)
|
54
|
+
|
55
|
+
processed_text = '\n'.join(processed_text_parts)
|
56
|
+
|
57
|
+
response_content_text = (
|
58
|
+
f'{processed_text} with temperature={json.get("temperature")}, '
|
59
|
+
f'top_k={json.get("top_k")}, '
|
60
|
+
f'top_p={json.get("top_p")}, '
|
61
|
+
f'max_tokens={json.get("max_tokens")}, '
|
62
|
+
f'stop={json.get("stop_sequences")}.'
|
63
|
+
)
|
64
|
+
|
34
65
|
response._content = pg.to_json_str({
|
35
|
-
'content': [{
|
36
|
-
'type': 'text',
|
37
|
-
'text': (
|
38
|
-
'\n'.join(c['content'][0]['text'] for c in json['messages']) +
|
39
|
-
f' with temperature={json.get("temperature")}, '
|
40
|
-
f'top_k={json.get("top_k")}, '
|
41
|
-
f'top_p={json.get("top_p")}, '
|
42
|
-
f'max_tokens={json.get("max_tokens")}, '
|
43
|
-
f'stop={json.get("stop_sequences")}.'
|
44
|
-
),
|
45
|
-
}],
|
66
|
+
'content': [{'type': 'text', 'text': response_content_text}],
|
46
67
|
'usage': {
|
47
|
-
'input_tokens':
|
48
|
-
|
68
|
+
'input_tokens': (
|
69
|
+
2
|
70
|
+
), # Placeholder: adjust if tests need accurate token counts
|
71
|
+
'output_tokens': (
|
72
|
+
1
|
73
|
+
), # Placeholder: adjust if tests need accurate token counts
|
49
74
|
},
|
50
75
|
}).encode()
|
51
76
|
return response
|