langroid 0.33.6__py3-none-any.whl → 0.33.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +106 -0
- langroid/agent/__init__.py +41 -0
- langroid/agent/base.py +1983 -0
- langroid/agent/batch.py +398 -0
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +598 -0
- langroid/agent/chat_agent.py +1899 -0
- langroid/agent/chat_document.py +454 -0
- langroid/agent/openai_assistant.py +882 -0
- langroid/agent/special/__init__.py +59 -0
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +656 -0
- langroid/agent/special/arangodb/system_messages.py +186 -0
- langroid/agent/special/arangodb/tools.py +107 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/doc_chat_agent.py +1466 -0
- langroid/agent/special/lance_doc_chat_agent.py +262 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +198 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
- langroid/agent/special/lance_tools.py +61 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +56 -0
- langroid/agent/special/sql/__init__.py +17 -0
- langroid/agent/special/sql/sql_chat_agent.py +654 -0
- langroid/agent/special/sql/utils/__init__.py +21 -0
- langroid/agent/special/sql/utils/description_extractors.py +190 -0
- langroid/agent/special/sql/utils/populate_metadata.py +85 -0
- langroid/agent/special/sql/utils/system_message.py +35 -0
- langroid/agent/special/sql/utils/tools.py +64 -0
- langroid/agent/special/table_chat_agent.py +263 -0
- langroid/agent/task.py +2095 -0
- langroid/agent/tool_message.py +393 -0
- langroid/agent/tools/__init__.py +38 -0
- langroid/agent/tools/duckduckgo_search_tool.py +50 -0
- langroid/agent/tools/file_tools.py +234 -0
- langroid/agent/tools/google_search_tool.py +39 -0
- langroid/agent/tools/metaphor_search_tool.py +68 -0
- langroid/agent/tools/orchestration.py +303 -0
- langroid/agent/tools/recipient_tool.py +235 -0
- langroid/agent/tools/retrieval_tool.py +32 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/agent/tools/segment_extract_tool.py +41 -0
- langroid/agent/xml_tool_message.py +382 -0
- langroid/cachedb/__init__.py +17 -0
- langroid/cachedb/base.py +58 -0
- langroid/cachedb/momento_cachedb.py +108 -0
- langroid/cachedb/redis_cachedb.py +153 -0
- langroid/embedding_models/__init__.py +39 -0
- langroid/embedding_models/base.py +74 -0
- langroid/embedding_models/models.py +461 -0
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/exceptions.py +71 -0
- langroid/language_models/__init__.py +53 -0
- langroid/language_models/azure_openai.py +153 -0
- langroid/language_models/base.py +678 -0
- langroid/language_models/config.py +18 -0
- langroid/language_models/mock_lm.py +124 -0
- langroid/language_models/openai_gpt.py +1964 -0
- langroid/language_models/prompt_formatter/__init__.py +16 -0
- langroid/language_models/prompt_formatter/base.py +40 -0
- langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
- langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
- langroid/language_models/utils.py +151 -0
- langroid/mytypes.py +84 -0
- langroid/parsing/__init__.py +52 -0
- langroid/parsing/agent_chats.py +38 -0
- langroid/parsing/code_parser.py +121 -0
- langroid/parsing/document_parser.py +718 -0
- langroid/parsing/para_sentence_split.py +62 -0
- langroid/parsing/parse_json.py +155 -0
- langroid/parsing/parser.py +313 -0
- langroid/parsing/repo_loader.py +790 -0
- langroid/parsing/routing.py +36 -0
- langroid/parsing/search.py +275 -0
- langroid/parsing/spider.py +102 -0
- langroid/parsing/table_loader.py +94 -0
- langroid/parsing/url_loader.py +111 -0
- langroid/parsing/urls.py +273 -0
- langroid/parsing/utils.py +373 -0
- langroid/parsing/web_search.py +156 -0
- langroid/prompts/__init__.py +9 -0
- langroid/prompts/dialog.py +17 -0
- langroid/prompts/prompts_config.py +5 -0
- langroid/prompts/templates.py +141 -0
- langroid/pydantic_v1/__init__.py +10 -0
- langroid/pydantic_v1/main.py +4 -0
- langroid/utils/__init__.py +19 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +98 -0
- langroid/utils/constants.py +30 -0
- langroid/utils/git_utils.py +252 -0
- langroid/utils/globals.py +49 -0
- langroid/utils/logging.py +135 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/output/__init__.py +20 -0
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +99 -0
- langroid/utils/output/status.py +40 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +602 -0
- langroid/utils/system.py +286 -0
- langroid/utils/types.py +93 -0
- langroid/vector_store/__init__.py +50 -0
- langroid/vector_store/base.py +359 -0
- langroid/vector_store/chromadb.py +214 -0
- langroid/vector_store/lancedb.py +406 -0
- langroid/vector_store/meilisearch.py +299 -0
- langroid/vector_store/momento.py +278 -0
- langroid/vector_store/qdrantdb.py +468 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
- langroid-0.33.7.dist-info/RECORD +127 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
- langroid-0.33.6.dist-info/RECORD +0 -7
- langroid-0.33.6.dist-info/entry_points.txt +0 -4
- pyproject.toml +0 -356
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
langroid/agent/batch.py
ADDED
@@ -0,0 +1,398 @@
|
|
1
|
+
import asyncio
|
2
|
+
import copy
|
3
|
+
import inspect
|
4
|
+
from typing import Any, Callable, Coroutine, Iterable, List, Optional, TypeVar, cast
|
5
|
+
|
6
|
+
from dotenv import load_dotenv
|
7
|
+
|
8
|
+
from langroid.agent.base import Agent
|
9
|
+
from langroid.agent.chat_document import ChatDocument
|
10
|
+
from langroid.agent.task import Task
|
11
|
+
from langroid.parsing.utils import batched
|
12
|
+
from langroid.utils.configuration import quiet_mode
|
13
|
+
from langroid.utils.logging import setup_colored_logging
|
14
|
+
from langroid.utils.output import SuppressLoggerWarnings, status
|
15
|
+
|
16
|
+
setup_colored_logging()
|
17
|
+
|
18
|
+
load_dotenv()
|
19
|
+
|
20
|
+
T = TypeVar("T")
|
21
|
+
U = TypeVar("U")
|
22
|
+
|
23
|
+
|
24
|
+
def run_batch_task_gen(
|
25
|
+
gen_task: Callable[[int], Task],
|
26
|
+
items: list[T],
|
27
|
+
input_map: Callable[[T], str | ChatDocument] = lambda x: str(x),
|
28
|
+
output_map: Callable[[ChatDocument | None], U] = lambda x: x, # type: ignore
|
29
|
+
stop_on_first_result: bool = False,
|
30
|
+
sequential: bool = True,
|
31
|
+
batch_size: Optional[int] = None,
|
32
|
+
turns: int = -1,
|
33
|
+
message: Optional[str] = None,
|
34
|
+
handle_exceptions: bool = False,
|
35
|
+
max_cost: float = 0.0,
|
36
|
+
max_tokens: int = 0,
|
37
|
+
) -> list[Optional[U]]:
|
38
|
+
"""
|
39
|
+
Generate and run copies of a task async/concurrently one per item in `items` list.
|
40
|
+
For each item, apply `input_map` to get the initial message to process.
|
41
|
+
For each result, apply `output_map` to get the final result.
|
42
|
+
Args:
|
43
|
+
gen_task (Callable[[int], Task]): generates the tasks to run
|
44
|
+
items (list[T]): list of items to process
|
45
|
+
input_map (Callable[[T], str|ChatDocument]): function to map item to
|
46
|
+
initial message to process
|
47
|
+
output_map (Callable[[ChatDocument|str], U]): function to map result
|
48
|
+
to final result. If stop_on_first_result is enabled, then
|
49
|
+
map any invalid output to None. We continue until some non-None
|
50
|
+
result is obtained.
|
51
|
+
stop_on_first_result (bool): whether to stop after the first valid
|
52
|
+
(not-None) result. In this case all other tasks are
|
53
|
+
cancelled, and their corresponding result is None in the
|
54
|
+
returned list.
|
55
|
+
sequential (bool): whether to run sequentially
|
56
|
+
(e.g. some APIs such as ooba don't support concurrent requests)
|
57
|
+
batch_size (Optional[int]): The number of tasks to run at a time,
|
58
|
+
if None, unbatched
|
59
|
+
turns (int): number of turns to run, -1 for infinite
|
60
|
+
message (Optional[str]): optionally overrides the console status messages
|
61
|
+
handle_exceptions: bool: Whether to replace exceptions with outputs of None
|
62
|
+
max_cost: float: maximum cost to run the task (default 0.0 for unlimited)
|
63
|
+
max_tokens: int: maximum token usage (in and out) (default 0 for unlimited)
|
64
|
+
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
list[Optional[U]]: list of final results. Always list[U] if
|
68
|
+
`stop_on_first_result` is disabled
|
69
|
+
"""
|
70
|
+
inputs = [input_map(item) for item in items]
|
71
|
+
|
72
|
+
async def _do_task(
|
73
|
+
input: str | ChatDocument,
|
74
|
+
i: int,
|
75
|
+
return_idx: Optional[int] = None,
|
76
|
+
) -> BaseException | Optional[ChatDocument] | tuple[int, Optional[ChatDocument]]:
|
77
|
+
task_i = gen_task(i)
|
78
|
+
if task_i.agent.llm is not None:
|
79
|
+
task_i.agent.llm.set_stream(False)
|
80
|
+
task_i.agent.config.show_stats = False
|
81
|
+
try:
|
82
|
+
result = await task_i.run_async(
|
83
|
+
input, turns=turns, max_cost=max_cost, max_tokens=max_tokens
|
84
|
+
)
|
85
|
+
if return_idx is not None:
|
86
|
+
return return_idx, result
|
87
|
+
else:
|
88
|
+
return result
|
89
|
+
except asyncio.CancelledError as e:
|
90
|
+
task_i.kill()
|
91
|
+
if handle_exceptions:
|
92
|
+
return e
|
93
|
+
else:
|
94
|
+
raise e
|
95
|
+
except BaseException as e:
|
96
|
+
if handle_exceptions:
|
97
|
+
return e
|
98
|
+
else:
|
99
|
+
raise e
|
100
|
+
|
101
|
+
async def _do_all(
|
102
|
+
inputs: Iterable[str | ChatDocument], start_idx: int = 0
|
103
|
+
) -> list[Optional[U]]:
|
104
|
+
results: list[Optional[ChatDocument]] = []
|
105
|
+
if stop_on_first_result:
|
106
|
+
outputs: list[Optional[U]] = [None] * len(list(inputs))
|
107
|
+
tasks = set(
|
108
|
+
asyncio.create_task(_do_task(input, i + start_idx, return_idx=i))
|
109
|
+
for i, input in enumerate(inputs)
|
110
|
+
)
|
111
|
+
while tasks:
|
112
|
+
try:
|
113
|
+
done, tasks = await asyncio.wait(
|
114
|
+
tasks, return_when=asyncio.FIRST_COMPLETED
|
115
|
+
)
|
116
|
+
for task in done:
|
117
|
+
idx_result = task.result()
|
118
|
+
if not isinstance(idx_result, tuple):
|
119
|
+
continue
|
120
|
+
index, output = idx_result
|
121
|
+
outputs[index] = output_map(output)
|
122
|
+
|
123
|
+
if any(r is not None for r in outputs):
|
124
|
+
return outputs
|
125
|
+
finally:
|
126
|
+
# Cancel all remaining tasks
|
127
|
+
for task in tasks:
|
128
|
+
task.cancel()
|
129
|
+
# Wait for cancellations to complete
|
130
|
+
try:
|
131
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
132
|
+
except BaseException as e:
|
133
|
+
if not handle_exceptions:
|
134
|
+
raise e
|
135
|
+
return outputs
|
136
|
+
elif sequential:
|
137
|
+
for i, input in enumerate(inputs):
|
138
|
+
result: Optional[ChatDocument] | BaseException = await _do_task(
|
139
|
+
input, i + start_idx
|
140
|
+
) # type: ignore
|
141
|
+
|
142
|
+
if isinstance(result, BaseException):
|
143
|
+
result = None
|
144
|
+
|
145
|
+
results.append(result)
|
146
|
+
else:
|
147
|
+
results_with_exceptions = cast(
|
148
|
+
list[Optional[ChatDocument | BaseException]],
|
149
|
+
await asyncio.gather(
|
150
|
+
*(_do_task(input, i + start_idx) for i, input in enumerate(inputs)),
|
151
|
+
),
|
152
|
+
)
|
153
|
+
|
154
|
+
results = [
|
155
|
+
r if not isinstance(r, BaseException) else None
|
156
|
+
for r in results_with_exceptions
|
157
|
+
]
|
158
|
+
|
159
|
+
return list(map(output_map, results))
|
160
|
+
|
161
|
+
results: List[Optional[U]] = []
|
162
|
+
if batch_size is None:
|
163
|
+
msg = message or f"[bold green]Running {len(items)} tasks:"
|
164
|
+
|
165
|
+
with status(msg), SuppressLoggerWarnings():
|
166
|
+
results = asyncio.run(_do_all(inputs))
|
167
|
+
else:
|
168
|
+
batches = batched(inputs, batch_size)
|
169
|
+
|
170
|
+
for batch in batches:
|
171
|
+
start_idx = len(results)
|
172
|
+
complete_str = f", {start_idx} complete" if start_idx > 0 else ""
|
173
|
+
msg = message or f"[bold green]Running {len(items)} tasks{complete_str}:"
|
174
|
+
|
175
|
+
if stop_on_first_result and any(r is not None for r in results):
|
176
|
+
results.extend([None] * len(batch))
|
177
|
+
else:
|
178
|
+
with status(msg), SuppressLoggerWarnings():
|
179
|
+
results.extend(asyncio.run(_do_all(batch, start_idx=start_idx)))
|
180
|
+
|
181
|
+
return results
|
182
|
+
|
183
|
+
|
184
|
+
def run_batch_tasks(
|
185
|
+
task: Task,
|
186
|
+
items: list[T],
|
187
|
+
input_map: Callable[[T], str | ChatDocument] = lambda x: str(x),
|
188
|
+
output_map: Callable[[ChatDocument | None], U] = lambda x: x, # type: ignore
|
189
|
+
stop_on_first_result: bool = False,
|
190
|
+
sequential: bool = True,
|
191
|
+
batch_size: Optional[int] = None,
|
192
|
+
turns: int = -1,
|
193
|
+
max_cost: float = 0.0,
|
194
|
+
max_tokens: int = 0,
|
195
|
+
) -> List[Optional[U]]:
|
196
|
+
"""
|
197
|
+
Run copies of `task` async/concurrently one per item in `items` list.
|
198
|
+
For each item, apply `input_map` to get the initial message to process.
|
199
|
+
For each result, apply `output_map` to get the final result.
|
200
|
+
Args:
|
201
|
+
task (Task): task to run
|
202
|
+
items (list[T]): list of items to process
|
203
|
+
input_map (Callable[[T], str|ChatDocument]): function to map item to
|
204
|
+
initial message to process
|
205
|
+
output_map (Callable[[ChatDocument|str], U]): function to map result
|
206
|
+
to final result
|
207
|
+
sequential (bool): whether to run sequentially
|
208
|
+
(e.g. some APIs such as ooba don't support concurrent requests)
|
209
|
+
batch_size (Optional[int]): The number of tasks to run at a time,
|
210
|
+
if None, unbatched
|
211
|
+
turns (int): number of turns to run, -1 for infinite
|
212
|
+
max_cost: float: maximum cost to run the task (default 0.0 for unlimited)
|
213
|
+
max_tokens: int: maximum token usage (in and out) (default 0 for unlimited)
|
214
|
+
|
215
|
+
Returns:
|
216
|
+
list[Optional[U]]: list of final results. Always list[U] if
|
217
|
+
`stop_on_first_result` is disabled
|
218
|
+
"""
|
219
|
+
message = f"[bold green]Running {len(items)} copies of {task.name}..."
|
220
|
+
return run_batch_task_gen(
|
221
|
+
lambda i: task.clone(i),
|
222
|
+
items,
|
223
|
+
input_map,
|
224
|
+
output_map,
|
225
|
+
stop_on_first_result,
|
226
|
+
sequential,
|
227
|
+
batch_size,
|
228
|
+
turns,
|
229
|
+
message,
|
230
|
+
max_cost=max_cost,
|
231
|
+
max_tokens=max_tokens,
|
232
|
+
)
|
233
|
+
|
234
|
+
|
235
|
+
def run_batch_agent_method(
|
236
|
+
agent: Agent,
|
237
|
+
method: Callable[
|
238
|
+
[str | ChatDocument | None], Coroutine[Any, Any, ChatDocument | None]
|
239
|
+
],
|
240
|
+
items: List[Any],
|
241
|
+
input_map: Callable[[Any], str | ChatDocument] = lambda x: str(x),
|
242
|
+
output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
|
243
|
+
sequential: bool = True,
|
244
|
+
stop_on_first_result: bool = False,
|
245
|
+
) -> List[Any]:
|
246
|
+
"""
|
247
|
+
Run the `method` on copies of `agent`, async/concurrently one per
|
248
|
+
item in `items` list.
|
249
|
+
ASSUMPTION: The `method` is an async method and has signature:
|
250
|
+
method(self, input: str|ChatDocument|None) -> ChatDocument|None
|
251
|
+
So this would typically be used for the agent's "responder" methods,
|
252
|
+
e.g. `llm_response_async` or `agent_responder_async`.
|
253
|
+
|
254
|
+
For each item, apply `input_map` to get the initial message to process.
|
255
|
+
For each result, apply `output_map` to get the final result.
|
256
|
+
|
257
|
+
Args:
|
258
|
+
agent (Agent): agent whose method to run
|
259
|
+
method (str): Async method to run on copies of `agent`.
|
260
|
+
The method is assumed to have signature:
|
261
|
+
`method(self, input: str|ChatDocument|None) -> ChatDocument|None`
|
262
|
+
input_map (Callable[[Any], str|ChatDocument]): function to map item to
|
263
|
+
initial message to process
|
264
|
+
output_map (Callable[[ChatDocument|str], Any]): function to map result
|
265
|
+
to final result
|
266
|
+
sequential (bool): whether to run sequentially
|
267
|
+
(e.g. some APIs such as ooba don't support concurrent requests)
|
268
|
+
Returns:
|
269
|
+
List[Any]: list of final results
|
270
|
+
"""
|
271
|
+
# Check if the method is async
|
272
|
+
method_name = method.__name__
|
273
|
+
if not inspect.iscoroutinefunction(method):
|
274
|
+
raise ValueError(f"The method {method_name} is not async.")
|
275
|
+
|
276
|
+
inputs = [input_map(item) for item in items]
|
277
|
+
agent_cfg = copy.deepcopy(agent.config)
|
278
|
+
assert agent_cfg.llm is not None, "agent must have llm config"
|
279
|
+
agent_cfg.llm.stream = False
|
280
|
+
agent_cfg.show_stats = False
|
281
|
+
agent_cls = type(agent)
|
282
|
+
agent_name = agent_cfg.name
|
283
|
+
|
284
|
+
async def _do_task(input: str | ChatDocument, i: int) -> Any:
|
285
|
+
agent_cfg.name = f"{agent_cfg.name}-{i}"
|
286
|
+
agent_i = agent_cls(agent_cfg)
|
287
|
+
method_i = getattr(agent_i, method_name, None)
|
288
|
+
if method_i is None:
|
289
|
+
raise ValueError(f"Agent {agent_name} has no method {method_name}")
|
290
|
+
result = await method_i(input)
|
291
|
+
return output_map(result)
|
292
|
+
|
293
|
+
async def _do_all() -> List[Any]:
|
294
|
+
if stop_on_first_result:
|
295
|
+
tasks = [
|
296
|
+
asyncio.create_task(_do_task(input, i))
|
297
|
+
for i, input in enumerate(inputs)
|
298
|
+
]
|
299
|
+
results = [None] * len(tasks)
|
300
|
+
try:
|
301
|
+
done, pending = await asyncio.wait(
|
302
|
+
tasks, return_when=asyncio.FIRST_COMPLETED
|
303
|
+
)
|
304
|
+
for task in done:
|
305
|
+
index = tasks.index(task)
|
306
|
+
results[index] = await task
|
307
|
+
finally:
|
308
|
+
for task in pending:
|
309
|
+
task.cancel()
|
310
|
+
await asyncio.gather(*pending, return_exceptions=True)
|
311
|
+
return results
|
312
|
+
elif sequential:
|
313
|
+
results = []
|
314
|
+
for i, input in enumerate(inputs):
|
315
|
+
result = await _do_task(input, i)
|
316
|
+
results.append(result)
|
317
|
+
return results
|
318
|
+
with quiet_mode(), SuppressLoggerWarnings():
|
319
|
+
return await asyncio.gather(
|
320
|
+
*(_do_task(input, i) for i, input in enumerate(inputs))
|
321
|
+
)
|
322
|
+
|
323
|
+
n = len(items)
|
324
|
+
with status(f"[bold green]Running {n} copies of {agent_name}..."):
|
325
|
+
results = asyncio.run(_do_all())
|
326
|
+
|
327
|
+
return results
|
328
|
+
|
329
|
+
|
330
|
+
def llm_response_batch(
|
331
|
+
agent: Agent,
|
332
|
+
items: List[Any],
|
333
|
+
input_map: Callable[[Any], str | ChatDocument] = lambda x: str(x),
|
334
|
+
output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
|
335
|
+
sequential: bool = True,
|
336
|
+
stop_on_first_result: bool = False,
|
337
|
+
) -> List[Any]:
|
338
|
+
return run_batch_agent_method(
|
339
|
+
agent,
|
340
|
+
agent.llm_response_async,
|
341
|
+
items,
|
342
|
+
input_map=input_map,
|
343
|
+
output_map=output_map,
|
344
|
+
sequential=sequential,
|
345
|
+
stop_on_first_result=stop_on_first_result,
|
346
|
+
)
|
347
|
+
|
348
|
+
|
349
|
+
def agent_response_batch(
|
350
|
+
agent: Agent,
|
351
|
+
items: List[Any],
|
352
|
+
input_map: Callable[[Any], str | ChatDocument] = lambda x: str(x),
|
353
|
+
output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
|
354
|
+
sequential: bool = True,
|
355
|
+
stop_on_first_result: bool = False,
|
356
|
+
) -> List[Any]:
|
357
|
+
return run_batch_agent_method(
|
358
|
+
agent,
|
359
|
+
agent.agent_response_async,
|
360
|
+
items,
|
361
|
+
input_map=input_map,
|
362
|
+
output_map=output_map,
|
363
|
+
sequential=sequential,
|
364
|
+
stop_on_first_result=stop_on_first_result,
|
365
|
+
)
|
366
|
+
|
367
|
+
|
368
|
+
def run_batch_function(
|
369
|
+
function: Callable[[T], U],
|
370
|
+
items: list[T],
|
371
|
+
sequential: bool = True,
|
372
|
+
batch_size: Optional[int] = None,
|
373
|
+
) -> List[U]:
|
374
|
+
async def _do_task(item: T) -> U:
|
375
|
+
return function(item)
|
376
|
+
|
377
|
+
async def _do_all(items: Iterable[T]) -> List[U]:
|
378
|
+
if sequential:
|
379
|
+
results = []
|
380
|
+
for item in items:
|
381
|
+
result = await _do_task(item)
|
382
|
+
results.append(result)
|
383
|
+
return results
|
384
|
+
|
385
|
+
return await asyncio.gather(*(_do_task(item) for item in items))
|
386
|
+
|
387
|
+
results: List[U] = []
|
388
|
+
|
389
|
+
if batch_size is None:
|
390
|
+
with status(f"[bold green]Running {len(items)} tasks:"):
|
391
|
+
results = asyncio.run(_do_all(items))
|
392
|
+
else:
|
393
|
+
batches = batched(items, batch_size)
|
394
|
+
for batch in batches:
|
395
|
+
with status(f"[bold green]Running batch of {len(batch)} tasks:"):
|
396
|
+
results.extend(asyncio.run(_do_all(batch)))
|
397
|
+
|
398
|
+
return results
|
File without changes
|