langroid 0.33.12__tar.gz → 0.34.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. {langroid-0.33.12 → langroid-0.34.0}/PKG-INFO +1 -1
  2. langroid-0.34.0/langroid/agent/batch.py +555 -0
  3. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/chat_agent.py +7 -3
  4. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/doc_chat_agent.py +116 -5
  5. {langroid-0.33.12 → langroid-0.34.0}/langroid/mytypes.py +11 -0
  6. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/spider.py +6 -6
  7. {langroid-0.33.12 → langroid-0.34.0}/pyproject.toml +1 -1
  8. langroid-0.33.12/langroid/agent/batch.py +0 -398
  9. {langroid-0.33.12 → langroid-0.34.0}/.gitignore +0 -0
  10. {langroid-0.33.12 → langroid-0.34.0}/LICENSE +0 -0
  11. {langroid-0.33.12 → langroid-0.34.0}/README.md +0 -0
  12. {langroid-0.33.12 → langroid-0.34.0}/langroid/__init__.py +0 -0
  13. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/__init__.py +0 -0
  14. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/base.py +0 -0
  15. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/callbacks/__init__.py +0 -0
  16. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/callbacks/chainlit.py +0 -0
  17. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/chat_document.py +0 -0
  18. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/openai_assistant.py +0 -0
  19. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/__init__.py +0 -0
  20. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/arangodb/__init__.py +0 -0
  21. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/arangodb/arangodb_agent.py +0 -0
  22. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/arangodb/system_messages.py +0 -0
  23. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/arangodb/tools.py +0 -0
  24. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/arangodb/utils.py +0 -0
  25. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/lance_doc_chat_agent.py +0 -0
  26. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/lance_rag/__init__.py +0 -0
  27. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/lance_rag/critic_agent.py +0 -0
  28. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/lance_rag/lance_rag_task.py +0 -0
  29. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/lance_rag/query_planner_agent.py +0 -0
  30. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/lance_tools.py +0 -0
  31. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/neo4j/__init__.py +0 -0
  32. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/neo4j/csv_kg_chat.py +0 -0
  33. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/neo4j/neo4j_chat_agent.py +0 -0
  34. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/neo4j/system_messages.py +0 -0
  35. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/neo4j/tools.py +0 -0
  36. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/relevance_extractor_agent.py +0 -0
  37. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/retriever_agent.py +0 -0
  38. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/sql/__init__.py +0 -0
  39. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/sql/sql_chat_agent.py +0 -0
  40. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/sql/utils/__init__.py +0 -0
  41. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/sql/utils/description_extractors.py +0 -0
  42. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/sql/utils/populate_metadata.py +0 -0
  43. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/sql/utils/system_message.py +0 -0
  44. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/sql/utils/tools.py +0 -0
  45. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/special/table_chat_agent.py +0 -0
  46. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/task.py +0 -0
  47. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/tool_message.py +0 -0
  48. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/tools/__init__.py +0 -0
  49. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/tools/duckduckgo_search_tool.py +0 -0
  50. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/tools/file_tools.py +0 -0
  51. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/tools/google_search_tool.py +0 -0
  52. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/tools/metaphor_search_tool.py +0 -0
  53. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/tools/orchestration.py +0 -0
  54. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/tools/recipient_tool.py +0 -0
  55. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/tools/retrieval_tool.py +0 -0
  56. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/tools/rewind_tool.py +0 -0
  57. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/tools/segment_extract_tool.py +0 -0
  58. {langroid-0.33.12 → langroid-0.34.0}/langroid/agent/xml_tool_message.py +0 -0
  59. {langroid-0.33.12 → langroid-0.34.0}/langroid/cachedb/__init__.py +0 -0
  60. {langroid-0.33.12 → langroid-0.34.0}/langroid/cachedb/base.py +0 -0
  61. {langroid-0.33.12 → langroid-0.34.0}/langroid/cachedb/momento_cachedb.py +0 -0
  62. {langroid-0.33.12 → langroid-0.34.0}/langroid/cachedb/redis_cachedb.py +0 -0
  63. {langroid-0.33.12 → langroid-0.34.0}/langroid/embedding_models/__init__.py +0 -0
  64. {langroid-0.33.12 → langroid-0.34.0}/langroid/embedding_models/base.py +0 -0
  65. {langroid-0.33.12 → langroid-0.34.0}/langroid/embedding_models/models.py +0 -0
  66. {langroid-0.33.12 → langroid-0.34.0}/langroid/embedding_models/protoc/__init__.py +0 -0
  67. {langroid-0.33.12 → langroid-0.34.0}/langroid/embedding_models/protoc/embeddings.proto +0 -0
  68. {langroid-0.33.12 → langroid-0.34.0}/langroid/embedding_models/protoc/embeddings_pb2.py +0 -0
  69. {langroid-0.33.12 → langroid-0.34.0}/langroid/embedding_models/protoc/embeddings_pb2.pyi +0 -0
  70. {langroid-0.33.12 → langroid-0.34.0}/langroid/embedding_models/protoc/embeddings_pb2_grpc.py +0 -0
  71. {langroid-0.33.12 → langroid-0.34.0}/langroid/embedding_models/remote_embeds.py +0 -0
  72. {langroid-0.33.12 → langroid-0.34.0}/langroid/exceptions.py +0 -0
  73. {langroid-0.33.12 → langroid-0.34.0}/langroid/language_models/__init__.py +0 -0
  74. {langroid-0.33.12 → langroid-0.34.0}/langroid/language_models/azure_openai.py +0 -0
  75. {langroid-0.33.12 → langroid-0.34.0}/langroid/language_models/base.py +0 -0
  76. {langroid-0.33.12 → langroid-0.34.0}/langroid/language_models/config.py +0 -0
  77. {langroid-0.33.12 → langroid-0.34.0}/langroid/language_models/mock_lm.py +0 -0
  78. {langroid-0.33.12 → langroid-0.34.0}/langroid/language_models/openai_gpt.py +0 -0
  79. {langroid-0.33.12 → langroid-0.34.0}/langroid/language_models/prompt_formatter/__init__.py +0 -0
  80. {langroid-0.33.12 → langroid-0.34.0}/langroid/language_models/prompt_formatter/base.py +0 -0
  81. {langroid-0.33.12 → langroid-0.34.0}/langroid/language_models/prompt_formatter/hf_formatter.py +0 -0
  82. {langroid-0.33.12 → langroid-0.34.0}/langroid/language_models/prompt_formatter/llama2_formatter.py +0 -0
  83. {langroid-0.33.12 → langroid-0.34.0}/langroid/language_models/utils.py +0 -0
  84. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/__init__.py +0 -0
  85. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/agent_chats.py +0 -0
  86. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/code_parser.py +0 -0
  87. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/document_parser.py +0 -0
  88. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/para_sentence_split.py +0 -0
  89. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/parse_json.py +0 -0
  90. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/parser.py +0 -0
  91. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/repo_loader.py +0 -0
  92. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/routing.py +0 -0
  93. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/search.py +0 -0
  94. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/table_loader.py +0 -0
  95. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/url_loader.py +0 -0
  96. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/urls.py +0 -0
  97. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/utils.py +0 -0
  98. {langroid-0.33.12 → langroid-0.34.0}/langroid/parsing/web_search.py +0 -0
  99. {langroid-0.33.12 → langroid-0.34.0}/langroid/prompts/__init__.py +0 -0
  100. {langroid-0.33.12 → langroid-0.34.0}/langroid/prompts/dialog.py +0 -0
  101. {langroid-0.33.12 → langroid-0.34.0}/langroid/prompts/prompts_config.py +0 -0
  102. {langroid-0.33.12 → langroid-0.34.0}/langroid/prompts/templates.py +0 -0
  103. {langroid-0.33.12 → langroid-0.34.0}/langroid/py.typed +0 -0
  104. {langroid-0.33.12 → langroid-0.34.0}/langroid/pydantic_v1/__init__.py +0 -0
  105. {langroid-0.33.12 → langroid-0.34.0}/langroid/pydantic_v1/main.py +0 -0
  106. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/__init__.py +0 -0
  107. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/algorithms/__init__.py +0 -0
  108. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/algorithms/graph.py +0 -0
  109. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/configuration.py +0 -0
  110. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/constants.py +0 -0
  111. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/git_utils.py +0 -0
  112. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/globals.py +0 -0
  113. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/logging.py +0 -0
  114. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/object_registry.py +0 -0
  115. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/output/__init__.py +0 -0
  116. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/output/citations.py +0 -0
  117. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/output/printing.py +0 -0
  118. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/output/status.py +0 -0
  119. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/pandas_utils.py +0 -0
  120. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/pydantic_utils.py +0 -0
  121. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/system.py +0 -0
  122. {langroid-0.33.12 → langroid-0.34.0}/langroid/utils/types.py +0 -0
  123. {langroid-0.33.12 → langroid-0.34.0}/langroid/vector_store/__init__.py +0 -0
  124. {langroid-0.33.12 → langroid-0.34.0}/langroid/vector_store/base.py +0 -0
  125. {langroid-0.33.12 → langroid-0.34.0}/langroid/vector_store/chromadb.py +0 -0
  126. {langroid-0.33.12 → langroid-0.34.0}/langroid/vector_store/lancedb.py +0 -0
  127. {langroid-0.33.12 → langroid-0.34.0}/langroid/vector_store/meilisearch.py +0 -0
  128. {langroid-0.33.12 → langroid-0.34.0}/langroid/vector_store/momento.py +0 -0
  129. {langroid-0.33.12 → langroid-0.34.0}/langroid/vector_store/qdrantdb.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langroid
3
- Version: 0.33.12
3
+ Version: 0.34.0
4
4
  Summary: Harness LLMs with Multi-Agent Programming
5
5
  Author-email: Prasad Chalasani <pchalasani@gmail.com>
6
6
  License: MIT
@@ -0,0 +1,555 @@
1
+ import asyncio
2
+ import copy
3
+ import inspect
4
+ import warnings
5
+ from enum import Enum
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Coroutine,
10
+ Iterable,
11
+ List,
12
+ Optional,
13
+ TypeVar,
14
+ Union,
15
+ cast,
16
+ )
17
+
18
+ from dotenv import load_dotenv
19
+
20
+ from langroid.agent.base import Agent
21
+ from langroid.agent.chat_document import ChatDocument
22
+ from langroid.agent.task import Task
23
+ from langroid.parsing.utils import batched
24
+ from langroid.utils.configuration import quiet_mode
25
+ from langroid.utils.logging import setup_colored_logging
26
+ from langroid.utils.output import SuppressLoggerWarnings, status
27
+
28
+ setup_colored_logging()
29
+
30
+ load_dotenv()
31
+
32
+ T = TypeVar("T")
33
+ U = TypeVar("U")
34
+
35
+
36
+ class ExceptionHandling(str, Enum):
37
+ """Enum for exception handling options."""
38
+
39
+ RAISE = "raise"
40
+ RETURN_NONE = "return_none"
41
+ RETURN_EXCEPTION = "return_exception"
42
+
43
+
44
+ def _convert_exception_handling(
45
+ handle_exceptions: Union[bool, ExceptionHandling]
46
+ ) -> ExceptionHandling:
47
+ """Convert legacy boolean handle_exceptions to ExceptionHandling enum."""
48
+ if isinstance(handle_exceptions, ExceptionHandling):
49
+ return handle_exceptions
50
+
51
+ if isinstance(handle_exceptions, bool):
52
+ warnings.warn(
53
+ "Boolean handle_exceptions is deprecated. "
54
+ "Use ExceptionHandling enum instead: "
55
+ "RAISE, RETURN_NONE, or RETURN_EXCEPTION.",
56
+ DeprecationWarning,
57
+ stacklevel=2,
58
+ )
59
+ return (
60
+ ExceptionHandling.RETURN_NONE
61
+ if handle_exceptions
62
+ else ExceptionHandling.RAISE
63
+ )
64
+
65
+ raise TypeError(
66
+ "handle_exceptions must be bool or ExceptionHandling, "
67
+ f"not {type(handle_exceptions)}"
68
+ )
69
+
70
+
71
+ async def _process_batch_async(
72
+ inputs: Iterable[str | ChatDocument],
73
+ do_task: Callable[[str | ChatDocument, int], Coroutine[Any, Any, Any]],
74
+ start_idx: int = 0,
75
+ stop_on_first_result: bool = False,
76
+ sequential: bool = False,
77
+ handle_exceptions: Union[bool, ExceptionHandling] = ExceptionHandling.RAISE,
78
+ output_map: Callable[[Any], Any] = lambda x: x,
79
+ ) -> List[Optional[ChatDocument] | BaseException]:
80
+ """
81
+ Unified batch processing logic for both agent methods and tasks.
82
+
83
+ Args:
84
+ inputs: Iterable of inputs to process
85
+ do_task: Task execution function that takes (input, index) and returns result
86
+ start_idx: Starting index for the batch
87
+ stop_on_first_result: Whether to stop after first valid result
88
+ sequential: Whether to process sequentially
89
+ handle_exceptions: How to handle exceptions:
90
+ - RAISE or False: Let exceptions propagate
91
+ - RETURN_NONE or True: Convert exceptions to None in results
92
+ - RETURN_EXCEPTION: Include exception objects in results
93
+ Boolean values are deprecated and will be removed in a future version.
94
+ output_map: Function to map results to final output format
95
+ """
96
+ exception_handling = _convert_exception_handling(handle_exceptions)
97
+
98
+ def handle_error(e: BaseException) -> Any:
99
+ """Handle exceptions based on exception_handling."""
100
+ match exception_handling:
101
+ case ExceptionHandling.RAISE:
102
+ raise e
103
+ case ExceptionHandling.RETURN_NONE:
104
+ return None
105
+ case ExceptionHandling.RETURN_EXCEPTION:
106
+ return e
107
+
108
+ if stop_on_first_result:
109
+ results: List[Optional[ChatDocument] | BaseException] = []
110
+ pending: set[asyncio.Task[Any]] = set()
111
+ # Create task-to-index mapping
112
+ task_indices: dict[asyncio.Task[Any], int] = {}
113
+ try:
114
+ tasks = [
115
+ asyncio.create_task(do_task(input, i + start_idx))
116
+ for i, input in enumerate(inputs)
117
+ ]
118
+ task_indices = {task: i for i, task in enumerate(tasks)}
119
+ results = [None] * len(tasks)
120
+
121
+ done, pending = await asyncio.wait(
122
+ tasks, return_when=asyncio.FIRST_COMPLETED
123
+ )
124
+
125
+ # Process completed tasks
126
+ for task in done:
127
+ index = task_indices[task]
128
+ try:
129
+ result = await task
130
+ results[index] = output_map(result)
131
+ except BaseException as e:
132
+ results[index] = handle_error(e)
133
+
134
+ if any(r is not None for r in results):
135
+ return results
136
+ finally:
137
+ for task in pending:
138
+ task.cancel()
139
+ try:
140
+ await asyncio.gather(*pending, return_exceptions=True)
141
+ except BaseException as e:
142
+ handle_error(e)
143
+ return results
144
+
145
+ elif sequential:
146
+ results = []
147
+ for i, input in enumerate(inputs):
148
+ try:
149
+ result = await do_task(input, i + start_idx)
150
+ results.append(output_map(result))
151
+ except BaseException as e:
152
+ results.append(handle_error(e))
153
+ return results
154
+
155
+ # Parallel execution
156
+ else:
157
+ try:
158
+ return_exceptions = exception_handling != ExceptionHandling.RAISE
159
+ with quiet_mode(), SuppressLoggerWarnings():
160
+ results_with_exceptions = cast(
161
+ list[Optional[ChatDocument | BaseException]],
162
+ await asyncio.gather(
163
+ *(
164
+ do_task(input, i + start_idx)
165
+ for i, input in enumerate(inputs)
166
+ ),
167
+ return_exceptions=return_exceptions,
168
+ ),
169
+ )
170
+
171
+ if exception_handling == ExceptionHandling.RETURN_NONE:
172
+ results = [
173
+ None if isinstance(r, BaseException) else r
174
+ for r in results_with_exceptions
175
+ ]
176
+ else: # ExceptionHandling.RETURN_EXCEPTION
177
+ results = results_with_exceptions
178
+ except BaseException as e:
179
+ results = [handle_error(e) for _ in inputs]
180
+
181
+ return [output_map(r) for r in results]
182
+
183
+
184
+ def run_batched_tasks(
185
+ inputs: List[str | ChatDocument],
186
+ do_task: Callable[[str | ChatDocument, int], Coroutine[Any, Any, Any]],
187
+ batch_size: Optional[int],
188
+ stop_on_first_result: bool,
189
+ sequential: bool,
190
+ handle_exceptions: Union[bool, ExceptionHandling],
191
+ output_map: Callable[[Any], Any],
192
+ message_template: str,
193
+ message: Optional[str] = None,
194
+ ) -> List[Any]:
195
+ """
196
+ Common batch processing logic for both agent methods and tasks.
197
+
198
+ Args:
199
+ inputs: List of inputs to process
200
+ do_task: Task execution function
201
+ batch_size: Size of batches, if None process all at once
202
+ stop_on_first_result: Whether to stop after first valid result
203
+ sequential: Whether to process sequentially
204
+ handle_exceptions: How to handle exceptions:
205
+ - RAISE or False: Let exceptions propagate
206
+ - RETURN_NONE or True: Convert exceptions to None in results
207
+ - RETURN_EXCEPTION: Include exception objects in results
208
+ Boolean values are deprecated and will be removed in a future version.
209
+ output_map: Function to map results
210
+ message_template: Template for status message
211
+ message: Optional override for status message
212
+ """
213
+
214
+ async def run_all_batched_tasks(
215
+ inputs: List[str | ChatDocument],
216
+ batch_size: int | None,
217
+ ) -> List[Any]:
218
+ """Extra wrap to run asyncio.run one single time and not once per loop
219
+
220
+ Args:
221
+ inputs (List[str | ChatDocument]): inputs to process
222
+ batch_size (int | None): batch size
223
+
224
+ Returns:
225
+ List[Any]: results
226
+ """
227
+ results: List[Any] = []
228
+ if batch_size is None:
229
+ msg = message or message_template.format(total=len(inputs))
230
+ with status(msg), SuppressLoggerWarnings():
231
+ results = await _process_batch_async(
232
+ inputs,
233
+ do_task,
234
+ stop_on_first_result=stop_on_first_result,
235
+ sequential=sequential,
236
+ handle_exceptions=handle_exceptions,
237
+ output_map=output_map,
238
+ )
239
+ else:
240
+ batches = batched(inputs, batch_size)
241
+ for batch in batches:
242
+ start_idx = len(results)
243
+ complete_str = f", {start_idx} complete" if start_idx > 0 else ""
244
+ msg = (
245
+ message or message_template.format(total=len(inputs)) + complete_str
246
+ )
247
+
248
+ if stop_on_first_result and any(r is not None for r in results):
249
+ results.extend([None] * len(batch))
250
+ else:
251
+ with status(msg), SuppressLoggerWarnings():
252
+ results.extend(
253
+ await _process_batch_async(
254
+ batch,
255
+ do_task,
256
+ start_idx=start_idx,
257
+ stop_on_first_result=stop_on_first_result,
258
+ sequential=sequential,
259
+ handle_exceptions=handle_exceptions,
260
+ output_map=output_map,
261
+ )
262
+ )
263
+ return results
264
+
265
+ return asyncio.run(run_all_batched_tasks(inputs, batch_size))
266
+
267
+
268
+ def run_batch_task_gen(
269
+ gen_task: Callable[[int], Task],
270
+ items: list[T],
271
+ input_map: Callable[[T], str | ChatDocument] = lambda x: str(x),
272
+ output_map: Callable[[ChatDocument | None], U] = lambda x: x, # type: ignore
273
+ stop_on_first_result: bool = False,
274
+ sequential: bool = True,
275
+ batch_size: Optional[int] = None,
276
+ turns: int = -1,
277
+ message: Optional[str] = None,
278
+ handle_exceptions: Union[bool, ExceptionHandling] = ExceptionHandling.RAISE,
279
+ max_cost: float = 0.0,
280
+ max_tokens: int = 0,
281
+ ) -> list[Optional[U]]:
282
+ """
283
+ Generate and run copies of a task async/concurrently one per item in `items` list.
284
+ For each item, apply `input_map` to get the initial message to process.
285
+ For each result, apply `output_map` to get the final result.
286
+ Args:
287
+ gen_task (Callable[[int], Task]): generates the tasks to run
288
+ items (list[T]): list of items to process
289
+ input_map (Callable[[T], str|ChatDocument]): function to map item to
290
+ initial message to process
291
+ output_map (Callable[[ChatDocument|str], U]): function to map result
292
+ to final result. If stop_on_first_result is enabled, then
293
+ map any invalid output to None. We continue until some non-None
294
+ result is obtained.
295
+ stop_on_first_result (bool): whether to stop after the first valid
296
+ (not-None) result. In this case all other tasks are
297
+ cancelled, and their corresponding result is None in the
298
+ returned list.
299
+ sequential (bool): whether to run sequentially
300
+ (e.g. some APIs such as ooba don't support concurrent requests)
301
+ batch_size (Optional[int]): The number of tasks to run at a time,
302
+ if None, unbatched
303
+ turns (int): number of turns to run, -1 for infinite
304
+ message (Optional[str]): optionally overrides the console status messages
305
+ handle_exceptions: How to handle exceptions:
306
+ - RAISE or False: Let exceptions propagate
307
+ - RETURN_NONE or True: Convert exceptions to None in results
308
+ - RETURN_EXCEPTION: Include exception objects in results
309
+ Boolean values are deprecated and will be removed in a future version.
310
+ max_cost: float: maximum cost to run the task (default 0.0 for unlimited)
311
+ max_tokens: int: maximum token usage (in and out) (default 0 for unlimited)
312
+
313
+
314
+ Returns:
315
+ list[Optional[U]]: list of final results. Always list[U] if
316
+ `stop_on_first_result` is disabled
317
+ """
318
+ inputs = [input_map(item) for item in items]
319
+
320
+ async def _do_task(
321
+ input: str | ChatDocument,
322
+ i: int,
323
+ ) -> BaseException | Optional[ChatDocument] | tuple[int, Optional[ChatDocument]]:
324
+ task_i = gen_task(i)
325
+ if task_i.agent.llm is not None:
326
+ task_i.agent.llm.set_stream(False)
327
+ task_i.agent.config.show_stats = False
328
+
329
+ try:
330
+ result = await task_i.run_async(
331
+ input, turns=turns, max_cost=max_cost, max_tokens=max_tokens
332
+ )
333
+ except asyncio.CancelledError as e:
334
+ task_i.kill()
335
+ # exception will be handled by the caller
336
+ raise e
337
+ return result
338
+
339
+ return run_batched_tasks(
340
+ inputs=inputs,
341
+ do_task=_do_task,
342
+ batch_size=batch_size,
343
+ stop_on_first_result=stop_on_first_result,
344
+ sequential=sequential,
345
+ handle_exceptions=handle_exceptions,
346
+ output_map=output_map,
347
+ message_template="[bold green]Running {total} tasks:",
348
+ message=message,
349
+ )
350
+
351
+
352
+ def run_batch_tasks(
353
+ task: Task,
354
+ items: list[T],
355
+ input_map: Callable[[T], str | ChatDocument] = lambda x: str(x),
356
+ output_map: Callable[[ChatDocument | None], U] = lambda x: x, # type: ignore
357
+ stop_on_first_result: bool = False,
358
+ sequential: bool = True,
359
+ batch_size: Optional[int] = None,
360
+ turns: int = -1,
361
+ max_cost: float = 0.0,
362
+ max_tokens: int = 0,
363
+ ) -> List[Optional[U]]:
364
+ """
365
+ Run copies of `task` async/concurrently one per item in `items` list.
366
+ For each item, apply `input_map` to get the initial message to process.
367
+ For each result, apply `output_map` to get the final result.
368
+ Args:
369
+ task (Task): task to run
370
+ items (list[T]): list of items to process
371
+ input_map (Callable[[T], str|ChatDocument]): function to map item to
372
+ initial message to process
373
+ output_map (Callable[[ChatDocument|str], U]): function to map result
374
+ to final result
375
+ sequential (bool): whether to run sequentially
376
+ (e.g. some APIs such as ooba don't support concurrent requests)
377
+ batch_size (Optional[int]): The number of tasks to run at a time,
378
+ if None, unbatched
379
+ turns (int): number of turns to run, -1 for infinite
380
+ max_cost: float: maximum cost to run the task (default 0.0 for unlimited)
381
+ max_tokens: int: maximum token usage (in and out) (default 0 for unlimited)
382
+
383
+ Returns:
384
+ list[Optional[U]]: list of final results. Always list[U] if
385
+ `stop_on_first_result` is disabled
386
+ """
387
+ message = f"[bold green]Running {len(items)} copies of {task.name}..."
388
+ return run_batch_task_gen(
389
+ lambda i: task.clone(i),
390
+ items,
391
+ input_map,
392
+ output_map,
393
+ stop_on_first_result,
394
+ sequential,
395
+ batch_size,
396
+ turns,
397
+ message,
398
+ max_cost=max_cost,
399
+ max_tokens=max_tokens,
400
+ )
401
+
402
+
403
+ def run_batch_agent_method(
404
+ agent: Agent,
405
+ method: Callable[
406
+ [str | ChatDocument | None], Coroutine[Any, Any, ChatDocument | None]
407
+ ],
408
+ items: List[Any],
409
+ input_map: Callable[[Any], str | ChatDocument] = lambda x: str(x),
410
+ output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
411
+ sequential: bool = True,
412
+ stop_on_first_result: bool = False,
413
+ handle_exceptions: Union[bool, ExceptionHandling] = ExceptionHandling.RAISE,
414
+ batch_size: Optional[int] = None,
415
+ ) -> List[Any]:
416
+ """
417
+ Run the `method` on copies of `agent`, async/concurrently one per
418
+ item in `items` list.
419
+ ASSUMPTION: The `method` is an async method and has signature:
420
+ method(self, input: str|ChatDocument|None) -> ChatDocument|None
421
+ So this would typically be used for the agent's "responder" methods,
422
+ e.g. `llm_response_async` or `agent_responder_async`.
423
+
424
+ For each item, apply `input_map` to get the initial message to process.
425
+ For each result, apply `output_map` to get the final result.
426
+
427
+ Args:
428
+ agent (Agent): agent whose method to run
429
+ method (str): Async method to run on copies of `agent`.
430
+ The method is assumed to have signature:
431
+ `method(self, input: str|ChatDocument|None) -> ChatDocument|None`
432
+ input_map (Callable[[Any], str|ChatDocument]): function to map item to
433
+ initial message to process
434
+ output_map (Callable[[ChatDocument|str], Any]): function to map result
435
+ to final result
436
+ sequential (bool): whether to run sequentially
437
+ (e.g. some APIs such as ooba don't support concurrent requests)
438
+ stop_on_first_result (bool): whether to stop after the first valid
439
+ handle_exceptions: How to handle exceptions:
440
+ - RAISE or False: Let exceptions propagate
441
+ - RETURN_NONE or True: Convert exceptions to None in results
442
+ - RETURN_EXCEPTION: Include exception objects in results
443
+ Boolean values are deprecated and will be removed in a future version.
444
+ batch_size (Optional[int]): The number of items to process in each batch.
445
+ If None, process all items at once.
446
+ Returns:
447
+ List[Any]: list of final results
448
+ """
449
+ # Check if the method is async
450
+ method_name = method.__name__
451
+ if not inspect.iscoroutinefunction(method):
452
+ raise ValueError(f"The method {method_name} is not async.")
453
+
454
+ inputs = [input_map(item) for item in items]
455
+ agent_cfg = copy.deepcopy(agent.config)
456
+ assert agent_cfg.llm is not None, "agent must have llm config"
457
+ agent_cfg.llm.stream = False
458
+ agent_cfg.show_stats = False
459
+ agent_cls = type(agent)
460
+ agent_name = agent_cfg.name
461
+
462
+ async def _do_task(input: str | ChatDocument, i: int) -> Any:
463
+ agent_cfg.name = f"{agent_cfg.name}-{i}"
464
+ agent_i = agent_cls(agent_cfg)
465
+ method_i = getattr(agent_i, method_name, None)
466
+ if method_i is None:
467
+ raise ValueError(f"Agent {agent_name} has no method {method_name}")
468
+ result = await method_i(input)
469
+ return result
470
+
471
+ return run_batched_tasks(
472
+ inputs=inputs,
473
+ do_task=_do_task,
474
+ batch_size=batch_size,
475
+ stop_on_first_result=stop_on_first_result,
476
+ sequential=sequential,
477
+ handle_exceptions=handle_exceptions,
478
+ output_map=output_map,
479
+ message_template=f"[bold green]Running {{total}} copies of {agent_name}...",
480
+ )
481
+
482
+
483
+ def llm_response_batch(
484
+ agent: Agent,
485
+ items: List[Any],
486
+ input_map: Callable[[Any], str | ChatDocument] = lambda x: str(x),
487
+ output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
488
+ sequential: bool = True,
489
+ stop_on_first_result: bool = False,
490
+ batch_size: Optional[int] = None,
491
+ ) -> List[Any]:
492
+ return run_batch_agent_method(
493
+ agent,
494
+ agent.llm_response_async,
495
+ items,
496
+ input_map=input_map,
497
+ output_map=output_map,
498
+ sequential=sequential,
499
+ stop_on_first_result=stop_on_first_result,
500
+ batch_size=batch_size,
501
+ )
502
+
503
+
504
+ def agent_response_batch(
505
+ agent: Agent,
506
+ items: List[Any],
507
+ input_map: Callable[[Any], str | ChatDocument] = lambda x: str(x),
508
+ output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
509
+ sequential: bool = True,
510
+ stop_on_first_result: bool = False,
511
+ batch_size: Optional[int] = None,
512
+ ) -> List[Any]:
513
+ return run_batch_agent_method(
514
+ agent,
515
+ agent.agent_response_async,
516
+ items,
517
+ input_map=input_map,
518
+ output_map=output_map,
519
+ sequential=sequential,
520
+ stop_on_first_result=stop_on_first_result,
521
+ batch_size=batch_size,
522
+ )
523
+
524
+
525
+ def run_batch_function(
526
+ function: Callable[[T], U],
527
+ items: list[T],
528
+ sequential: bool = True,
529
+ batch_size: Optional[int] = None,
530
+ ) -> List[U]:
531
+ async def _do_task(item: T) -> U:
532
+ return function(item)
533
+
534
+ async def _do_all(items: Iterable[T]) -> List[U]:
535
+ if sequential:
536
+ results = []
537
+ for item in items:
538
+ result = await _do_task(item)
539
+ results.append(result)
540
+ return results
541
+
542
+ return await asyncio.gather(*(_do_task(item) for item in items))
543
+
544
+ results: List[U] = []
545
+
546
+ if batch_size is None:
547
+ with status(f"[bold green]Running {len(items)} tasks:"):
548
+ results = asyncio.run(_do_all(items))
549
+ else:
550
+ batches = batched(items, batch_size)
551
+ for batch in batches:
552
+ with status(f"[bold green]Running batch of {len(batch)} tasks:"):
553
+ results.extend(asyncio.run(_do_all(batch)))
554
+
555
+ return results
@@ -1845,14 +1845,16 @@ class ChatAgent(Agent):
1845
1845
  self.update_last_message(message, role=Role.USER)
1846
1846
  return answer_doc
1847
1847
 
1848
- def llm_response_forget(self, message: str) -> ChatDocument:
1848
+ def llm_response_forget(
1849
+ self, message: Optional[str | ChatDocument] = None
1850
+ ) -> ChatDocument:
1849
1851
  """
1850
1852
  LLM Response to single message, and restore message_history.
1851
1853
  In effect a "one-off" message & response that leaves agent
1852
1854
  message history state intact.
1853
1855
 
1854
1856
  Args:
1855
- message (str): user message
1857
+ message (str|ChatDocument): message to respond to.
1856
1858
 
1857
1859
  Returns:
1858
1860
  A Document object with the response.
@@ -1879,7 +1881,9 @@ class ChatAgent(Agent):
1879
1881
 
1880
1882
  return response
1881
1883
 
1882
- async def llm_response_forget_async(self, message: str) -> ChatDocument:
1884
+ async def llm_response_forget_async(
1885
+ self, message: Optional[str | ChatDocument] = None
1886
+ ) -> ChatDocument:
1883
1887
  """
1884
1888
  Async version of `llm_response_forget`. See there for details.
1885
1889
  """