langroid 0.10.2__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/__init__.py +1 -2
- langroid/agent/base.py +138 -54
- langroid/agent/batch.py +116 -28
- langroid/agent/chat_agent.py +25 -4
- langroid/agent/chat_document.py +5 -1
- langroid/agent/special/doc_chat_agent.py +2 -2
- langroid/agent/task.py +131 -26
- langroid/agent/tool_message.py +15 -43
- langroid/agent/tools/__init__.py +4 -0
- langroid/agent/tools/orchestration.py +87 -8
- langroid/language_models/mock_lm.py +28 -7
- langroid/parsing/web_search.py +7 -4
- langroid/utils/.chainlit/config.toml +121 -0
- langroid/utils/.chainlit/translations/en-US.json +231 -0
- langroid/utils/types.py +93 -0
- {langroid-0.10.2.dist-info → langroid-0.12.0.dist-info}/METADATA +4 -2
- {langroid-0.10.2.dist-info → langroid-0.12.0.dist-info}/RECORD +20 -17
- pyproject.toml +2 -2
- {langroid-0.10.2.dist-info → langroid-0.12.0.dist-info}/LICENSE +0 -0
- {langroid-0.10.2.dist-info → langroid-0.12.0.dist-info}/WHEEL +0 -0
langroid/agent/__init__.py
CHANGED
@@ -6,7 +6,7 @@ from .chat_document import (
|
|
6
6
|
ChatDocument,
|
7
7
|
)
|
8
8
|
from .chat_agent import ChatAgentConfig, ChatAgent
|
9
|
-
from .tool_message import ToolMessage
|
9
|
+
from .tool_message import ToolMessage
|
10
10
|
from .task import Task
|
11
11
|
|
12
12
|
from . import base
|
@@ -29,7 +29,6 @@ __all__ = [
|
|
29
29
|
"ChatAgent",
|
30
30
|
"ChatAgentConfig",
|
31
31
|
"ToolMessage",
|
32
|
-
"FinalResultTool",
|
33
32
|
"Task",
|
34
33
|
"base",
|
35
34
|
"chat_document",
|
langroid/agent/base.py
CHANGED
@@ -18,7 +18,10 @@ from typing import (
|
|
18
18
|
Set,
|
19
19
|
Tuple,
|
20
20
|
Type,
|
21
|
+
TypeVar,
|
21
22
|
cast,
|
23
|
+
get_args,
|
24
|
+
get_origin,
|
22
25
|
no_type_check,
|
23
26
|
)
|
24
27
|
|
@@ -46,7 +49,6 @@ from langroid.parsing.parse_json import extract_top_level_json
|
|
46
49
|
from langroid.parsing.parser import Parser, ParsingConfig
|
47
50
|
from langroid.prompts.prompts_config import PromptsConfig
|
48
51
|
from langroid.pydantic_v1 import (
|
49
|
-
BaseModel,
|
50
52
|
BaseSettings,
|
51
53
|
Field,
|
52
54
|
ValidationError,
|
@@ -56,6 +58,7 @@ from langroid.utils.configuration import settings
|
|
56
58
|
from langroid.utils.constants import DONE, NO_ANSWER, PASS, PASS_TO, SEND_TO
|
57
59
|
from langroid.utils.object_registry import ObjectRegistry
|
58
60
|
from langroid.utils.output import status
|
61
|
+
from langroid.utils.types import from_string, to_string
|
59
62
|
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
60
63
|
|
61
64
|
ORCHESTRATION_STRINGS = [DONE, PASS, PASS_TO, SEND_TO]
|
@@ -63,6 +66,8 @@ console = Console(quiet=settings.quiet)
|
|
63
66
|
|
64
67
|
logger = logging.getLogger(__name__)
|
65
68
|
|
69
|
+
T = TypeVar("T")
|
70
|
+
|
66
71
|
|
67
72
|
class AgentConfig(BaseSettings):
|
68
73
|
"""
|
@@ -78,6 +83,7 @@ class AgentConfig(BaseSettings):
|
|
78
83
|
prompts: Optional[PromptsConfig] = PromptsConfig()
|
79
84
|
show_stats: bool = True # show token usage/cost stats?
|
80
85
|
add_to_registry: bool = True # register agent in ObjectRegistry?
|
86
|
+
respond_tools_only: bool = False # respond only to tool messages (not plain text)?
|
81
87
|
|
82
88
|
@validator("name")
|
83
89
|
def check_name_alphanum(cls, v: str) -> str:
|
@@ -341,6 +347,7 @@ class Agent(ABC):
|
|
341
347
|
def create_agent_response(
|
342
348
|
self,
|
343
349
|
content: str | None = None,
|
350
|
+
content_any: Any = None,
|
344
351
|
tool_messages: List[ToolMessage] = [],
|
345
352
|
oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
|
346
353
|
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
@@ -352,6 +359,7 @@ class Agent(ABC):
|
|
352
359
|
return self.response_template(
|
353
360
|
Entity.AGENT,
|
354
361
|
content=content,
|
362
|
+
content_any=content_any,
|
355
363
|
tool_messages=tool_messages,
|
356
364
|
oai_tool_calls=oai_tool_calls,
|
357
365
|
oai_tool_choice=oai_tool_choice,
|
@@ -543,6 +551,7 @@ class Agent(ABC):
|
|
543
551
|
self,
|
544
552
|
e: Entity,
|
545
553
|
content: str | None = None,
|
554
|
+
content_any: Any = None,
|
546
555
|
tool_messages: List[ToolMessage] = [],
|
547
556
|
oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
|
548
557
|
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
@@ -553,6 +562,7 @@ class Agent(ABC):
|
|
553
562
|
"""Template for response from entity `e`."""
|
554
563
|
return ChatDocument(
|
555
564
|
content=content or "",
|
565
|
+
content_any=content_any,
|
556
566
|
tool_messages=tool_messages,
|
557
567
|
oai_tool_calls=oai_tool_calls,
|
558
568
|
oai_tool_id2result=oai_tool_id2result,
|
@@ -566,6 +576,7 @@ class Agent(ABC):
|
|
566
576
|
def create_user_response(
|
567
577
|
self,
|
568
578
|
content: str | None = None,
|
579
|
+
content_any: Any = None,
|
569
580
|
tool_messages: List[ToolMessage] = [],
|
570
581
|
oai_tool_calls: List[OpenAIToolCall] | None = None,
|
571
582
|
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
@@ -577,6 +588,7 @@ class Agent(ABC):
|
|
577
588
|
return self.response_template(
|
578
589
|
e=Entity.USER,
|
579
590
|
content=content,
|
591
|
+
content_any=content_any,
|
580
592
|
tool_messages=tool_messages,
|
581
593
|
oai_tool_calls=oai_tool_calls,
|
582
594
|
oai_tool_choice=oai_tool_choice,
|
@@ -677,9 +689,26 @@ class Agent(ABC):
|
|
677
689
|
|
678
690
|
return True
|
679
691
|
|
692
|
+
def can_respond(self, message: Optional[str | ChatDocument] = None) -> bool:
|
693
|
+
"""
|
694
|
+
Whether the agent can respond to a message.
|
695
|
+
Used in Task.py to skip a sub-task when we know it would not respond.
|
696
|
+
Args:
|
697
|
+
message (str|ChatDocument): message or ChatDocument object to respond to.
|
698
|
+
"""
|
699
|
+
tools = self.get_tool_messages(message)
|
700
|
+
if len(tools) == 0 and self.config.respond_tools_only:
|
701
|
+
return False
|
702
|
+
if message is not None and self.has_only_unhandled_tools(message):
|
703
|
+
# The message has tools that are NOT enabled to be handled by this agent,
|
704
|
+
# which means the agent cannot respond to it.
|
705
|
+
return False
|
706
|
+
return True
|
707
|
+
|
680
708
|
def create_llm_response(
|
681
709
|
self,
|
682
710
|
content: str | None = None,
|
711
|
+
content_any: Any = None,
|
683
712
|
tool_messages: List[ToolMessage] = [],
|
684
713
|
oai_tool_calls: None | List[OpenAIToolCall] = None,
|
685
714
|
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
@@ -691,6 +720,7 @@ class Agent(ABC):
|
|
691
720
|
return self.response_template(
|
692
721
|
Entity.LLM,
|
693
722
|
content=content,
|
723
|
+
content_any=content_any,
|
694
724
|
tool_messages=tool_messages,
|
695
725
|
oai_tool_calls=oai_tool_calls,
|
696
726
|
oai_tool_choice=oai_tool_choice,
|
@@ -856,6 +886,8 @@ class Agent(ABC):
|
|
856
886
|
Does the msg have at least one tool, and ALL tools are
|
857
887
|
disabled for handling by this agent?
|
858
888
|
"""
|
889
|
+
if msg is None:
|
890
|
+
return False
|
859
891
|
tools = self.get_tool_messages(msg, all_tools=True)
|
860
892
|
if len(tools) == 0:
|
861
893
|
return False
|
@@ -863,7 +895,7 @@ class Agent(ABC):
|
|
863
895
|
|
864
896
|
def get_tool_messages(
|
865
897
|
self,
|
866
|
-
msg: str | ChatDocument,
|
898
|
+
msg: str | ChatDocument | None,
|
867
899
|
all_tools: bool = False,
|
868
900
|
) -> List[ToolMessage]:
|
869
901
|
"""
|
@@ -874,6 +906,9 @@ class Agent(ABC):
|
|
874
906
|
- otherwise, return only the tools handled by this agent.
|
875
907
|
"""
|
876
908
|
|
909
|
+
if msg is None:
|
910
|
+
return []
|
911
|
+
|
877
912
|
if isinstance(msg, str):
|
878
913
|
json_tools = self.get_json_tool_messages(msg)
|
879
914
|
if all_tools:
|
@@ -1070,7 +1105,7 @@ class Agent(ABC):
|
|
1070
1105
|
fallback_result = self.handle_message_fallback(msg)
|
1071
1106
|
if fallback_result is None:
|
1072
1107
|
return None
|
1073
|
-
return self.
|
1108
|
+
return self.to_ChatDocument(
|
1074
1109
|
fallback_result,
|
1075
1110
|
chat_doc=msg if isinstance(msg, ChatDocument) else None,
|
1076
1111
|
)
|
@@ -1109,7 +1144,13 @@ class Agent(ABC):
|
|
1109
1144
|
results = [err_str for _ in tools]
|
1110
1145
|
else:
|
1111
1146
|
results = [self.handle_tool_message(t, chat_doc=chat_doc) for t in tools]
|
1147
|
+
# if there's a solitary ChatDocument|str result, return it as is
|
1148
|
+
if len(results) == 1 and isinstance(results[0], (str, ChatDocument)):
|
1149
|
+
return results[0]
|
1150
|
+
# extract content from ChatDocument results so we have all str|None
|
1151
|
+
results = [r.content if isinstance(r, ChatDocument) else r for r in results]
|
1112
1152
|
|
1153
|
+
# now all results are str|None
|
1113
1154
|
tool_names = [t.default_value("request") for t in tools]
|
1114
1155
|
if has_ids:
|
1115
1156
|
id2result = OrderedDict(
|
@@ -1132,35 +1173,16 @@ class Agent(ABC):
|
|
1132
1173
|
(name, r) for name, r in zip(tool_names, results) if r is not None
|
1133
1174
|
]
|
1134
1175
|
if len(name_results_list) == 0:
|
1135
|
-
return None
|
1176
|
+
return None
|
1177
|
+
|
1136
1178
|
# there was a non-None result
|
1137
|
-
chat_doc_results = [
|
1138
|
-
r for _, r in name_results_list if isinstance(r, ChatDocument)
|
1139
|
-
]
|
1140
|
-
if len(chat_doc_results) > 1:
|
1141
|
-
logger.warning(
|
1142
|
-
"""There were multiple ChatDocument results from tools,
|
1143
|
-
which is unexpected. The first one will be returned, and the others
|
1144
|
-
will be ignored.
|
1145
|
-
"""
|
1146
|
-
)
|
1147
|
-
if len(chat_doc_results) > 0:
|
1148
|
-
return chat_doc_results[0]
|
1149
1179
|
|
1150
1180
|
if has_ids and len(id2result) > 1:
|
1151
1181
|
# if there are multiple OpenAI Tool results, return them as a dict
|
1152
1182
|
return id2result
|
1153
1183
|
|
1154
|
-
if len(name_results_list) == 1 and isinstance(name_results_list[0][1], str):
|
1155
|
-
# single str result -- return it as is
|
1156
|
-
return name_results_list[0][1]
|
1157
|
-
|
1158
1184
|
# multi-results: prepend the tool name to each result
|
1159
|
-
str_results = [
|
1160
|
-
f"Result from {name}: {r}"
|
1161
|
-
for name, r in name_results_list
|
1162
|
-
if isinstance(r, str)
|
1163
|
-
]
|
1185
|
+
str_results = [f"Result from {name}: {r}" for name, r in name_results_list]
|
1164
1186
|
final = "\n\n".join(str_results)
|
1165
1187
|
return final
|
1166
1188
|
|
@@ -1260,20 +1282,41 @@ class Agent(ABC):
|
|
1260
1282
|
raise ve
|
1261
1283
|
return message
|
1262
1284
|
|
1263
|
-
def
|
1285
|
+
def to_ChatDocument(
|
1264
1286
|
self,
|
1265
1287
|
msg: Any,
|
1266
1288
|
orig_tool_name: str | None = None,
|
1267
1289
|
chat_doc: Optional[ChatDocument] = None,
|
1268
|
-
|
1290
|
+
author_entity: Entity = Entity.AGENT,
|
1291
|
+
) -> Optional[ChatDocument]:
|
1269
1292
|
"""
|
1270
|
-
|
1293
|
+
Convert result of a responder (agent_response or llm_response, or task.run()),
|
1294
|
+
or tool handler, or handle_message_fallback,
|
1295
|
+
to a ChatDocument, to enabling handling by other
|
1296
|
+
responders/tasks in a task loop possibly involving multiple agents.
|
1297
|
+
|
1298
|
+
Args:
|
1299
|
+
msg (Any): The result of a responder or tool handler or task.run()
|
1300
|
+
orig_tool_name (str): The original tool name that generated the response,
|
1301
|
+
if any.
|
1302
|
+
chat_doc (ChatDocument): The original ChatDocument object that `msg`
|
1303
|
+
is a response to.
|
1304
|
+
author_entity (Entity): The intended author of the result ChatDocument
|
1271
1305
|
"""
|
1272
|
-
if isinstance(msg,
|
1306
|
+
if msg is None or isinstance(msg, ChatDocument):
|
1307
|
+
return msg
|
1308
|
+
|
1309
|
+
is_agent_author = author_entity == Entity.AGENT
|
1310
|
+
|
1311
|
+
if isinstance(msg, str):
|
1312
|
+
return self.response_template(author_entity, content=msg, content_any=msg)
|
1313
|
+
elif isinstance(msg, ToolMessage):
|
1273
1314
|
# result is a ToolMessage, so...
|
1274
1315
|
result_tool_name = msg.default_value("request")
|
1275
|
-
if
|
1276
|
-
|
1316
|
+
if (
|
1317
|
+
is_agent_author
|
1318
|
+
and result_tool_name in self.llm_tools_handled
|
1319
|
+
and (orig_tool_name is None or orig_tool_name != result_tool_name)
|
1277
1320
|
):
|
1278
1321
|
# TODO: do we need to remove the tool message from the chat_doc?
|
1279
1322
|
# if (chat_doc is not None and
|
@@ -1281,30 +1324,73 @@ class Agent(ABC):
|
|
1281
1324
|
# chat_doc.tool_messages.remove(msg)
|
1282
1325
|
# if we can handle it, do so
|
1283
1326
|
result = self.handle_tool_message(msg, chat_doc=chat_doc)
|
1327
|
+
if result is not None and isinstance(result, ChatDocument):
|
1328
|
+
return result
|
1284
1329
|
else:
|
1285
1330
|
# else wrap it in an agent response and return it so
|
1286
1331
|
# orchestrator can find a respondent
|
1287
|
-
|
1288
|
-
elif isinstance(msg, (ChatDocument, str)):
|
1289
|
-
result = msg
|
1290
|
-
elif isinstance(msg, BaseModel):
|
1291
|
-
result = msg.json()
|
1332
|
+
return self.response_template(author_entity, tool_messages=[msg])
|
1292
1333
|
else:
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
return
|
1334
|
+
result = to_string(msg)
|
1335
|
+
|
1336
|
+
return (
|
1337
|
+
None
|
1338
|
+
if result is None
|
1339
|
+
else self.response_template(author_entity, content=result, content_any=msg)
|
1340
|
+
)
|
1341
|
+
|
1342
|
+
def from_ChatDocument(self, msg: ChatDocument, output_type: Type[T]) -> Optional[T]:
|
1343
|
+
"""
|
1344
|
+
Extract a desired output_type from a ChatDocument object.
|
1345
|
+
We use this fallback order:
|
1346
|
+
- if `msg.content_any` exists and matches the output_type, return it
|
1347
|
+
- if `msg.content` exists and output_type is str return it
|
1348
|
+
- if output_type is a ToolMessage, return the first tool in `msg.tool_messages`
|
1349
|
+
- if output_type is a list of ToolMessage,
|
1350
|
+
return all tools in `msg.tool_messages`
|
1351
|
+
- search for a tool in `msg.tool_messages` that has a field of output_type,
|
1352
|
+
and if found, return that field value
|
1353
|
+
- return None if all the above fail
|
1354
|
+
"""
|
1355
|
+
content = msg.content
|
1356
|
+
if output_type is str and content != "":
|
1357
|
+
return cast(T, content)
|
1358
|
+
content_any = msg.content_any
|
1359
|
+
if content_any is not None and isinstance(content_any, output_type):
|
1360
|
+
return cast(T, content_any)
|
1361
|
+
|
1362
|
+
tools = self.get_tool_messages(msg, all_tools=True)
|
1363
|
+
|
1364
|
+
if get_origin(output_type) is list:
|
1365
|
+
list_element_type = get_args(output_type)[0]
|
1366
|
+
if issubclass(list_element_type, ToolMessage):
|
1367
|
+
# list_element_type is a subclass of ToolMessage:
|
1368
|
+
# We output a list of objects derived from list_element_type
|
1369
|
+
return cast(
|
1370
|
+
T,
|
1371
|
+
[t for t in tools if isinstance(t, list_element_type)],
|
1372
|
+
)
|
1373
|
+
elif get_origin(output_type) is None and issubclass(output_type, ToolMessage):
|
1374
|
+
# output_type is a subclass of ToolMessage:
|
1375
|
+
# return the first tool that has this specific output_type
|
1376
|
+
for tool in tools:
|
1377
|
+
if isinstance(tool, output_type):
|
1378
|
+
return cast(T, tool)
|
1379
|
+
return None
|
1380
|
+
elif get_origin(output_type) is None and output_type in (str, int, float, bool):
|
1381
|
+
# attempt to get the output_type from the content,
|
1382
|
+
# if it's a primitive type
|
1383
|
+
primitive_value = from_string(content, output_type) # type: ignore
|
1384
|
+
if primitive_value is not None:
|
1385
|
+
return cast(T, primitive_value)
|
1386
|
+
|
1387
|
+
# then search for output_type as a field in a tool
|
1388
|
+
for tool in tools:
|
1389
|
+
value = tool.get_value_of_type(output_type)
|
1390
|
+
if value is not None:
|
1391
|
+
return cast(T, value)
|
1392
|
+
|
1393
|
+
return None
|
1308
1394
|
|
1309
1395
|
def handle_tool_message(
|
1310
1396
|
self,
|
@@ -1335,9 +1421,7 @@ class Agent(ABC):
|
|
1335
1421
|
maybe_result = handler_method(tool, chat_doc=chat_doc)
|
1336
1422
|
else:
|
1337
1423
|
maybe_result = handler_method(tool)
|
1338
|
-
result = self.
|
1339
|
-
maybe_result, tool_name, chat_doc
|
1340
|
-
)
|
1424
|
+
result = self.to_ChatDocument(maybe_result, tool_name, chat_doc)
|
1341
1425
|
except Exception as e:
|
1342
1426
|
# raise the error here since we are sure it's
|
1343
1427
|
# not a pydantic validation error,
|
langroid/agent/batch.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
import copy
|
3
3
|
import inspect
|
4
|
-
from typing import Any, Callable, Coroutine, Iterable, List, Optional, TypeVar
|
4
|
+
from typing import Any, Callable, Coroutine, Iterable, List, Optional, TypeVar, cast
|
5
5
|
|
6
6
|
from dotenv import load_dotenv
|
7
7
|
|
@@ -26,6 +26,7 @@ def run_batch_task_gen(
|
|
26
26
|
items: list[T],
|
27
27
|
input_map: Callable[[T], str | ChatDocument] = lambda x: str(x),
|
28
28
|
output_map: Callable[[ChatDocument | None], U] = lambda x: x, # type: ignore
|
29
|
+
stop_on_first_result: bool = False,
|
29
30
|
sequential: bool = True,
|
30
31
|
batch_size: Optional[int] = None,
|
31
32
|
turns: int = -1,
|
@@ -33,7 +34,7 @@ def run_batch_task_gen(
|
|
33
34
|
handle_exceptions: bool = False,
|
34
35
|
max_cost: float = 0.0,
|
35
36
|
max_tokens: int = 0,
|
36
|
-
) -> list[U]:
|
37
|
+
) -> list[Optional[U]]:
|
37
38
|
"""
|
38
39
|
Generate and run copies of a task async/concurrently one per item in `items` list.
|
39
40
|
For each item, apply `input_map` to get the initial message to process.
|
@@ -44,7 +45,13 @@ def run_batch_task_gen(
|
|
44
45
|
input_map (Callable[[T], str|ChatDocument]): function to map item to
|
45
46
|
initial message to process
|
46
47
|
output_map (Callable[[ChatDocument|str], U]): function to map result
|
47
|
-
to final result
|
48
|
+
to final result. If stop_on_first_result is enabled, then
|
49
|
+
map any invalid output to None. We continue until some non-None
|
50
|
+
result is obtained.
|
51
|
+
stop_on_first_result (bool): whether to stop after the first valid
|
52
|
+
(not-None) result. In this case all other tasks are
|
53
|
+
cancelled, and their corresponding result is None in the
|
54
|
+
returned list.
|
48
55
|
sequential (bool): whether to run sequentially
|
49
56
|
(e.g. some APIs such as ooba don't support concurrent requests)
|
50
57
|
batch_size (Optional[int]): The number of tasks to run at a time,
|
@@ -57,39 +64,91 @@ def run_batch_task_gen(
|
|
57
64
|
|
58
65
|
|
59
66
|
Returns:
|
60
|
-
list[
|
67
|
+
list[Optional[U]]: list of final results. Always list[U] if
|
68
|
+
`stop_on_first_result` is disabled
|
61
69
|
"""
|
62
70
|
inputs = [input_map(item) for item in items]
|
63
71
|
|
64
|
-
async def _do_task(
|
72
|
+
async def _do_task(
|
73
|
+
input: str | ChatDocument,
|
74
|
+
i: int,
|
75
|
+
return_idx: Optional[int] = None,
|
76
|
+
) -> BaseException | Optional[ChatDocument] | tuple[int, Optional[ChatDocument]]:
|
65
77
|
task_i = gen_task(i)
|
66
78
|
if task_i.agent.llm is not None:
|
67
79
|
task_i.agent.llm.set_stream(False)
|
68
80
|
task_i.agent.config.show_stats = False
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
81
|
+
try:
|
82
|
+
result = await task_i.run_async(
|
83
|
+
input, turns=turns, max_cost=max_cost, max_tokens=max_tokens
|
84
|
+
)
|
85
|
+
if return_idx is not None:
|
86
|
+
return return_idx, result
|
87
|
+
else:
|
88
|
+
return result
|
89
|
+
except asyncio.CancelledError as e:
|
90
|
+
task_i.kill()
|
91
|
+
if handle_exceptions:
|
92
|
+
return e
|
93
|
+
else:
|
94
|
+
raise e
|
95
|
+
except BaseException as e:
|
96
|
+
if handle_exceptions:
|
97
|
+
return e
|
98
|
+
else:
|
99
|
+
raise e
|
74
100
|
|
75
101
|
async def _do_all(
|
76
102
|
inputs: Iterable[str | ChatDocument], start_idx: int = 0
|
77
|
-
) -> list[U]:
|
103
|
+
) -> list[Optional[U]]:
|
78
104
|
results: list[Optional[ChatDocument]] = []
|
79
|
-
if
|
80
|
-
|
105
|
+
if stop_on_first_result:
|
106
|
+
outputs: list[Optional[U]] = [None] * len(list(inputs))
|
107
|
+
tasks = set(
|
108
|
+
asyncio.create_task(_do_task(input, i + start_idx, return_idx=i))
|
109
|
+
for i, input in enumerate(inputs)
|
110
|
+
)
|
111
|
+
while tasks:
|
81
112
|
try:
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
113
|
+
done, tasks = await asyncio.wait(
|
114
|
+
tasks, return_when=asyncio.FIRST_COMPLETED
|
115
|
+
)
|
116
|
+
for task in done:
|
117
|
+
idx_result = task.result()
|
118
|
+
if not isinstance(idx_result, tuple):
|
119
|
+
continue
|
120
|
+
index, output = idx_result
|
121
|
+
outputs[index] = output_map(output)
|
122
|
+
|
123
|
+
if any(r is not None for r in outputs):
|
124
|
+
return outputs
|
125
|
+
finally:
|
126
|
+
# Cancel all remaining tasks
|
127
|
+
for task in tasks:
|
128
|
+
task.cancel()
|
129
|
+
# Wait for cancellations to complete
|
130
|
+
try:
|
131
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
132
|
+
except BaseException as e:
|
133
|
+
if not handle_exceptions:
|
134
|
+
raise e
|
135
|
+
return outputs
|
136
|
+
elif sequential:
|
137
|
+
for i, input in enumerate(inputs):
|
138
|
+
result: Optional[ChatDocument] | BaseException = await _do_task(
|
139
|
+
input, i + start_idx
|
140
|
+
) # type: ignore
|
141
|
+
|
142
|
+
if isinstance(result, BaseException):
|
143
|
+
result = None
|
144
|
+
|
88
145
|
results.append(result)
|
89
146
|
else:
|
90
|
-
results_with_exceptions =
|
91
|
-
|
92
|
-
|
147
|
+
results_with_exceptions = cast(
|
148
|
+
list[Optional[ChatDocument | BaseException]],
|
149
|
+
await asyncio.gather(
|
150
|
+
*(_do_task(input, i + start_idx) for i, input in enumerate(inputs)),
|
151
|
+
),
|
93
152
|
)
|
94
153
|
|
95
154
|
results = [
|
@@ -99,7 +158,7 @@ def run_batch_task_gen(
|
|
99
158
|
|
100
159
|
return list(map(output_map, results))
|
101
160
|
|
102
|
-
results: List[U] = []
|
161
|
+
results: List[Optional[U]] = []
|
103
162
|
if batch_size is None:
|
104
163
|
msg = message or f"[bold green]Running {len(items)} tasks:"
|
105
164
|
|
@@ -113,8 +172,11 @@ def run_batch_task_gen(
|
|
113
172
|
complete_str = f", {start_idx} complete" if start_idx > 0 else ""
|
114
173
|
msg = message or f"[bold green]Running {len(items)} tasks{complete_str}:"
|
115
174
|
|
116
|
-
|
117
|
-
results.extend(
|
175
|
+
if stop_on_first_result and any(r is not None for r in results):
|
176
|
+
results.extend([None] * len(batch))
|
177
|
+
else:
|
178
|
+
with status(msg), SuppressLoggerWarnings():
|
179
|
+
results.extend(asyncio.run(_do_all(batch, start_idx=start_idx)))
|
118
180
|
|
119
181
|
return results
|
120
182
|
|
@@ -124,12 +186,13 @@ def run_batch_tasks(
|
|
124
186
|
items: list[T],
|
125
187
|
input_map: Callable[[T], str | ChatDocument] = lambda x: str(x),
|
126
188
|
output_map: Callable[[ChatDocument | None], U] = lambda x: x, # type: ignore
|
189
|
+
stop_on_first_result: bool = False,
|
127
190
|
sequential: bool = True,
|
128
191
|
batch_size: Optional[int] = None,
|
129
192
|
turns: int = -1,
|
130
193
|
max_cost: float = 0.0,
|
131
194
|
max_tokens: int = 0,
|
132
|
-
) -> List[U]:
|
195
|
+
) -> List[Optional[U]]:
|
133
196
|
"""
|
134
197
|
Run copies of `task` async/concurrently one per item in `items` list.
|
135
198
|
For each item, apply `input_map` to get the initial message to process.
|
@@ -150,7 +213,8 @@ def run_batch_tasks(
|
|
150
213
|
max_tokens: int: maximum token usage (in and out) (default 0 for unlimited)
|
151
214
|
|
152
215
|
Returns:
|
153
|
-
list[
|
216
|
+
list[Optional[U]]: list of final results. Always list[U] if
|
217
|
+
`stop_on_first_result` is disabled
|
154
218
|
"""
|
155
219
|
message = f"[bold green]Running {len(items)} copies of {task.name}..."
|
156
220
|
return run_batch_task_gen(
|
@@ -158,6 +222,7 @@ def run_batch_tasks(
|
|
158
222
|
items,
|
159
223
|
input_map,
|
160
224
|
output_map,
|
225
|
+
stop_on_first_result,
|
161
226
|
sequential,
|
162
227
|
batch_size,
|
163
228
|
turns,
|
@@ -176,6 +241,7 @@ def run_batch_agent_method(
|
|
176
241
|
input_map: Callable[[Any], str | ChatDocument] = lambda x: str(x),
|
177
242
|
output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
|
178
243
|
sequential: bool = True,
|
244
|
+
stop_on_first_result: bool = False,
|
179
245
|
) -> List[Any]:
|
180
246
|
"""
|
181
247
|
Run the `method` on copies of `agent`, async/concurrently one per
|
@@ -225,7 +291,25 @@ def run_batch_agent_method(
|
|
225
291
|
return output_map(result)
|
226
292
|
|
227
293
|
async def _do_all() -> List[Any]:
|
228
|
-
if
|
294
|
+
if stop_on_first_result:
|
295
|
+
tasks = [
|
296
|
+
asyncio.create_task(_do_task(input, i))
|
297
|
+
for i, input in enumerate(inputs)
|
298
|
+
]
|
299
|
+
results = [None] * len(tasks)
|
300
|
+
try:
|
301
|
+
done, pending = await asyncio.wait(
|
302
|
+
tasks, return_when=asyncio.FIRST_COMPLETED
|
303
|
+
)
|
304
|
+
for task in done:
|
305
|
+
index = tasks.index(task)
|
306
|
+
results[index] = await task
|
307
|
+
finally:
|
308
|
+
for task in pending:
|
309
|
+
task.cancel()
|
310
|
+
await asyncio.gather(*pending, return_exceptions=True)
|
311
|
+
return results
|
312
|
+
elif sequential:
|
229
313
|
results = []
|
230
314
|
for i, input in enumerate(inputs):
|
231
315
|
result = await _do_task(input, i)
|
@@ -249,6 +333,7 @@ def llm_response_batch(
|
|
249
333
|
input_map: Callable[[Any], str | ChatDocument] = lambda x: str(x),
|
250
334
|
output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
|
251
335
|
sequential: bool = True,
|
336
|
+
stop_on_first_result: bool = False,
|
252
337
|
) -> List[Any]:
|
253
338
|
return run_batch_agent_method(
|
254
339
|
agent,
|
@@ -257,6 +342,7 @@ def llm_response_batch(
|
|
257
342
|
input_map=input_map,
|
258
343
|
output_map=output_map,
|
259
344
|
sequential=sequential,
|
345
|
+
stop_on_first_result=stop_on_first_result,
|
260
346
|
)
|
261
347
|
|
262
348
|
|
@@ -266,6 +352,7 @@ def agent_response_batch(
|
|
266
352
|
input_map: Callable[[Any], str | ChatDocument] = lambda x: str(x),
|
267
353
|
output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
|
268
354
|
sequential: bool = True,
|
355
|
+
stop_on_first_result: bool = False,
|
269
356
|
) -> List[Any]:
|
270
357
|
return run_batch_agent_method(
|
271
358
|
agent,
|
@@ -274,4 +361,5 @@ def agent_response_batch(
|
|
274
361
|
input_map=input_map,
|
275
362
|
output_map=output_map,
|
276
363
|
sequential=sequential,
|
364
|
+
stop_on_first_result=stop_on_first_result,
|
277
365
|
)
|