langroid 0.33.6__py3-none-any.whl → 0.33.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +106 -0
- langroid/agent/__init__.py +41 -0
- langroid/agent/base.py +1983 -0
- langroid/agent/batch.py +398 -0
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +598 -0
- langroid/agent/chat_agent.py +1899 -0
- langroid/agent/chat_document.py +454 -0
- langroid/agent/openai_assistant.py +882 -0
- langroid/agent/special/__init__.py +59 -0
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +656 -0
- langroid/agent/special/arangodb/system_messages.py +186 -0
- langroid/agent/special/arangodb/tools.py +107 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/doc_chat_agent.py +1466 -0
- langroid/agent/special/lance_doc_chat_agent.py +262 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +198 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
- langroid/agent/special/lance_tools.py +61 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +56 -0
- langroid/agent/special/sql/__init__.py +17 -0
- langroid/agent/special/sql/sql_chat_agent.py +654 -0
- langroid/agent/special/sql/utils/__init__.py +21 -0
- langroid/agent/special/sql/utils/description_extractors.py +190 -0
- langroid/agent/special/sql/utils/populate_metadata.py +85 -0
- langroid/agent/special/sql/utils/system_message.py +35 -0
- langroid/agent/special/sql/utils/tools.py +64 -0
- langroid/agent/special/table_chat_agent.py +263 -0
- langroid/agent/task.py +2099 -0
- langroid/agent/tool_message.py +393 -0
- langroid/agent/tools/__init__.py +38 -0
- langroid/agent/tools/duckduckgo_search_tool.py +50 -0
- langroid/agent/tools/file_tools.py +234 -0
- langroid/agent/tools/google_search_tool.py +39 -0
- langroid/agent/tools/metaphor_search_tool.py +68 -0
- langroid/agent/tools/orchestration.py +303 -0
- langroid/agent/tools/recipient_tool.py +235 -0
- langroid/agent/tools/retrieval_tool.py +32 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/agent/tools/segment_extract_tool.py +41 -0
- langroid/agent/xml_tool_message.py +382 -0
- langroid/cachedb/__init__.py +17 -0
- langroid/cachedb/base.py +58 -0
- langroid/cachedb/momento_cachedb.py +108 -0
- langroid/cachedb/redis_cachedb.py +153 -0
- langroid/embedding_models/__init__.py +39 -0
- langroid/embedding_models/base.py +74 -0
- langroid/embedding_models/models.py +461 -0
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/exceptions.py +71 -0
- langroid/language_models/__init__.py +53 -0
- langroid/language_models/azure_openai.py +153 -0
- langroid/language_models/base.py +678 -0
- langroid/language_models/config.py +18 -0
- langroid/language_models/mock_lm.py +124 -0
- langroid/language_models/openai_gpt.py +1964 -0
- langroid/language_models/prompt_formatter/__init__.py +16 -0
- langroid/language_models/prompt_formatter/base.py +40 -0
- langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
- langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
- langroid/language_models/utils.py +151 -0
- langroid/mytypes.py +84 -0
- langroid/parsing/__init__.py +52 -0
- langroid/parsing/agent_chats.py +38 -0
- langroid/parsing/code_parser.py +121 -0
- langroid/parsing/document_parser.py +718 -0
- langroid/parsing/para_sentence_split.py +62 -0
- langroid/parsing/parse_json.py +155 -0
- langroid/parsing/parser.py +313 -0
- langroid/parsing/repo_loader.py +790 -0
- langroid/parsing/routing.py +36 -0
- langroid/parsing/search.py +275 -0
- langroid/parsing/spider.py +102 -0
- langroid/parsing/table_loader.py +94 -0
- langroid/parsing/url_loader.py +115 -0
- langroid/parsing/urls.py +273 -0
- langroid/parsing/utils.py +373 -0
- langroid/parsing/web_search.py +156 -0
- langroid/prompts/__init__.py +9 -0
- langroid/prompts/dialog.py +17 -0
- langroid/prompts/prompts_config.py +5 -0
- langroid/prompts/templates.py +141 -0
- langroid/pydantic_v1/__init__.py +10 -0
- langroid/pydantic_v1/main.py +4 -0
- langroid/utils/__init__.py +19 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +98 -0
- langroid/utils/constants.py +30 -0
- langroid/utils/git_utils.py +252 -0
- langroid/utils/globals.py +49 -0
- langroid/utils/logging.py +135 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/output/__init__.py +20 -0
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +99 -0
- langroid/utils/output/status.py +40 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +602 -0
- langroid/utils/system.py +286 -0
- langroid/utils/types.py +93 -0
- langroid/vector_store/__init__.py +50 -0
- langroid/vector_store/base.py +359 -0
- langroid/vector_store/chromadb.py +214 -0
- langroid/vector_store/lancedb.py +406 -0
- langroid/vector_store/meilisearch.py +299 -0
- langroid/vector_store/momento.py +278 -0
- langroid/vector_store/qdrantdb.py +468 -0
- {langroid-0.33.6.dist-info → langroid-0.33.8.dist-info}/METADATA +95 -94
- langroid-0.33.8.dist-info/RECORD +127 -0
- {langroid-0.33.6.dist-info → langroid-0.33.8.dist-info}/WHEEL +1 -1
- langroid-0.33.6.dist-info/RECORD +0 -7
- langroid-0.33.6.dist-info/entry_points.txt +0 -4
- pyproject.toml +0 -356
- {langroid-0.33.6.dist-info → langroid-0.33.8.dist-info}/licenses/LICENSE +0 -0
langroid/agent/base.py
ADDED
@@ -0,0 +1,1983 @@
|
|
1
|
+
import asyncio
|
2
|
+
import copy
|
3
|
+
import inspect
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import re
|
7
|
+
from abc import ABC
|
8
|
+
from collections import OrderedDict
|
9
|
+
from contextlib import ExitStack
|
10
|
+
from types import SimpleNamespace
|
11
|
+
from typing import (
|
12
|
+
Any,
|
13
|
+
Callable,
|
14
|
+
Coroutine,
|
15
|
+
Dict,
|
16
|
+
List,
|
17
|
+
Optional,
|
18
|
+
Set,
|
19
|
+
Tuple,
|
20
|
+
Type,
|
21
|
+
TypeVar,
|
22
|
+
cast,
|
23
|
+
get_args,
|
24
|
+
get_origin,
|
25
|
+
no_type_check,
|
26
|
+
)
|
27
|
+
|
28
|
+
from rich import print
|
29
|
+
from rich.console import Console
|
30
|
+
from rich.markup import escape
|
31
|
+
from rich.prompt import Prompt
|
32
|
+
|
33
|
+
from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
|
34
|
+
from langroid.agent.tool_message import ToolMessage
|
35
|
+
from langroid.agent.xml_tool_message import XMLToolMessage
|
36
|
+
from langroid.exceptions import XMLException
|
37
|
+
from langroid.language_models.base import (
|
38
|
+
LanguageModel,
|
39
|
+
LLMConfig,
|
40
|
+
LLMFunctionCall,
|
41
|
+
LLMMessage,
|
42
|
+
LLMResponse,
|
43
|
+
LLMTokenUsage,
|
44
|
+
OpenAIToolCall,
|
45
|
+
StreamingIfAllowed,
|
46
|
+
ToolChoiceTypes,
|
47
|
+
)
|
48
|
+
from langroid.language_models.openai_gpt import OpenAIGPT, OpenAIGPTConfig
|
49
|
+
from langroid.mytypes import Entity
|
50
|
+
from langroid.parsing.parse_json import extract_top_level_json
|
51
|
+
from langroid.parsing.parser import Parser, ParsingConfig
|
52
|
+
from langroid.prompts.prompts_config import PromptsConfig
|
53
|
+
from langroid.pydantic_v1 import (
|
54
|
+
BaseSettings,
|
55
|
+
Field,
|
56
|
+
ValidationError,
|
57
|
+
validator,
|
58
|
+
)
|
59
|
+
from langroid.utils.configuration import settings
|
60
|
+
from langroid.utils.constants import (
|
61
|
+
DONE,
|
62
|
+
NO_ANSWER,
|
63
|
+
PASS,
|
64
|
+
PASS_TO,
|
65
|
+
SEND_TO,
|
66
|
+
)
|
67
|
+
from langroid.utils.object_registry import ObjectRegistry
|
68
|
+
from langroid.utils.output import status
|
69
|
+
from langroid.utils.types import from_string, to_string
|
70
|
+
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
71
|
+
|
72
|
+
ORCHESTRATION_STRINGS = [DONE, PASS, PASS_TO, SEND_TO]
|
73
|
+
console = Console(quiet=settings.quiet)
|
74
|
+
|
75
|
+
logger = logging.getLogger(__name__)
|
76
|
+
|
77
|
+
T = TypeVar("T")
|
78
|
+
|
79
|
+
|
80
|
+
class AgentConfig(BaseSettings):
|
81
|
+
"""
|
82
|
+
General config settings for an LLM agent. This is nested, combining configs of
|
83
|
+
various components.
|
84
|
+
"""
|
85
|
+
|
86
|
+
name: str = "LLM-Agent"
|
87
|
+
debug: bool = False
|
88
|
+
vecdb: Optional[VectorStoreConfig] = None
|
89
|
+
llm: Optional[LLMConfig] = OpenAIGPTConfig()
|
90
|
+
parsing: Optional[ParsingConfig] = ParsingConfig()
|
91
|
+
prompts: Optional[PromptsConfig] = PromptsConfig()
|
92
|
+
show_stats: bool = True # show token usage/cost stats?
|
93
|
+
add_to_registry: bool = True # register agent in ObjectRegistry?
|
94
|
+
respond_tools_only: bool = False # respond only to tool messages (not plain text)?
|
95
|
+
# allow multiple tool messages in a single response?
|
96
|
+
allow_multiple_tools: bool = True
|
97
|
+
human_prompt: str = (
|
98
|
+
"Human (respond or q, x to exit current level, " "or hit enter to continue)"
|
99
|
+
)
|
100
|
+
|
101
|
+
@validator("name")
|
102
|
+
def check_name_alphanum(cls, v: str) -> str:
|
103
|
+
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
104
|
+
raise ValueError(
|
105
|
+
"The name must only contain alphanumeric characters, "
|
106
|
+
"underscores, or hyphens, with no spaces"
|
107
|
+
)
|
108
|
+
return v
|
109
|
+
|
110
|
+
|
111
|
+
def noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
|
112
|
+
pass
|
113
|
+
|
114
|
+
|
115
|
+
async def async_noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
|
116
|
+
pass
|
117
|
+
|
118
|
+
|
119
|
+
async def async_lambda_noop_fn() -> Callable[..., Coroutine[Any, Any, None]]:
|
120
|
+
return async_noop_fn
|
121
|
+
|
122
|
+
|
123
|
+
class Agent(ABC):
|
124
|
+
"""
|
125
|
+
An Agent is an abstraction that encapsulates mainly two components:
|
126
|
+
|
127
|
+
- a language model (LLM)
|
128
|
+
- a vector store (vecdb)
|
129
|
+
|
130
|
+
plus associated components such as a parser, and variables that hold
|
131
|
+
information about any tool/function-calling messages that have been defined.
|
132
|
+
"""
|
133
|
+
|
134
|
+
id: str = Field(default_factory=lambda: ObjectRegistry.new_id())
|
135
|
+
# OpenAI tool-calls awaiting response; update when a tool result with Role.TOOL
|
136
|
+
# is added to self.message_history
|
137
|
+
oai_tool_calls: List[OpenAIToolCall] = []
|
138
|
+
# Index of ALL tool calls generated by the agent
|
139
|
+
oai_tool_id2call: Dict[str, OpenAIToolCall] = {}
|
140
|
+
|
141
|
+
def __init__(self, config: AgentConfig = AgentConfig()):
|
142
|
+
self.config = config
|
143
|
+
self.lock = asyncio.Lock() # for async access to update self.llm.usage_cost
|
144
|
+
self.dialog: List[Tuple[str, str]] = [] # seq of LLM (prompt, response) tuples
|
145
|
+
self.llm_tools_map: Dict[str, Type[ToolMessage]] = {}
|
146
|
+
self.llm_tools_handled: Set[str] = set()
|
147
|
+
self.llm_tools_usable: Set[str] = set()
|
148
|
+
self.llm_tools_known: Set[str] = set() # all known tools, handled/used or not
|
149
|
+
# Indicates which tool-names are allowed to be inferred when
|
150
|
+
# the LLM "forgets" to include the request field in its
|
151
|
+
# tool-call.
|
152
|
+
self.enabled_requests_for_inference: Optional[Set[str]] = (
|
153
|
+
None # If None, we allow all
|
154
|
+
)
|
155
|
+
self.interactive: bool = True # may be modified by Task wrapper
|
156
|
+
self.token_stats_str = ""
|
157
|
+
self.default_human_response: Optional[str] = None
|
158
|
+
self._indent = ""
|
159
|
+
self.llm = LanguageModel.create(config.llm)
|
160
|
+
self.vecdb = VectorStore.create(config.vecdb) if config.vecdb else None
|
161
|
+
self.tool_error = False
|
162
|
+
if config.parsing is not None and self.config.llm is not None:
|
163
|
+
# token_encoding_model is used to obtain the tokenizer,
|
164
|
+
# so in case it's an OpenAI model, we ensure that the tokenizer
|
165
|
+
# corresponding to the model is used.
|
166
|
+
if isinstance(self.llm, OpenAIGPT) and self.llm.is_openai_chat_model():
|
167
|
+
config.parsing.token_encoding_model = self.llm.config.chat_model
|
168
|
+
self.parser: Optional[Parser] = (
|
169
|
+
Parser(config.parsing) if config.parsing else None
|
170
|
+
)
|
171
|
+
if config.add_to_registry:
|
172
|
+
ObjectRegistry.register_object(self)
|
173
|
+
|
174
|
+
self.callbacks = SimpleNamespace(
|
175
|
+
start_llm_stream=lambda: noop_fn,
|
176
|
+
start_llm_stream_async=async_lambda_noop_fn,
|
177
|
+
cancel_llm_stream=noop_fn,
|
178
|
+
finish_llm_stream=noop_fn,
|
179
|
+
show_llm_response=noop_fn,
|
180
|
+
show_agent_response=noop_fn,
|
181
|
+
get_user_response=None,
|
182
|
+
get_user_response_async=None,
|
183
|
+
get_last_step=noop_fn,
|
184
|
+
set_parent_agent=noop_fn,
|
185
|
+
show_error_message=noop_fn,
|
186
|
+
show_start_response=noop_fn,
|
187
|
+
)
|
188
|
+
Agent.init_state(self)
|
189
|
+
|
190
|
+
def init_state(self) -> None:
|
191
|
+
"""Initialize all state vars. Called by Task.run() if restart is True"""
|
192
|
+
self.total_llm_token_cost = 0.0
|
193
|
+
self.total_llm_token_usage = 0
|
194
|
+
|
195
|
+
@staticmethod
|
196
|
+
def from_id(id: str) -> "Agent":
|
197
|
+
return cast(Agent, ObjectRegistry.get(id))
|
198
|
+
|
199
|
+
@staticmethod
|
200
|
+
def delete_id(id: str) -> None:
|
201
|
+
ObjectRegistry.remove(id)
|
202
|
+
|
203
|
+
def entity_responders(
|
204
|
+
self,
|
205
|
+
) -> List[
|
206
|
+
Tuple[Entity, Callable[[None | str | ChatDocument], None | ChatDocument]]
|
207
|
+
]:
|
208
|
+
"""
|
209
|
+
Sequence of (entity, response_method) pairs. This sequence is used
|
210
|
+
in a `Task` to respond to the current pending message.
|
211
|
+
See `Task.step()` for details.
|
212
|
+
Returns:
|
213
|
+
Sequence of (entity, response_method) pairs.
|
214
|
+
"""
|
215
|
+
return [
|
216
|
+
(Entity.AGENT, self.agent_response),
|
217
|
+
(Entity.LLM, self.llm_response),
|
218
|
+
(Entity.USER, self.user_response),
|
219
|
+
]
|
220
|
+
|
221
|
+
def entity_responders_async(
|
222
|
+
self,
|
223
|
+
) -> List[
|
224
|
+
Tuple[
|
225
|
+
Entity,
|
226
|
+
Callable[
|
227
|
+
[None | str | ChatDocument], Coroutine[Any, Any, None | ChatDocument]
|
228
|
+
],
|
229
|
+
]
|
230
|
+
]:
|
231
|
+
"""
|
232
|
+
Async version of `entity_responders`. See there for details.
|
233
|
+
"""
|
234
|
+
return [
|
235
|
+
(Entity.AGENT, self.agent_response_async),
|
236
|
+
(Entity.LLM, self.llm_response_async),
|
237
|
+
(Entity.USER, self.user_response_async),
|
238
|
+
]
|
239
|
+
|
240
|
+
@property
|
241
|
+
def indent(self) -> str:
|
242
|
+
"""Indentation to print before any responses from the agent's entities."""
|
243
|
+
return self._indent
|
244
|
+
|
245
|
+
@indent.setter
|
246
|
+
def indent(self, value: str) -> None:
|
247
|
+
self._indent = value
|
248
|
+
|
249
|
+
def update_dialog(self, prompt: str, output: str) -> None:
|
250
|
+
self.dialog.append((prompt, output))
|
251
|
+
|
252
|
+
def get_dialog(self) -> List[Tuple[str, str]]:
|
253
|
+
return self.dialog
|
254
|
+
|
255
|
+
def clear_dialog(self) -> None:
|
256
|
+
self.dialog = []
|
257
|
+
|
258
|
+
def _get_tool_list(
|
259
|
+
self, message_class: Optional[Type[ToolMessage]] = None
|
260
|
+
) -> List[str]:
|
261
|
+
"""
|
262
|
+
If `message_class` is None, return a list of all known tool names.
|
263
|
+
Otherwise, first add the tool name corresponding to the message class
|
264
|
+
(which is the value of the `request` field of the message class),
|
265
|
+
to the `self.llm_tools_map` dict, and then return a list
|
266
|
+
containing this tool name.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
message_class (Optional[Type[ToolMessage]]): The message class whose tool
|
270
|
+
name is to be returned; Optional, default is None.
|
271
|
+
if None, return a list of all known tool names.
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
List[str]: List of tool names: either just the tool name corresponding
|
275
|
+
to the message class, or all known tool names
|
276
|
+
(when `message_class` is None).
|
277
|
+
|
278
|
+
"""
|
279
|
+
if message_class is None:
|
280
|
+
return list(self.llm_tools_map.keys())
|
281
|
+
|
282
|
+
if not issubclass(message_class, ToolMessage):
|
283
|
+
raise ValueError("message_class must be a subclass of ToolMessage")
|
284
|
+
tool = message_class.default_value("request")
|
285
|
+
|
286
|
+
"""
|
287
|
+
if tool has handler method explicitly defined - use it,
|
288
|
+
otherwise use the tool name as the handler
|
289
|
+
"""
|
290
|
+
if hasattr(message_class, "_handler"):
|
291
|
+
handler = getattr(message_class, "_handler", tool)
|
292
|
+
else:
|
293
|
+
handler = tool
|
294
|
+
|
295
|
+
self.llm_tools_map[tool] = message_class
|
296
|
+
if (
|
297
|
+
hasattr(message_class, "handle")
|
298
|
+
and inspect.isfunction(message_class.handle)
|
299
|
+
and not hasattr(self, handler)
|
300
|
+
):
|
301
|
+
"""
|
302
|
+
If the message class has a `handle` method,
|
303
|
+
and agent does NOT have a tool handler method,
|
304
|
+
then we create a method for the agent whose name
|
305
|
+
is the value of `handler`, and whose body is the `handle` method.
|
306
|
+
This removes a separate step of having to define this method
|
307
|
+
for the agent, and also keeps the tool definition AND handling
|
308
|
+
in one place, i.e. in the message class.
|
309
|
+
See `tests/main/test_stateless_tool_messages.py` for an example.
|
310
|
+
"""
|
311
|
+
has_chat_doc_arg = (
|
312
|
+
len(inspect.signature(message_class.handle).parameters) > 1
|
313
|
+
)
|
314
|
+
if has_chat_doc_arg:
|
315
|
+
setattr(self, handler, lambda obj, chat_doc: obj.handle(chat_doc))
|
316
|
+
else:
|
317
|
+
setattr(self, handler, lambda obj: obj.handle())
|
318
|
+
elif (
|
319
|
+
hasattr(message_class, "response")
|
320
|
+
and inspect.isfunction(message_class.response)
|
321
|
+
and not hasattr(self, handler)
|
322
|
+
):
|
323
|
+
has_chat_doc_arg = (
|
324
|
+
len(inspect.signature(message_class.response).parameters) > 2
|
325
|
+
)
|
326
|
+
if has_chat_doc_arg:
|
327
|
+
setattr(
|
328
|
+
self, handler, lambda obj, chat_doc: obj.response(self, chat_doc)
|
329
|
+
)
|
330
|
+
else:
|
331
|
+
setattr(self, handler, lambda obj: obj.response(self))
|
332
|
+
|
333
|
+
if hasattr(message_class, "handle_message_fallback") and (
|
334
|
+
inspect.isfunction(message_class.handle_message_fallback)
|
335
|
+
):
|
336
|
+
setattr(
|
337
|
+
self,
|
338
|
+
"handle_message_fallback",
|
339
|
+
lambda msg: message_class.handle_message_fallback(self, msg),
|
340
|
+
)
|
341
|
+
|
342
|
+
async_handler_name = f"{handler}_async"
|
343
|
+
if (
|
344
|
+
hasattr(message_class, "handle_async")
|
345
|
+
and inspect.isfunction(message_class.handle_async)
|
346
|
+
and not hasattr(self, async_handler_name)
|
347
|
+
):
|
348
|
+
has_chat_doc_arg = (
|
349
|
+
len(inspect.signature(message_class.handle_async).parameters) > 1
|
350
|
+
)
|
351
|
+
|
352
|
+
if has_chat_doc_arg:
|
353
|
+
|
354
|
+
@no_type_check
|
355
|
+
async def handler(obj, chat_doc):
|
356
|
+
return await obj.handle_async(chat_doc)
|
357
|
+
|
358
|
+
else:
|
359
|
+
|
360
|
+
@no_type_check
|
361
|
+
async def handler(obj):
|
362
|
+
return await obj.handle_async()
|
363
|
+
|
364
|
+
setattr(self, async_handler_name, handler)
|
365
|
+
elif (
|
366
|
+
hasattr(message_class, "response_async")
|
367
|
+
and inspect.isfunction(message_class.response_async)
|
368
|
+
and not hasattr(self, async_handler_name)
|
369
|
+
):
|
370
|
+
has_chat_doc_arg = (
|
371
|
+
len(inspect.signature(message_class.response_async).parameters) > 2
|
372
|
+
)
|
373
|
+
|
374
|
+
if has_chat_doc_arg:
|
375
|
+
|
376
|
+
@no_type_check
|
377
|
+
async def handler(obj, chat_doc):
|
378
|
+
return await obj.response_async(self, chat_doc)
|
379
|
+
|
380
|
+
else:
|
381
|
+
|
382
|
+
@no_type_check
|
383
|
+
async def handler(obj):
|
384
|
+
return await obj.response_async(self)
|
385
|
+
|
386
|
+
setattr(self, async_handler_name, handler)
|
387
|
+
|
388
|
+
return [tool]
|
389
|
+
|
390
|
+
def enable_message_handling(
|
391
|
+
self, message_class: Optional[Type[ToolMessage]] = None
|
392
|
+
) -> None:
|
393
|
+
"""
|
394
|
+
Enable an agent to RESPOND (i.e. handle) a "tool" message of a specific type
|
395
|
+
from LLM. Also "registers" (i.e. adds) the `message_class` to the
|
396
|
+
`self.llm_tools_map` dict.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
message_class (Optional[Type[ToolMessage]]): The message class to enable;
|
400
|
+
Optional; if None, all known message classes are enabled for handling.
|
401
|
+
|
402
|
+
"""
|
403
|
+
for t in self._get_tool_list(message_class):
|
404
|
+
self.llm_tools_handled.add(t)
|
405
|
+
|
406
|
+
def disable_message_handling(
|
407
|
+
self,
|
408
|
+
message_class: Optional[Type[ToolMessage]] = None,
|
409
|
+
) -> None:
|
410
|
+
"""
|
411
|
+
Disable a message class from being handled by this Agent.
|
412
|
+
|
413
|
+
Args:
|
414
|
+
message_class (Optional[Type[ToolMessage]]): The message class to disable.
|
415
|
+
If None, all message classes are disabled.
|
416
|
+
"""
|
417
|
+
for t in self._get_tool_list(message_class):
|
418
|
+
self.llm_tools_handled.discard(t)
|
419
|
+
|
420
|
+
def sample_multi_round_dialog(self) -> str:
|
421
|
+
"""
|
422
|
+
Generate a sample multi-round dialog based on enabled message classes.
|
423
|
+
Returns:
|
424
|
+
str: The sample dialog string.
|
425
|
+
"""
|
426
|
+
enabled_classes: List[Type[ToolMessage]] = list(self.llm_tools_map.values())
|
427
|
+
# use at most 2 sample conversations, no need to be exhaustive;
|
428
|
+
sample_convo = [
|
429
|
+
msg_cls().usage_examples(random=True) # type: ignore
|
430
|
+
for i, msg_cls in enumerate(enabled_classes)
|
431
|
+
if i < 2
|
432
|
+
]
|
433
|
+
return "\n\n".join(sample_convo)
|
434
|
+
|
435
|
+
def create_agent_response(
|
436
|
+
self,
|
437
|
+
content: str | None = None,
|
438
|
+
content_any: Any = None,
|
439
|
+
tool_messages: List[ToolMessage] = [],
|
440
|
+
oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
|
441
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
442
|
+
oai_tool_id2result: OrderedDict[str, str] | None = None,
|
443
|
+
function_call: LLMFunctionCall | None = None,
|
444
|
+
recipient: str = "",
|
445
|
+
) -> ChatDocument:
|
446
|
+
"""Template for agent_response."""
|
447
|
+
return self.response_template(
|
448
|
+
Entity.AGENT,
|
449
|
+
content=content,
|
450
|
+
content_any=content_any,
|
451
|
+
tool_messages=tool_messages,
|
452
|
+
oai_tool_calls=oai_tool_calls,
|
453
|
+
oai_tool_choice=oai_tool_choice,
|
454
|
+
oai_tool_id2result=oai_tool_id2result,
|
455
|
+
function_call=function_call,
|
456
|
+
recipient=recipient,
|
457
|
+
)
|
458
|
+
|
459
|
+
def _agent_response_final(
|
460
|
+
self,
|
461
|
+
msg: Optional[str | ChatDocument],
|
462
|
+
results: Optional[str | OrderedDict[str, str] | ChatDocument],
|
463
|
+
) -> Optional[ChatDocument]:
|
464
|
+
"""
|
465
|
+
Convert results to final response.
|
466
|
+
"""
|
467
|
+
if results is None:
|
468
|
+
return None
|
469
|
+
if isinstance(results, str):
|
470
|
+
results_str = results
|
471
|
+
elif isinstance(results, ChatDocument):
|
472
|
+
results_str = results.content
|
473
|
+
elif isinstance(results, dict):
|
474
|
+
results_str = json.dumps(results, indent=2)
|
475
|
+
if not settings.quiet:
|
476
|
+
console.print(f"[red]{self.indent}", end="")
|
477
|
+
print(f"[red]Agent: {escape(results_str)}")
|
478
|
+
maybe_json = len(extract_top_level_json(results_str)) > 0
|
479
|
+
self.callbacks.show_agent_response(
|
480
|
+
content=results_str,
|
481
|
+
language="json" if maybe_json else "text",
|
482
|
+
)
|
483
|
+
if isinstance(results, ChatDocument):
|
484
|
+
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
485
|
+
results.metadata.tool_ids = (
|
486
|
+
[] if msg is None or isinstance(msg, str) else msg.metadata.tool_ids
|
487
|
+
)
|
488
|
+
return results
|
489
|
+
sender_name = self.config.name
|
490
|
+
if isinstance(msg, ChatDocument) and msg.function_call is not None:
|
491
|
+
# if result was from handling an LLM `function_call`,
|
492
|
+
# set sender_name to name of the function_call
|
493
|
+
sender_name = msg.function_call.name
|
494
|
+
|
495
|
+
results_str, id2result, oai_tool_id = self.process_tool_results(
|
496
|
+
results if isinstance(results, str) else "",
|
497
|
+
id2result=None if isinstance(results, str) else results,
|
498
|
+
tool_calls=(msg.oai_tool_calls if isinstance(msg, ChatDocument) else None),
|
499
|
+
)
|
500
|
+
return ChatDocument(
|
501
|
+
content=results_str,
|
502
|
+
oai_tool_id2result=id2result,
|
503
|
+
metadata=ChatDocMetaData(
|
504
|
+
source=Entity.AGENT,
|
505
|
+
sender=Entity.AGENT,
|
506
|
+
sender_name=sender_name,
|
507
|
+
oai_tool_id=oai_tool_id,
|
508
|
+
# preserve trail of tool_ids for OpenAI Assistant fn-calls
|
509
|
+
tool_ids=(
|
510
|
+
[] if msg is None or isinstance(msg, str) else msg.metadata.tool_ids
|
511
|
+
),
|
512
|
+
),
|
513
|
+
)
|
514
|
+
|
515
|
+
async def agent_response_async(
|
516
|
+
self,
|
517
|
+
msg: Optional[str | ChatDocument] = None,
|
518
|
+
) -> Optional[ChatDocument]:
|
519
|
+
"""
|
520
|
+
Asynch version of `agent_response`. See there for details.
|
521
|
+
"""
|
522
|
+
if msg is None:
|
523
|
+
return None
|
524
|
+
|
525
|
+
results = await self.handle_message_async(msg)
|
526
|
+
|
527
|
+
return self._agent_response_final(msg, results)
|
528
|
+
|
529
|
+
def agent_response(
|
530
|
+
self,
|
531
|
+
msg: Optional[str | ChatDocument] = None,
|
532
|
+
) -> Optional[ChatDocument]:
|
533
|
+
"""
|
534
|
+
Response from the "agent itself", typically (but not only)
|
535
|
+
used to handle LLM's "tool message" or `function_call`
|
536
|
+
(e.g. OpenAI `function_call`).
|
537
|
+
Args:
|
538
|
+
msg (str|ChatDocument): the input to respond to: if msg is a string,
|
539
|
+
and it contains a valid JSON-structured "tool message", or
|
540
|
+
if msg is a ChatDocument, and it contains a `function_call`.
|
541
|
+
Returns:
|
542
|
+
Optional[ChatDocument]: the response, packaged as a ChatDocument
|
543
|
+
|
544
|
+
"""
|
545
|
+
if msg is None:
|
546
|
+
return None
|
547
|
+
|
548
|
+
results = self.handle_message(msg)
|
549
|
+
|
550
|
+
return self._agent_response_final(msg, results)
|
551
|
+
|
552
|
+
def process_tool_results(
|
553
|
+
self,
|
554
|
+
results: str,
|
555
|
+
id2result: OrderedDict[str, str] | None,
|
556
|
+
tool_calls: List[OpenAIToolCall] | None = None,
|
557
|
+
) -> Tuple[str, Dict[str, str] | None, str | None]:
|
558
|
+
"""
|
559
|
+
Process results from a response, based on whether
|
560
|
+
they are results of OpenAI tool-calls from THIS agent, so that
|
561
|
+
we can construct an appropriate LLMMessage that contains tool results.
|
562
|
+
|
563
|
+
Args:
|
564
|
+
results (str): A possible string result from handling tool(s)
|
565
|
+
id2result (OrderedDict[str,str]|None): A dict of OpenAI tool id -> result,
|
566
|
+
if there are multiple tool results.
|
567
|
+
tool_calls (List[OpenAIToolCall]|None): List of OpenAI tool-calls that the
|
568
|
+
results are a response to.
|
569
|
+
|
570
|
+
Return:
|
571
|
+
- str: The response string
|
572
|
+
- Dict[str,str]|None: A dict of OpenAI tool id -> result, if there are
|
573
|
+
multiple tool results.
|
574
|
+
- str|None: tool_id if there was a single tool result
|
575
|
+
|
576
|
+
"""
|
577
|
+
id2result_ = copy.deepcopy(id2result) if id2result is not None else None
|
578
|
+
results_str = ""
|
579
|
+
oai_tool_id = None
|
580
|
+
|
581
|
+
if results != "":
|
582
|
+
# in this case ignore id2result
|
583
|
+
assert (
|
584
|
+
id2result is None
|
585
|
+
), "id2result should be None when results string is non-empty!"
|
586
|
+
results_str = results
|
587
|
+
if len(self.oai_tool_calls) > 0:
|
588
|
+
# We only have one result, so in case there is a
|
589
|
+
# "pending" OpenAI tool-call, we expect no more than 1 such.
|
590
|
+
assert (
|
591
|
+
len(self.oai_tool_calls) == 1
|
592
|
+
), "There are multiple pending tool-calls, but only one result!"
|
593
|
+
# We record the tool_id of the tool-call that
|
594
|
+
# the result is a response to, so that ChatDocument.to_LLMMessage
|
595
|
+
# can properly set the `tool_call_id` field of the LLMMessage.
|
596
|
+
oai_tool_id = self.oai_tool_calls[0].id
|
597
|
+
elif id2result is not None and id2result_ is not None: # appease mypy
|
598
|
+
if len(id2result_) == len(self.oai_tool_calls):
|
599
|
+
# if the number of pending tool calls equals the number of results,
|
600
|
+
# then ignore the ids in id2result, and use the results in order,
|
601
|
+
# which is preserved since id2result is an OrderedDict.
|
602
|
+
assert len(id2result_) > 1, "Expected to see > 1 result in id2result!"
|
603
|
+
results_str = ""
|
604
|
+
id2result_ = OrderedDict(
|
605
|
+
zip(
|
606
|
+
[tc.id or "" for tc in self.oai_tool_calls], id2result_.values()
|
607
|
+
)
|
608
|
+
)
|
609
|
+
else:
|
610
|
+
assert (
|
611
|
+
tool_calls is not None
|
612
|
+
), "tool_calls cannot be None when id2result is not None!"
|
613
|
+
# This must be an OpenAI tool id -> result map;
|
614
|
+
# However some ids may not correspond to the tool-calls in the list of
|
615
|
+
# pending tool-calls (self.oai_tool_calls).
|
616
|
+
# Such results are concatenated into a simple string, to store in the
|
617
|
+
# ChatDocument.content, and the rest
|
618
|
+
# (i.e. those that DO correspond to tools in self.oai_tool_calls)
|
619
|
+
# are stored as a dict in ChatDocument.oai_tool_id2result.
|
620
|
+
|
621
|
+
# OAI tools from THIS agent, awaiting response
|
622
|
+
pending_tool_ids = [tc.id for tc in self.oai_tool_calls]
|
623
|
+
# tool_calls that the results are a response to
|
624
|
+
# (but these may have been sent from another agent, hence may not be in
|
625
|
+
# self.oai_tool_calls)
|
626
|
+
parent_tool_id2name = {
|
627
|
+
tc.id: tc.function.name
|
628
|
+
for tc in tool_calls or []
|
629
|
+
if tc.function is not None
|
630
|
+
}
|
631
|
+
|
632
|
+
# (id, result) for result NOT corresponding to self.oai_tool_calls,
|
633
|
+
# i.e. these are results of EXTERNAL tool-calls from another agent.
|
634
|
+
external_tool_id_results = []
|
635
|
+
|
636
|
+
for tc_id, result in id2result.items():
|
637
|
+
if tc_id not in pending_tool_ids:
|
638
|
+
external_tool_id_results.append((tc_id, result))
|
639
|
+
id2result_.pop(tc_id)
|
640
|
+
if len(external_tool_id_results) == 0:
|
641
|
+
results_str = ""
|
642
|
+
elif len(external_tool_id_results) == 1:
|
643
|
+
results_str = external_tool_id_results[0][1]
|
644
|
+
else:
|
645
|
+
results_str = "\n\n".join(
|
646
|
+
[
|
647
|
+
f"Result from tool/function "
|
648
|
+
f"{parent_tool_id2name[id]}: {result}"
|
649
|
+
for id, result in external_tool_id_results
|
650
|
+
]
|
651
|
+
)
|
652
|
+
|
653
|
+
if len(id2result_) == 0:
|
654
|
+
id2result_ = None
|
655
|
+
elif len(id2result_) == 1 and len(external_tool_id_results) == 0:
|
656
|
+
results_str = list(id2result_.values())[0]
|
657
|
+
oai_tool_id = list(id2result_.keys())[0]
|
658
|
+
id2result_ = None
|
659
|
+
|
660
|
+
return results_str, id2result_, oai_tool_id
|
661
|
+
|
662
|
+
def response_template(
|
663
|
+
self,
|
664
|
+
e: Entity,
|
665
|
+
content: str | None = None,
|
666
|
+
content_any: Any = None,
|
667
|
+
tool_messages: List[ToolMessage] = [],
|
668
|
+
oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
|
669
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
670
|
+
oai_tool_id2result: OrderedDict[str, str] | None = None,
|
671
|
+
function_call: LLMFunctionCall | None = None,
|
672
|
+
recipient: str = "",
|
673
|
+
) -> ChatDocument:
|
674
|
+
"""Template for response from entity `e`."""
|
675
|
+
return ChatDocument(
|
676
|
+
content=content or "",
|
677
|
+
content_any=content_any,
|
678
|
+
tool_messages=tool_messages,
|
679
|
+
oai_tool_calls=oai_tool_calls,
|
680
|
+
oai_tool_id2result=oai_tool_id2result,
|
681
|
+
function_call=function_call,
|
682
|
+
oai_tool_choice=oai_tool_choice,
|
683
|
+
metadata=ChatDocMetaData(
|
684
|
+
source=e, sender=e, sender_name=self.config.name, recipient=recipient
|
685
|
+
),
|
686
|
+
)
|
687
|
+
|
688
|
+
def create_user_response(
|
689
|
+
self,
|
690
|
+
content: str | None = None,
|
691
|
+
content_any: Any = None,
|
692
|
+
tool_messages: List[ToolMessage] = [],
|
693
|
+
oai_tool_calls: List[OpenAIToolCall] | None = None,
|
694
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
695
|
+
oai_tool_id2result: OrderedDict[str, str] | None = None,
|
696
|
+
function_call: LLMFunctionCall | None = None,
|
697
|
+
recipient: str = "",
|
698
|
+
) -> ChatDocument:
|
699
|
+
"""Template for user_response."""
|
700
|
+
return self.response_template(
|
701
|
+
e=Entity.USER,
|
702
|
+
content=content,
|
703
|
+
content_any=content_any,
|
704
|
+
tool_messages=tool_messages,
|
705
|
+
oai_tool_calls=oai_tool_calls,
|
706
|
+
oai_tool_choice=oai_tool_choice,
|
707
|
+
oai_tool_id2result=oai_tool_id2result,
|
708
|
+
function_call=function_call,
|
709
|
+
recipient=recipient,
|
710
|
+
)
|
711
|
+
|
712
|
+
def user_can_respond(self, msg: Optional[str | ChatDocument] = None) -> bool:
|
713
|
+
"""
|
714
|
+
Whether the user can respond to a message.
|
715
|
+
|
716
|
+
Args:
|
717
|
+
msg (str|ChatDocument): the string to respond to.
|
718
|
+
|
719
|
+
Returns:
|
720
|
+
|
721
|
+
"""
|
722
|
+
# When msg explicitly addressed to user, this means an actual human response
|
723
|
+
# is being sought.
|
724
|
+
need_human_response = (
|
725
|
+
isinstance(msg, ChatDocument) and msg.metadata.recipient == Entity.USER
|
726
|
+
)
|
727
|
+
|
728
|
+
if not self.interactive and not need_human_response:
|
729
|
+
return False
|
730
|
+
|
731
|
+
return True
|
732
|
+
|
733
|
+
def _user_response_final(
|
734
|
+
self, msg: Optional[str | ChatDocument], user_msg: str
|
735
|
+
) -> Optional[ChatDocument]:
|
736
|
+
"""
|
737
|
+
Convert user_msg to final response.
|
738
|
+
"""
|
739
|
+
if not user_msg:
|
740
|
+
need_human_response = (
|
741
|
+
isinstance(msg, ChatDocument) and msg.metadata.recipient == Entity.USER
|
742
|
+
)
|
743
|
+
user_msg = (
|
744
|
+
(self.default_human_response or "null") if need_human_response else ""
|
745
|
+
)
|
746
|
+
user_msg = user_msg.strip()
|
747
|
+
|
748
|
+
tool_ids = []
|
749
|
+
if msg is not None and isinstance(msg, ChatDocument):
|
750
|
+
tool_ids = msg.metadata.tool_ids
|
751
|
+
|
752
|
+
# only return non-None result if user_msg not empty
|
753
|
+
if not user_msg:
|
754
|
+
return None
|
755
|
+
else:
|
756
|
+
if user_msg.startswith("SYSTEM"):
|
757
|
+
user_msg = user_msg.replace("SYSTEM", "").strip()
|
758
|
+
source = Entity.SYSTEM
|
759
|
+
sender = Entity.SYSTEM
|
760
|
+
else:
|
761
|
+
source = Entity.USER
|
762
|
+
sender = Entity.USER
|
763
|
+
return ChatDocument(
|
764
|
+
content=user_msg,
|
765
|
+
metadata=ChatDocMetaData(
|
766
|
+
source=source,
|
767
|
+
sender=sender,
|
768
|
+
# preserve trail of tool_ids for OpenAI Assistant fn-calls
|
769
|
+
tool_ids=tool_ids,
|
770
|
+
),
|
771
|
+
)
|
772
|
+
|
773
|
+
async def user_response_async(
|
774
|
+
self,
|
775
|
+
msg: Optional[str | ChatDocument] = None,
|
776
|
+
) -> Optional[ChatDocument]:
|
777
|
+
"""
|
778
|
+
Asynch version of `user_response`. See there for details.
|
779
|
+
"""
|
780
|
+
if not self.user_can_respond(msg):
|
781
|
+
return None
|
782
|
+
|
783
|
+
if self.default_human_response is not None:
|
784
|
+
user_msg = self.default_human_response
|
785
|
+
else:
|
786
|
+
if (
|
787
|
+
self.callbacks.get_user_response_async is not None
|
788
|
+
and self.callbacks.get_user_response_async is not async_noop_fn
|
789
|
+
):
|
790
|
+
user_msg = await self.callbacks.get_user_response_async(prompt="")
|
791
|
+
elif self.callbacks.get_user_response is not None:
|
792
|
+
user_msg = self.callbacks.get_user_response(prompt="")
|
793
|
+
else:
|
794
|
+
user_msg = Prompt.ask(
|
795
|
+
f"[blue]{self.indent}"
|
796
|
+
+ self.config.human_prompt
|
797
|
+
+ f"\n{self.indent}"
|
798
|
+
)
|
799
|
+
|
800
|
+
return self._user_response_final(msg, user_msg)
|
801
|
+
|
802
|
+
def user_response(
|
803
|
+
self,
|
804
|
+
msg: Optional[str | ChatDocument] = None,
|
805
|
+
) -> Optional[ChatDocument]:
|
806
|
+
"""
|
807
|
+
Get user response to current message. Could allow (human) user to intervene
|
808
|
+
with an actual answer, or quit using "q" or "x"
|
809
|
+
|
810
|
+
Args:
|
811
|
+
msg (str|ChatDocument): the string to respond to.
|
812
|
+
|
813
|
+
Returns:
|
814
|
+
(str) User response, packaged as a ChatDocument
|
815
|
+
|
816
|
+
"""
|
817
|
+
|
818
|
+
if not self.user_can_respond(msg):
|
819
|
+
return None
|
820
|
+
|
821
|
+
if self.default_human_response is not None:
|
822
|
+
user_msg = self.default_human_response
|
823
|
+
else:
|
824
|
+
if self.callbacks.get_user_response is not None:
|
825
|
+
# ask user with empty prompt: no need for prompt
|
826
|
+
# since user has seen the conversation so far.
|
827
|
+
# But non-empty prompt can be useful when Agent
|
828
|
+
# uses a tool that requires user input, or in other scenarios.
|
829
|
+
user_msg = self.callbacks.get_user_response(prompt="")
|
830
|
+
else:
|
831
|
+
user_msg = Prompt.ask(
|
832
|
+
f"[blue]{self.indent}"
|
833
|
+
+ self.config.human_prompt
|
834
|
+
+ f"\n{self.indent}"
|
835
|
+
)
|
836
|
+
|
837
|
+
return self._user_response_final(msg, user_msg)
|
838
|
+
|
839
|
+
@no_type_check
|
840
|
+
def llm_can_respond(self, message: Optional[str | ChatDocument] = None) -> bool:
|
841
|
+
"""
|
842
|
+
Whether the LLM can respond to a message.
|
843
|
+
Args:
|
844
|
+
message (str|ChatDocument): message or ChatDocument object to respond to.
|
845
|
+
|
846
|
+
Returns:
|
847
|
+
|
848
|
+
"""
|
849
|
+
if self.llm is None:
|
850
|
+
return False
|
851
|
+
|
852
|
+
if message is not None and len(self.try_get_tool_messages(message)) > 0:
|
853
|
+
# if there is a valid "tool" message (either JSON or via `function_call`)
|
854
|
+
# then LLM cannot respond to it
|
855
|
+
return False
|
856
|
+
|
857
|
+
return True
|
858
|
+
|
859
|
+
def can_respond(self, message: Optional[str | ChatDocument] = None) -> bool:
|
860
|
+
"""
|
861
|
+
Whether the agent can respond to a message.
|
862
|
+
Used in Task.py to skip a sub-task when we know it would not respond.
|
863
|
+
Args:
|
864
|
+
message (str|ChatDocument): message or ChatDocument object to respond to.
|
865
|
+
"""
|
866
|
+
tools = self.try_get_tool_messages(message)
|
867
|
+
if len(tools) == 0 and self.config.respond_tools_only:
|
868
|
+
return False
|
869
|
+
if message is not None and self.has_only_unhandled_tools(message):
|
870
|
+
# The message has tools that are NOT enabled to be handled by this agent,
|
871
|
+
# which means the agent cannot respond to it.
|
872
|
+
return False
|
873
|
+
return True
|
874
|
+
|
875
|
+
def create_llm_response(
|
876
|
+
self,
|
877
|
+
content: str | None = None,
|
878
|
+
content_any: Any = None,
|
879
|
+
tool_messages: List[ToolMessage] = [],
|
880
|
+
oai_tool_calls: None | List[OpenAIToolCall] = None,
|
881
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
882
|
+
oai_tool_id2result: OrderedDict[str, str] | None = None,
|
883
|
+
function_call: LLMFunctionCall | None = None,
|
884
|
+
recipient: str = "",
|
885
|
+
) -> ChatDocument:
|
886
|
+
"""Template for llm_response."""
|
887
|
+
return self.response_template(
|
888
|
+
Entity.LLM,
|
889
|
+
content=content,
|
890
|
+
content_any=content_any,
|
891
|
+
tool_messages=tool_messages,
|
892
|
+
oai_tool_calls=oai_tool_calls,
|
893
|
+
oai_tool_choice=oai_tool_choice,
|
894
|
+
oai_tool_id2result=oai_tool_id2result,
|
895
|
+
function_call=function_call,
|
896
|
+
recipient=recipient,
|
897
|
+
)
|
898
|
+
|
899
|
+
@no_type_check
|
900
|
+
async def llm_response_async(
|
901
|
+
self,
|
902
|
+
message: Optional[str | ChatDocument] = None,
|
903
|
+
) -> Optional[ChatDocument]:
|
904
|
+
"""
|
905
|
+
Asynch version of `llm_response`. See there for details.
|
906
|
+
"""
|
907
|
+
if message is None or not self.llm_can_respond(message):
|
908
|
+
return None
|
909
|
+
|
910
|
+
if isinstance(message, ChatDocument):
|
911
|
+
prompt = message.content
|
912
|
+
else:
|
913
|
+
prompt = message
|
914
|
+
|
915
|
+
output_len = self.config.llm.max_output_tokens
|
916
|
+
if self.num_tokens(prompt) + output_len > self.llm.completion_context_length():
|
917
|
+
output_len = self.llm.completion_context_length() - self.num_tokens(prompt)
|
918
|
+
if output_len < self.config.llm.min_output_tokens:
|
919
|
+
raise ValueError(
|
920
|
+
"""
|
921
|
+
Token-length of Prompt + Output is longer than the
|
922
|
+
completion context length of the LLM!
|
923
|
+
"""
|
924
|
+
)
|
925
|
+
else:
|
926
|
+
logger.warning(
|
927
|
+
f"""
|
928
|
+
Requested output length has been shortened to {output_len}
|
929
|
+
so that the total length of Prompt + Output is less than
|
930
|
+
the completion context length of the LLM.
|
931
|
+
"""
|
932
|
+
)
|
933
|
+
|
934
|
+
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
935
|
+
response = await self.llm.agenerate(prompt, output_len)
|
936
|
+
|
937
|
+
if not self.llm.get_stream() or response.cached and not settings.quiet:
|
938
|
+
# We would have already displayed the msg "live" ONLY if
|
939
|
+
# streaming was enabled, AND we did not find a cached response.
|
940
|
+
# If we are here, it means the response has not yet been displayed.
|
941
|
+
cached = f"[red]{self.indent}(cached)[/red]" if response.cached else ""
|
942
|
+
print(cached + "[green]" + escape(response.message))
|
943
|
+
async with self.lock:
|
944
|
+
self.update_token_usage(
|
945
|
+
response,
|
946
|
+
prompt,
|
947
|
+
self.llm.get_stream(),
|
948
|
+
chat=False, # i.e. it's a completion model not chat model
|
949
|
+
print_response_stats=self.config.show_stats and not settings.quiet,
|
950
|
+
)
|
951
|
+
cdoc = ChatDocument.from_LLMResponse(response, displayed=True)
|
952
|
+
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
953
|
+
cdoc.metadata.tool_ids = (
|
954
|
+
[] if isinstance(message, str) else message.metadata.tool_ids
|
955
|
+
)
|
956
|
+
return cdoc
|
957
|
+
|
958
|
+
@no_type_check
|
959
|
+
def llm_response(
|
960
|
+
self,
|
961
|
+
message: Optional[str | ChatDocument] = None,
|
962
|
+
) -> Optional[ChatDocument]:
|
963
|
+
"""
|
964
|
+
LLM response to a prompt.
|
965
|
+
Args:
|
966
|
+
message (str|ChatDocument): prompt string, or ChatDocument object
|
967
|
+
|
968
|
+
Returns:
|
969
|
+
Response from LLM, packaged as a ChatDocument
|
970
|
+
"""
|
971
|
+
if message is None or not self.llm_can_respond(message):
|
972
|
+
return None
|
973
|
+
|
974
|
+
if isinstance(message, ChatDocument):
|
975
|
+
prompt = message.content
|
976
|
+
else:
|
977
|
+
prompt = message
|
978
|
+
|
979
|
+
with ExitStack() as stack: # for conditionally using rich spinner
|
980
|
+
if not self.llm.get_stream():
|
981
|
+
# show rich spinner only if not streaming!
|
982
|
+
cm = status("LLM responding to message...")
|
983
|
+
stack.enter_context(cm)
|
984
|
+
output_len = self.config.llm.max_output_tokens
|
985
|
+
if (
|
986
|
+
self.num_tokens(prompt) + output_len
|
987
|
+
> self.llm.completion_context_length()
|
988
|
+
):
|
989
|
+
output_len = self.llm.completion_context_length() - self.num_tokens(
|
990
|
+
prompt
|
991
|
+
)
|
992
|
+
if output_len < self.config.llm.min_output_tokens:
|
993
|
+
raise ValueError(
|
994
|
+
"""
|
995
|
+
Token-length of Prompt + Output is longer than the
|
996
|
+
completion context length of the LLM!
|
997
|
+
"""
|
998
|
+
)
|
999
|
+
else:
|
1000
|
+
logger.warning(
|
1001
|
+
f"""
|
1002
|
+
Requested output length has been shortened to {output_len}
|
1003
|
+
so that the total length of Prompt + Output is less than
|
1004
|
+
the completion context length of the LLM.
|
1005
|
+
"""
|
1006
|
+
)
|
1007
|
+
if self.llm.get_stream() and not settings.quiet:
|
1008
|
+
console.print(f"[green]{self.indent}", end="")
|
1009
|
+
response = self.llm.generate(prompt, output_len)
|
1010
|
+
|
1011
|
+
if not self.llm.get_stream() or response.cached and not settings.quiet:
|
1012
|
+
# we would have already displayed the msg "live" ONLY if
|
1013
|
+
# streaming was enabled, AND we did not find a cached response
|
1014
|
+
# If we are here, it means the response has not yet been displayed.
|
1015
|
+
cached = f"[red]{self.indent}(cached)[/red]" if response.cached else ""
|
1016
|
+
console.print(f"[green]{self.indent}", end="")
|
1017
|
+
print(cached + "[green]" + escape(response.message))
|
1018
|
+
self.update_token_usage(
|
1019
|
+
response,
|
1020
|
+
prompt,
|
1021
|
+
self.llm.get_stream(),
|
1022
|
+
chat=False, # i.e. it's a completion model not chat model
|
1023
|
+
print_response_stats=self.config.show_stats and not settings.quiet,
|
1024
|
+
)
|
1025
|
+
cdoc = ChatDocument.from_LLMResponse(response, displayed=True)
|
1026
|
+
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
1027
|
+
cdoc.metadata.tool_ids = (
|
1028
|
+
[] if isinstance(message, str) else message.metadata.tool_ids
|
1029
|
+
)
|
1030
|
+
return cdoc
|
1031
|
+
|
1032
|
+
def has_tool_message_attempt(self, msg: str | ChatDocument | None) -> bool:
|
1033
|
+
"""
|
1034
|
+
Check whether msg contains a Tool/fn-call attempt (by the LLM).
|
1035
|
+
|
1036
|
+
CAUTION: This uses self.get_tool_messages(msg) which as a side-effect
|
1037
|
+
may update msg.tool_messages when msg is a ChatDocument, if there are
|
1038
|
+
any tools in msg.
|
1039
|
+
"""
|
1040
|
+
if msg is None:
|
1041
|
+
return False
|
1042
|
+
try:
|
1043
|
+
tools = self.get_tool_messages(msg)
|
1044
|
+
return len(tools) > 0
|
1045
|
+
except (ValidationError, XMLException):
|
1046
|
+
# there is a tool/fn-call attempt but had a validation error,
|
1047
|
+
# so we still consider this a tool message "attempt"
|
1048
|
+
return True
|
1049
|
+
return False
|
1050
|
+
|
1051
|
+
def _tool_recipient_match(self, tool: ToolMessage) -> bool:
|
1052
|
+
"""Is tool is handled by this agent
|
1053
|
+
and an explicit `recipient` field doesn't preclude this agent from handling it?
|
1054
|
+
"""
|
1055
|
+
if tool.default_value("request") not in self.llm_tools_handled:
|
1056
|
+
return False
|
1057
|
+
if hasattr(tool, "recipient") and isinstance(tool.recipient, str):
|
1058
|
+
return tool.recipient == "" or tool.recipient == self.config.name
|
1059
|
+
return True
|
1060
|
+
|
1061
|
+
def has_only_unhandled_tools(self, msg: str | ChatDocument) -> bool:
|
1062
|
+
"""
|
1063
|
+
Does the msg have at least one tool, and ALL tools are
|
1064
|
+
disabled for handling by this agent?
|
1065
|
+
"""
|
1066
|
+
if msg is None:
|
1067
|
+
return False
|
1068
|
+
tools = self.try_get_tool_messages(msg, all_tools=True)
|
1069
|
+
if len(tools) == 0:
|
1070
|
+
return False
|
1071
|
+
return all(not self._tool_recipient_match(t) for t in tools)
|
1072
|
+
|
1073
|
+
def try_get_tool_messages(
|
1074
|
+
self,
|
1075
|
+
msg: str | ChatDocument | None,
|
1076
|
+
all_tools: bool = False,
|
1077
|
+
) -> List[ToolMessage]:
|
1078
|
+
try:
|
1079
|
+
return self.get_tool_messages(msg, all_tools)
|
1080
|
+
except (ValidationError, XMLException):
|
1081
|
+
return []
|
1082
|
+
|
1083
|
+
def get_tool_messages(
|
1084
|
+
self,
|
1085
|
+
msg: str | ChatDocument | None,
|
1086
|
+
all_tools: bool = False,
|
1087
|
+
) -> List[ToolMessage]:
|
1088
|
+
"""
|
1089
|
+
Get ToolMessages recognized in msg, handle-able by this agent.
|
1090
|
+
NOTE: as a side-effect, this will update msg.tool_messages
|
1091
|
+
when msg is a ChatDocument and msg contains tool messages.
|
1092
|
+
The intent here is that update=True should be set ONLY within agent_response()
|
1093
|
+
or agent_response_async() methods. In other words, we want to persist the
|
1094
|
+
msg.tool_messages only AFTER the agent has had a chance to handle the tools.
|
1095
|
+
|
1096
|
+
Args:
|
1097
|
+
msg (str|ChatDocument): the message to extract tools from.
|
1098
|
+
all_tools (bool):
|
1099
|
+
- if True, return all tools,
|
1100
|
+
i.e. any recognized tool in self.llm_tools_known,
|
1101
|
+
whether it is handled by this agent or not;
|
1102
|
+
- otherwise, return only the tools handled by this agent.
|
1103
|
+
|
1104
|
+
Returns:
|
1105
|
+
List[ToolMessage]: list of ToolMessage objects
|
1106
|
+
"""
|
1107
|
+
|
1108
|
+
if msg is None:
|
1109
|
+
return []
|
1110
|
+
|
1111
|
+
if isinstance(msg, str):
|
1112
|
+
json_tools = self.get_formatted_tool_messages(msg)
|
1113
|
+
if all_tools:
|
1114
|
+
return json_tools
|
1115
|
+
else:
|
1116
|
+
return [
|
1117
|
+
t
|
1118
|
+
for t in json_tools
|
1119
|
+
if self._tool_recipient_match(t) and t.default_value("request")
|
1120
|
+
]
|
1121
|
+
|
1122
|
+
if all_tools and len(msg.all_tool_messages) > 0:
|
1123
|
+
# We've already identified all_tool_messages in the msg;
|
1124
|
+
# return the corresponding ToolMessage objects
|
1125
|
+
return msg.all_tool_messages
|
1126
|
+
if len(msg.tool_messages) > 0:
|
1127
|
+
# We've already found tool_messages,
|
1128
|
+
# (either via OpenAI Fn-call or Langroid-native ToolMessage);
|
1129
|
+
# or they were added by an agent_response.
|
1130
|
+
# note these could be from a forwarded msg from another agent,
|
1131
|
+
# so return ONLY the messages THIS agent to enabled to handle.
|
1132
|
+
if all_tools:
|
1133
|
+
return msg.tool_messages
|
1134
|
+
return [t for t in msg.tool_messages if self._tool_recipient_match(t)]
|
1135
|
+
assert isinstance(msg, ChatDocument)
|
1136
|
+
if (
|
1137
|
+
msg.content != ""
|
1138
|
+
and msg.oai_tool_calls is None
|
1139
|
+
and msg.function_call is None
|
1140
|
+
):
|
1141
|
+
|
1142
|
+
tools = self.get_formatted_tool_messages(msg.content)
|
1143
|
+
msg.all_tool_messages = tools
|
1144
|
+
# filter for actually handle-able tools, and recipient is this agent
|
1145
|
+
my_tools = [t for t in tools if self._tool_recipient_match(t)]
|
1146
|
+
msg.tool_messages = my_tools
|
1147
|
+
|
1148
|
+
if all_tools:
|
1149
|
+
return tools
|
1150
|
+
else:
|
1151
|
+
return my_tools
|
1152
|
+
|
1153
|
+
# otherwise, we look for `tool_calls` (possibly multiple)
|
1154
|
+
tools = self.get_oai_tool_calls_classes(msg)
|
1155
|
+
msg.all_tool_messages = tools
|
1156
|
+
my_tools = [t for t in tools if self._tool_recipient_match(t)]
|
1157
|
+
msg.tool_messages = my_tools
|
1158
|
+
|
1159
|
+
if len(tools) == 0:
|
1160
|
+
# otherwise, we look for a `function_call`
|
1161
|
+
fun_call_cls = self.get_function_call_class(msg)
|
1162
|
+
tools = [fun_call_cls] if fun_call_cls is not None else []
|
1163
|
+
msg.all_tool_messages = tools
|
1164
|
+
my_tools = [t for t in tools if self._tool_recipient_match(t)]
|
1165
|
+
msg.tool_messages = my_tools
|
1166
|
+
if all_tools:
|
1167
|
+
return tools
|
1168
|
+
else:
|
1169
|
+
return my_tools
|
1170
|
+
|
1171
|
+
def get_formatted_tool_messages(self, input_str: str) -> List[ToolMessage]:
|
1172
|
+
"""
|
1173
|
+
Returns ToolMessage objects (tools) corresponding to
|
1174
|
+
tool-formatted substrings, if any.
|
1175
|
+
ASSUMPTION - These tools are either ALL JSON-based, or ALL XML-based
|
1176
|
+
(i.e. not a mix of both).
|
1177
|
+
Terminology: a "formatted tool msg" is one which the LLM generates as
|
1178
|
+
part of its raw string output, rather than within a JSON object
|
1179
|
+
in the API response (i.e. this method does not extract tools/fns returned
|
1180
|
+
by OpenAI's tools/fns API or similar APIs).
|
1181
|
+
|
1182
|
+
Args:
|
1183
|
+
input_str (str): input string, typically a message sent by an LLM
|
1184
|
+
|
1185
|
+
Returns:
|
1186
|
+
List[ToolMessage]: list of ToolMessage objects
|
1187
|
+
"""
|
1188
|
+
self.tool_error = False
|
1189
|
+
substrings = XMLToolMessage.find_candidates(input_str)
|
1190
|
+
is_json = False
|
1191
|
+
if len(substrings) == 0:
|
1192
|
+
substrings = extract_top_level_json(input_str)
|
1193
|
+
is_json = len(substrings) > 0
|
1194
|
+
if not is_json:
|
1195
|
+
return []
|
1196
|
+
|
1197
|
+
results = [self._get_one_tool_message(j, is_json) for j in substrings]
|
1198
|
+
valid_results = [r for r in results if r is not None]
|
1199
|
+
# If any tool is correctly formed we do not set the flag
|
1200
|
+
if len(valid_results) > 0:
|
1201
|
+
self.tool_error = False
|
1202
|
+
return valid_results
|
1203
|
+
|
1204
|
+
def get_function_call_class(self, msg: ChatDocument) -> Optional[ToolMessage]:
|
1205
|
+
"""
|
1206
|
+
From ChatDocument (constructed from an LLM Response), get the `ToolMessage`
|
1207
|
+
corresponding to the `function_call` if it exists.
|
1208
|
+
"""
|
1209
|
+
if msg.function_call is None:
|
1210
|
+
return None
|
1211
|
+
tool_name = msg.function_call.name
|
1212
|
+
tool_msg = msg.function_call.arguments or {}
|
1213
|
+
if tool_name not in self.llm_tools_handled:
|
1214
|
+
logger.warning(
|
1215
|
+
f"""
|
1216
|
+
The function_call '{tool_name}' is not handled
|
1217
|
+
by the agent named '{self.config.name}'!
|
1218
|
+
If you intended this agent to handle this function_call,
|
1219
|
+
either the fn-call name is incorrectly generated by the LLM,
|
1220
|
+
(in which case you may need to adjust your LLM instructions),
|
1221
|
+
or you need to enable this agent to handle this fn-call.
|
1222
|
+
"""
|
1223
|
+
)
|
1224
|
+
if tool_name not in self.all_llm_tools_known:
|
1225
|
+
self.tool_error = True
|
1226
|
+
return None
|
1227
|
+
self.tool_error = False
|
1228
|
+
tool_class = self.llm_tools_map[tool_name]
|
1229
|
+
tool_msg.update(dict(request=tool_name))
|
1230
|
+
tool = tool_class.parse_obj(tool_msg)
|
1231
|
+
return tool
|
1232
|
+
|
1233
|
+
def get_oai_tool_calls_classes(self, msg: ChatDocument) -> List[ToolMessage]:
|
1234
|
+
"""
|
1235
|
+
From ChatDocument (constructed from an LLM Response), get
|
1236
|
+
a list of ToolMessages corresponding to the `tool_calls`, if any.
|
1237
|
+
"""
|
1238
|
+
|
1239
|
+
if msg.oai_tool_calls is None:
|
1240
|
+
return []
|
1241
|
+
tools = []
|
1242
|
+
all_errors = True
|
1243
|
+
for tc in msg.oai_tool_calls:
|
1244
|
+
if tc.function is None:
|
1245
|
+
continue
|
1246
|
+
tool_name = tc.function.name
|
1247
|
+
tool_msg = tc.function.arguments or {}
|
1248
|
+
if tool_name not in self.llm_tools_handled:
|
1249
|
+
logger.warning(
|
1250
|
+
f"""
|
1251
|
+
The tool_call '{tool_name}' is not handled
|
1252
|
+
by the agent named '{self.config.name}'!
|
1253
|
+
If you intended this agent to handle this function_call,
|
1254
|
+
either the fn-call name is incorrectly generated by the LLM,
|
1255
|
+
(in which case you may need to adjust your LLM instructions),
|
1256
|
+
or you need to enable this agent to handle this fn-call.
|
1257
|
+
"""
|
1258
|
+
)
|
1259
|
+
continue
|
1260
|
+
all_errors = False
|
1261
|
+
tool_class = self.llm_tools_map[tool_name]
|
1262
|
+
tool_msg.update(dict(request=tool_name))
|
1263
|
+
tool = tool_class.parse_obj(tool_msg)
|
1264
|
+
tool.id = tc.id or ""
|
1265
|
+
tools.append(tool)
|
1266
|
+
# When no tool is valid, set the recovery flag
|
1267
|
+
self.tool_error = all_errors
|
1268
|
+
return tools
|
1269
|
+
|
1270
|
+
def tool_validation_error(self, ve: ValidationError) -> str:
|
1271
|
+
"""
|
1272
|
+
Handle a validation error raised when parsing a tool message,
|
1273
|
+
when there is a legit tool name used, but it has missing/bad fields.
|
1274
|
+
Args:
|
1275
|
+
tool (ToolMessage): The tool message that failed validation
|
1276
|
+
ve (ValidationError): The exception raised
|
1277
|
+
|
1278
|
+
Returns:
|
1279
|
+
str: The error message to send back to the LLM
|
1280
|
+
"""
|
1281
|
+
tool_name = cast(ToolMessage, ve.model).default_value("request")
|
1282
|
+
bad_field_errors = "\n".join(
|
1283
|
+
[f"{e['loc']}: {e['msg']}" for e in ve.errors() if "loc" in e]
|
1284
|
+
)
|
1285
|
+
return f"""
|
1286
|
+
There were one or more errors in your attempt to use the
|
1287
|
+
TOOL or function_call named '{tool_name}':
|
1288
|
+
{bad_field_errors}
|
1289
|
+
Please write your message again, correcting the errors.
|
1290
|
+
"""
|
1291
|
+
|
1292
|
+
def _get_multiple_orch_tool_errs(
|
1293
|
+
self, tools: List[ToolMessage]
|
1294
|
+
) -> List[str | ChatDocument | None]:
|
1295
|
+
"""
|
1296
|
+
Return error document if the message contains multiple orchestration tools
|
1297
|
+
"""
|
1298
|
+
# check whether there are multiple orchestration-tools (e.g. DoneTool etc),
|
1299
|
+
# in which case set result to error-string since we don't yet support
|
1300
|
+
# multi-tools with one or more orch tools.
|
1301
|
+
from langroid.agent.tools.orchestration import (
|
1302
|
+
AgentDoneTool,
|
1303
|
+
AgentSendTool,
|
1304
|
+
DonePassTool,
|
1305
|
+
DoneTool,
|
1306
|
+
ForwardTool,
|
1307
|
+
PassTool,
|
1308
|
+
SendTool,
|
1309
|
+
)
|
1310
|
+
from langroid.agent.tools.recipient_tool import RecipientTool
|
1311
|
+
|
1312
|
+
ORCHESTRATION_TOOLS = (
|
1313
|
+
AgentDoneTool,
|
1314
|
+
DoneTool,
|
1315
|
+
PassTool,
|
1316
|
+
DonePassTool,
|
1317
|
+
ForwardTool,
|
1318
|
+
RecipientTool,
|
1319
|
+
SendTool,
|
1320
|
+
AgentSendTool,
|
1321
|
+
)
|
1322
|
+
|
1323
|
+
has_orch = any(isinstance(t, ORCHESTRATION_TOOLS) for t in tools)
|
1324
|
+
if has_orch and len(tools) > 1:
|
1325
|
+
err_str = "ERROR: Use ONE tool at a time!"
|
1326
|
+
return [err_str for _ in tools]
|
1327
|
+
|
1328
|
+
return []
|
1329
|
+
|
1330
|
+
def _handle_message_final(
|
1331
|
+
self, tools: List[ToolMessage], results: List[str | ChatDocument | None]
|
1332
|
+
) -> None | str | OrderedDict[str, str] | ChatDocument:
|
1333
|
+
"""
|
1334
|
+
Convert results to final response
|
1335
|
+
"""
|
1336
|
+
# extract content from ChatDocument results so we have all str|None
|
1337
|
+
results = [r.content if isinstance(r, ChatDocument) else r for r in results]
|
1338
|
+
|
1339
|
+
tool_names = [t.default_value("request") for t in tools]
|
1340
|
+
|
1341
|
+
has_ids = all([t.id != "" for t in tools])
|
1342
|
+
if has_ids:
|
1343
|
+
id2result = OrderedDict(
|
1344
|
+
(t.id, r)
|
1345
|
+
for t, r in zip(tools, results)
|
1346
|
+
if r is not None and isinstance(r, str)
|
1347
|
+
)
|
1348
|
+
result_values = list(id2result.values())
|
1349
|
+
if len(id2result) > 1 and any(
|
1350
|
+
orch_str in r
|
1351
|
+
for r in result_values
|
1352
|
+
for orch_str in ORCHESTRATION_STRINGS
|
1353
|
+
):
|
1354
|
+
# Cannot support multi-tool results containing orchestration strings!
|
1355
|
+
# Replace results with err string to force LLM to retry
|
1356
|
+
err_str = "ERROR: Please use ONE tool at a time!"
|
1357
|
+
id2result = OrderedDict((id, err_str) for id in id2result.keys())
|
1358
|
+
|
1359
|
+
name_results_list = [
|
1360
|
+
(name, r) for name, r in zip(tool_names, results) if r is not None
|
1361
|
+
]
|
1362
|
+
if len(name_results_list) == 0:
|
1363
|
+
return None
|
1364
|
+
|
1365
|
+
# there was a non-None result
|
1366
|
+
|
1367
|
+
if has_ids and len(id2result) > 1:
|
1368
|
+
# if there are multiple OpenAI Tool results, return them as a dict
|
1369
|
+
return id2result
|
1370
|
+
|
1371
|
+
# multi-results: prepend the tool name to each result
|
1372
|
+
str_results = [f"Result from {name}: {r}" for name, r in name_results_list]
|
1373
|
+
final = "\n\n".join(str_results)
|
1374
|
+
return final
|
1375
|
+
|
1376
|
+
async def handle_message_async(
|
1377
|
+
self, msg: str | ChatDocument
|
1378
|
+
) -> None | str | OrderedDict[str, str] | ChatDocument:
|
1379
|
+
"""
|
1380
|
+
Asynch version of `handle_message`. See there for details.
|
1381
|
+
"""
|
1382
|
+
try:
|
1383
|
+
tools = self.get_tool_messages(msg)
|
1384
|
+
tools = [t for t in tools if self._tool_recipient_match(t)]
|
1385
|
+
except ValidationError as ve:
|
1386
|
+
# correct tool name but bad fields
|
1387
|
+
return self.tool_validation_error(ve)
|
1388
|
+
except XMLException as xe: # from XMLToolMessage parsing
|
1389
|
+
return str(xe)
|
1390
|
+
except ValueError:
|
1391
|
+
# invalid tool name
|
1392
|
+
# We return None since returning "invalid tool name" would
|
1393
|
+
# be considered a valid result in task loop, and would be treated
|
1394
|
+
# as a response to the tool message even though the tool was not intended
|
1395
|
+
# for this agent.
|
1396
|
+
return None
|
1397
|
+
if len(tools) > 1 and not self.config.allow_multiple_tools:
|
1398
|
+
return self.to_ChatDocument("ERROR: Use ONE tool at a time!")
|
1399
|
+
if len(tools) == 0:
|
1400
|
+
fallback_result = self.handle_message_fallback(msg)
|
1401
|
+
if fallback_result is None:
|
1402
|
+
return None
|
1403
|
+
return self.to_ChatDocument(
|
1404
|
+
fallback_result,
|
1405
|
+
chat_doc=msg if isinstance(msg, ChatDocument) else None,
|
1406
|
+
)
|
1407
|
+
chat_doc = msg if isinstance(msg, ChatDocument) else None
|
1408
|
+
|
1409
|
+
results = self._get_multiple_orch_tool_errs(tools)
|
1410
|
+
if not results:
|
1411
|
+
results = [
|
1412
|
+
await self.handle_tool_message_async(t, chat_doc=chat_doc)
|
1413
|
+
for t in tools
|
1414
|
+
]
|
1415
|
+
# if there's a solitary ChatDocument|str result, return it as is
|
1416
|
+
if len(results) == 1 and isinstance(results[0], (str, ChatDocument)):
|
1417
|
+
return results[0]
|
1418
|
+
|
1419
|
+
return self._handle_message_final(tools, results)
|
1420
|
+
|
1421
|
+
def handle_message(
|
1422
|
+
self, msg: str | ChatDocument
|
1423
|
+
) -> None | str | OrderedDict[str, str] | ChatDocument:
|
1424
|
+
"""
|
1425
|
+
Handle a "tool" message either a string containing one or more
|
1426
|
+
valid "tool" JSON substrings, or a
|
1427
|
+
ChatDocument containing a `function_call` attribute.
|
1428
|
+
Handle with the corresponding handler method, and return
|
1429
|
+
the results as a combined string.
|
1430
|
+
|
1431
|
+
Args:
|
1432
|
+
msg (str | ChatDocument): The string or ChatDocument to handle
|
1433
|
+
|
1434
|
+
Returns:
|
1435
|
+
The result of the handler method can be:
|
1436
|
+
- None if no tools successfully handled, or no tools present
|
1437
|
+
- str if langroid-native JSON tools were handled, and results concatenated,
|
1438
|
+
OR there's a SINGLE OpenAI tool-call.
|
1439
|
+
(We do this so the common scenario of a single tool/fn-call
|
1440
|
+
has a simple behavior).
|
1441
|
+
- Dict[str, str] if multiple OpenAI tool-calls were handled
|
1442
|
+
(dict is an id->result map)
|
1443
|
+
- ChatDocument if a handler returned a ChatDocument, intended to be the
|
1444
|
+
final response of the `agent_response` method.
|
1445
|
+
"""
|
1446
|
+
try:
|
1447
|
+
tools = self.get_tool_messages(msg)
|
1448
|
+
tools = [t for t in tools if self._tool_recipient_match(t)]
|
1449
|
+
except ValidationError as ve:
|
1450
|
+
# correct tool name but bad fields
|
1451
|
+
return self.tool_validation_error(ve)
|
1452
|
+
except XMLException as xe: # from XMLToolMessage parsing
|
1453
|
+
return str(xe)
|
1454
|
+
except ValueError:
|
1455
|
+
# invalid tool name
|
1456
|
+
# We return None since returning "invalid tool name" would
|
1457
|
+
# be considered a valid result in task loop, and would be treated
|
1458
|
+
# as a response to the tool message even though the tool was not intended
|
1459
|
+
# for this agent.
|
1460
|
+
return None
|
1461
|
+
if len(tools) > 1 and not self.config.allow_multiple_tools:
|
1462
|
+
return self.to_ChatDocument("ERROR: Use ONE tool at a time!")
|
1463
|
+
if len(tools) == 0:
|
1464
|
+
fallback_result = self.handle_message_fallback(msg)
|
1465
|
+
if fallback_result is None:
|
1466
|
+
return None
|
1467
|
+
return self.to_ChatDocument(
|
1468
|
+
fallback_result,
|
1469
|
+
chat_doc=msg if isinstance(msg, ChatDocument) else None,
|
1470
|
+
)
|
1471
|
+
chat_doc = msg if isinstance(msg, ChatDocument) else None
|
1472
|
+
|
1473
|
+
results = self._get_multiple_orch_tool_errs(tools)
|
1474
|
+
if not results:
|
1475
|
+
results = [self.handle_tool_message(t, chat_doc=chat_doc) for t in tools]
|
1476
|
+
# if there's a solitary ChatDocument|str result, return it as is
|
1477
|
+
if len(results) == 1 and isinstance(results[0], (str, ChatDocument)):
|
1478
|
+
return results[0]
|
1479
|
+
|
1480
|
+
return self._handle_message_final(tools, results)
|
1481
|
+
|
1482
|
+
@property
|
1483
|
+
def all_llm_tools_known(self) -> set[str]:
|
1484
|
+
"""All known tools; this may extend self.llm_tools_known."""
|
1485
|
+
return self.llm_tools_known
|
1486
|
+
|
1487
|
+
def handle_message_fallback(self, msg: str | ChatDocument) -> Any:
|
1488
|
+
"""
|
1489
|
+
Fallback method for the "no-tools" scenario.
|
1490
|
+
This method can be overridden by subclasses, e.g.,
|
1491
|
+
to create a "reminder" message when a tool is expected but the LLM "forgot"
|
1492
|
+
to generate one.
|
1493
|
+
|
1494
|
+
Args:
|
1495
|
+
msg (str | ChatDocument): The input msg to handle
|
1496
|
+
Returns:
|
1497
|
+
Any: The result of the handler method
|
1498
|
+
"""
|
1499
|
+
return None
|
1500
|
+
|
1501
|
+
def _get_one_tool_message(
|
1502
|
+
self, tool_candidate_str: str, is_json: bool = True
|
1503
|
+
) -> Optional[ToolMessage]:
|
1504
|
+
"""
|
1505
|
+
Parse the tool_candidate_str into ANY ToolMessage KNOWN to agent --
|
1506
|
+
This includes non-used/handled tools, i.e. any tool in self.all_llm_tools_known.
|
1507
|
+
The exception to this is below where we try our best to infer the tool
|
1508
|
+
when the LLM has "forgotten" to include the "request" field in the tool str ---
|
1509
|
+
in this case we ONLY look at the possible set of HANDLED tools, i.e.
|
1510
|
+
self.llm_tools_handled.
|
1511
|
+
"""
|
1512
|
+
if is_json:
|
1513
|
+
maybe_tool_dict = json.loads(tool_candidate_str)
|
1514
|
+
else:
|
1515
|
+
try:
|
1516
|
+
maybe_tool_dict = XMLToolMessage.extract_field_values(
|
1517
|
+
tool_candidate_str
|
1518
|
+
)
|
1519
|
+
except Exception as e:
|
1520
|
+
from langroid.exceptions import XMLException
|
1521
|
+
|
1522
|
+
raise XMLException(f"Error extracting XML fields:\n {str(e)}")
|
1523
|
+
# check if the maybe_tool_dict contains a "properties" field
|
1524
|
+
# which further contains the actual tool-call
|
1525
|
+
# (some weak LLMs do this). E.g. gpt-4o sometimes generates this:
|
1526
|
+
# TOOL: {
|
1527
|
+
# "type": "object",
|
1528
|
+
# "properties": {
|
1529
|
+
# "request": "square",
|
1530
|
+
# "number": 9
|
1531
|
+
# },
|
1532
|
+
# "required": [
|
1533
|
+
# "number",
|
1534
|
+
# "request"
|
1535
|
+
# ]
|
1536
|
+
# }
|
1537
|
+
|
1538
|
+
if not isinstance(maybe_tool_dict, dict):
|
1539
|
+
self.tool_error = True
|
1540
|
+
return None
|
1541
|
+
|
1542
|
+
properties = maybe_tool_dict.get("properties")
|
1543
|
+
if isinstance(properties, dict):
|
1544
|
+
maybe_tool_dict = properties
|
1545
|
+
request = maybe_tool_dict.get("request")
|
1546
|
+
if request is None:
|
1547
|
+
if self.enabled_requests_for_inference is None:
|
1548
|
+
possible = [self.llm_tools_map[r] for r in self.llm_tools_handled]
|
1549
|
+
else:
|
1550
|
+
allowable = self.enabled_requests_for_inference.intersection(
|
1551
|
+
self.llm_tools_handled
|
1552
|
+
)
|
1553
|
+
possible = [self.llm_tools_map[r] for r in allowable]
|
1554
|
+
|
1555
|
+
default_keys = set(ToolMessage.__fields__.keys())
|
1556
|
+
request_keys = set(maybe_tool_dict.keys())
|
1557
|
+
|
1558
|
+
def maybe_parse(tool: type[ToolMessage]) -> Optional[ToolMessage]:
|
1559
|
+
all_keys = set(tool.__fields__.keys())
|
1560
|
+
non_inherited_keys = all_keys.difference(default_keys)
|
1561
|
+
# If the request has any keys not valid for the tool and
|
1562
|
+
# does not specify some key specific to the type
|
1563
|
+
# (e.g. not just `purpose`), the LLM must explicitly specify `request`
|
1564
|
+
if not (
|
1565
|
+
request_keys.issubset(all_keys)
|
1566
|
+
and len(request_keys.intersection(non_inherited_keys)) > 0
|
1567
|
+
):
|
1568
|
+
return None
|
1569
|
+
|
1570
|
+
try:
|
1571
|
+
return tool.parse_obj(maybe_tool_dict)
|
1572
|
+
except ValidationError:
|
1573
|
+
return None
|
1574
|
+
|
1575
|
+
candidate_tools = list(
|
1576
|
+
filter(
|
1577
|
+
lambda t: t is not None,
|
1578
|
+
map(maybe_parse, possible),
|
1579
|
+
)
|
1580
|
+
)
|
1581
|
+
|
1582
|
+
# If only one valid candidate exists, we infer
|
1583
|
+
# "request" to be the only possible value
|
1584
|
+
if len(candidate_tools) == 1:
|
1585
|
+
return candidate_tools[0]
|
1586
|
+
else:
|
1587
|
+
self.tool_error = True
|
1588
|
+
return None
|
1589
|
+
|
1590
|
+
if not isinstance(request, str) or request not in self.all_llm_tools_known:
|
1591
|
+
self.tool_error = True
|
1592
|
+
return None
|
1593
|
+
|
1594
|
+
message_class = self.llm_tools_map.get(request)
|
1595
|
+
if message_class is None:
|
1596
|
+
logger.warning(f"No message class found for request '{request}'")
|
1597
|
+
self.tool_error = True
|
1598
|
+
return None
|
1599
|
+
|
1600
|
+
try:
|
1601
|
+
message = message_class.parse_obj(maybe_tool_dict)
|
1602
|
+
except ValidationError as ve:
|
1603
|
+
self.tool_error = True
|
1604
|
+
raise ve
|
1605
|
+
return message
|
1606
|
+
|
1607
|
+
def to_ChatDocument(
|
1608
|
+
self,
|
1609
|
+
msg: Any,
|
1610
|
+
orig_tool_name: str | None = None,
|
1611
|
+
chat_doc: Optional[ChatDocument] = None,
|
1612
|
+
author_entity: Entity = Entity.AGENT,
|
1613
|
+
) -> Optional[ChatDocument]:
|
1614
|
+
"""
|
1615
|
+
Convert result of a responder (agent_response or llm_response, or task.run()),
|
1616
|
+
or tool handler, or handle_message_fallback,
|
1617
|
+
to a ChatDocument, to enable handling by other
|
1618
|
+
responders/tasks in a task loop possibly involving multiple agents.
|
1619
|
+
|
1620
|
+
Args:
|
1621
|
+
msg (Any): The result of a responder or tool handler or task.run()
|
1622
|
+
orig_tool_name (str): The original tool name that generated the response,
|
1623
|
+
if any.
|
1624
|
+
chat_doc (ChatDocument): The original ChatDocument object that `msg`
|
1625
|
+
is a response to.
|
1626
|
+
author_entity (Entity): The intended author of the result ChatDocument
|
1627
|
+
"""
|
1628
|
+
if msg is None or isinstance(msg, ChatDocument):
|
1629
|
+
return msg
|
1630
|
+
|
1631
|
+
is_agent_author = author_entity == Entity.AGENT
|
1632
|
+
|
1633
|
+
if isinstance(msg, str):
|
1634
|
+
return self.response_template(author_entity, content=msg, content_any=msg)
|
1635
|
+
elif isinstance(msg, ToolMessage):
|
1636
|
+
# result is a ToolMessage, so...
|
1637
|
+
result_tool_name = msg.default_value("request")
|
1638
|
+
if (
|
1639
|
+
is_agent_author
|
1640
|
+
and result_tool_name in self.llm_tools_handled
|
1641
|
+
and (orig_tool_name is None or orig_tool_name != result_tool_name)
|
1642
|
+
):
|
1643
|
+
# TODO: do we need to remove the tool message from the chat_doc?
|
1644
|
+
# if (chat_doc is not None and
|
1645
|
+
# msg in chat_doc.tool_messages):
|
1646
|
+
# chat_doc.tool_messages.remove(msg)
|
1647
|
+
# if we can handle it, do so
|
1648
|
+
result = self.handle_tool_message(msg, chat_doc=chat_doc)
|
1649
|
+
if result is not None and isinstance(result, ChatDocument):
|
1650
|
+
return result
|
1651
|
+
else:
|
1652
|
+
# else wrap it in an agent response and return it so
|
1653
|
+
# orchestrator can find a respondent
|
1654
|
+
return self.response_template(author_entity, tool_messages=[msg])
|
1655
|
+
else:
|
1656
|
+
result = to_string(msg)
|
1657
|
+
|
1658
|
+
return (
|
1659
|
+
None
|
1660
|
+
if result is None
|
1661
|
+
else self.response_template(author_entity, content=result, content_any=msg)
|
1662
|
+
)
|
1663
|
+
|
1664
|
+
def from_ChatDocument(self, msg: ChatDocument, output_type: Type[T]) -> Optional[T]:
|
1665
|
+
"""
|
1666
|
+
Extract a desired output_type from a ChatDocument object.
|
1667
|
+
We use this fallback order:
|
1668
|
+
- if `msg.content_any` exists and matches the output_type, return it
|
1669
|
+
- if `msg.content` exists and output_type is str return it
|
1670
|
+
- if output_type is a ToolMessage, return the first tool in `msg.tool_messages`
|
1671
|
+
- if output_type is a list of ToolMessage,
|
1672
|
+
return all tools in `msg.tool_messages`
|
1673
|
+
- search for a tool in `msg.tool_messages` that has a field of output_type,
|
1674
|
+
and if found, return that field value
|
1675
|
+
- return None if all the above fail
|
1676
|
+
"""
|
1677
|
+
content = msg.content
|
1678
|
+
if output_type is str and content != "":
|
1679
|
+
return cast(T, content)
|
1680
|
+
content_any = msg.content_any
|
1681
|
+
if content_any is not None and isinstance(content_any, output_type):
|
1682
|
+
return cast(T, content_any)
|
1683
|
+
|
1684
|
+
tools = self.try_get_tool_messages(msg, all_tools=True)
|
1685
|
+
|
1686
|
+
if get_origin(output_type) is list:
|
1687
|
+
list_element_type = get_args(output_type)[0]
|
1688
|
+
if issubclass(list_element_type, ToolMessage):
|
1689
|
+
# list_element_type is a subclass of ToolMessage:
|
1690
|
+
# We output a list of objects derived from list_element_type
|
1691
|
+
return cast(
|
1692
|
+
T,
|
1693
|
+
[t for t in tools if isinstance(t, list_element_type)],
|
1694
|
+
)
|
1695
|
+
elif get_origin(output_type) is None and issubclass(output_type, ToolMessage):
|
1696
|
+
# output_type is a subclass of ToolMessage:
|
1697
|
+
# return the first tool that has this specific output_type
|
1698
|
+
for tool in tools:
|
1699
|
+
if isinstance(tool, output_type):
|
1700
|
+
return cast(T, tool)
|
1701
|
+
return None
|
1702
|
+
elif get_origin(output_type) is None and output_type in (str, int, float, bool):
|
1703
|
+
# attempt to get the output_type from the content,
|
1704
|
+
# if it's a primitive type
|
1705
|
+
primitive_value = from_string(content, output_type) # type: ignore
|
1706
|
+
if primitive_value is not None:
|
1707
|
+
return cast(T, primitive_value)
|
1708
|
+
|
1709
|
+
# then search for output_type as a field in a tool
|
1710
|
+
for tool in tools:
|
1711
|
+
value = tool.get_value_of_type(output_type)
|
1712
|
+
if value is not None:
|
1713
|
+
return cast(T, value)
|
1714
|
+
return None
|
1715
|
+
|
1716
|
+
def _maybe_truncate_result(
|
1717
|
+
self, result: str | ChatDocument | None, max_tokens: int | None
|
1718
|
+
) -> str | ChatDocument | None:
|
1719
|
+
"""
|
1720
|
+
Truncate the result string to `max_tokens` tokens.
|
1721
|
+
"""
|
1722
|
+
if result is None or max_tokens is None:
|
1723
|
+
return result
|
1724
|
+
result_str = result.content if isinstance(result, ChatDocument) else result
|
1725
|
+
num_tokens = (
|
1726
|
+
self.parser.num_tokens(result_str)
|
1727
|
+
if self.parser is not None
|
1728
|
+
else len(result_str) / 4.0
|
1729
|
+
)
|
1730
|
+
if num_tokens <= max_tokens:
|
1731
|
+
return result
|
1732
|
+
truncate_warning = f"""
|
1733
|
+
The TOOL result was large, so it was truncated to {max_tokens} tokens.
|
1734
|
+
To get the full result, the TOOL must be called again.
|
1735
|
+
"""
|
1736
|
+
if isinstance(result, str):
|
1737
|
+
return (
|
1738
|
+
self.parser.truncate_tokens(result, max_tokens)
|
1739
|
+
if self.parser is not None
|
1740
|
+
else result[: max_tokens * 4] # approx truncate
|
1741
|
+
) + truncate_warning
|
1742
|
+
elif isinstance(result, ChatDocument):
|
1743
|
+
result.content = (
|
1744
|
+
self.parser.truncate_tokens(result.content, max_tokens)
|
1745
|
+
if self.parser is not None
|
1746
|
+
else result.content[: max_tokens * 4] # approx truncate
|
1747
|
+
) + truncate_warning
|
1748
|
+
return result
|
1749
|
+
|
1750
|
+
async def handle_tool_message_async(
|
1751
|
+
self,
|
1752
|
+
tool: ToolMessage,
|
1753
|
+
chat_doc: Optional[ChatDocument] = None,
|
1754
|
+
) -> None | str | ChatDocument:
|
1755
|
+
"""
|
1756
|
+
Asynch version of `handle_tool_message`. See there for details.
|
1757
|
+
"""
|
1758
|
+
tool_name = tool.default_value("request")
|
1759
|
+
if hasattr(tool, "_handler"):
|
1760
|
+
handler_name = getattr(tool, "_handler", tool_name)
|
1761
|
+
else:
|
1762
|
+
handler_name = tool_name
|
1763
|
+
handler_method = getattr(self, handler_name + "_async", None)
|
1764
|
+
if handler_method is None:
|
1765
|
+
return self.handle_tool_message(tool, chat_doc=chat_doc)
|
1766
|
+
has_chat_doc_arg = (
|
1767
|
+
chat_doc is not None
|
1768
|
+
and "chat_doc" in inspect.signature(handler_method).parameters
|
1769
|
+
)
|
1770
|
+
try:
|
1771
|
+
if has_chat_doc_arg:
|
1772
|
+
maybe_result = await handler_method(tool, chat_doc=chat_doc)
|
1773
|
+
else:
|
1774
|
+
maybe_result = await handler_method(tool)
|
1775
|
+
result = self.to_ChatDocument(maybe_result, tool_name, chat_doc)
|
1776
|
+
except Exception as e:
|
1777
|
+
# raise the error here since we are sure it's
|
1778
|
+
# not a pydantic validation error,
|
1779
|
+
# which we check in `handle_message`
|
1780
|
+
raise e
|
1781
|
+
return self._maybe_truncate_result(
|
1782
|
+
result, tool._max_result_tokens
|
1783
|
+
) # type: ignore
|
1784
|
+
|
1785
|
+
def handle_tool_message(
|
1786
|
+
self,
|
1787
|
+
tool: ToolMessage,
|
1788
|
+
chat_doc: Optional[ChatDocument] = None,
|
1789
|
+
) -> None | str | ChatDocument:
|
1790
|
+
"""
|
1791
|
+
Respond to a tool request from the LLM, in the form of an ToolMessage object.
|
1792
|
+
Args:
|
1793
|
+
tool: ToolMessage object representing the tool request.
|
1794
|
+
chat_doc: Optional ChatDocument object containing the tool request.
|
1795
|
+
This is passed to the tool-handler method only if it has a `chat_doc`
|
1796
|
+
argument.
|
1797
|
+
|
1798
|
+
Returns:
|
1799
|
+
|
1800
|
+
"""
|
1801
|
+
tool_name = tool.default_value("request")
|
1802
|
+
if hasattr(tool, "_handler"):
|
1803
|
+
handler_name = getattr(tool, "_handler", tool_name)
|
1804
|
+
else:
|
1805
|
+
handler_name = tool_name
|
1806
|
+
handler_method = getattr(self, handler_name, None)
|
1807
|
+
if handler_method is None:
|
1808
|
+
return None
|
1809
|
+
has_chat_doc_arg = (
|
1810
|
+
chat_doc is not None
|
1811
|
+
and "chat_doc" in inspect.signature(handler_method).parameters
|
1812
|
+
)
|
1813
|
+
try:
|
1814
|
+
if has_chat_doc_arg:
|
1815
|
+
maybe_result = handler_method(tool, chat_doc=chat_doc)
|
1816
|
+
else:
|
1817
|
+
maybe_result = handler_method(tool)
|
1818
|
+
result = self.to_ChatDocument(maybe_result, tool_name, chat_doc)
|
1819
|
+
except Exception as e:
|
1820
|
+
# raise the error here since we are sure it's
|
1821
|
+
# not a pydantic validation error,
|
1822
|
+
# which we check in `handle_message`
|
1823
|
+
raise e
|
1824
|
+
return self._maybe_truncate_result(
|
1825
|
+
result, tool._max_result_tokens
|
1826
|
+
) # type: ignore
|
1827
|
+
|
1828
|
+
def num_tokens(self, prompt: str | List[LLMMessage]) -> int:
|
1829
|
+
if self.parser is None:
|
1830
|
+
raise ValueError("Parser must be set, to count tokens")
|
1831
|
+
if isinstance(prompt, str):
|
1832
|
+
return self.parser.num_tokens(prompt)
|
1833
|
+
else:
|
1834
|
+
return sum(
|
1835
|
+
[
|
1836
|
+
self.parser.num_tokens(m.content)
|
1837
|
+
+ self.parser.num_tokens(str(m.function_call or ""))
|
1838
|
+
for m in prompt
|
1839
|
+
]
|
1840
|
+
)
|
1841
|
+
|
1842
|
+
def _get_response_stats(
|
1843
|
+
self, chat_length: int, tot_cost: float, response: LLMResponse
|
1844
|
+
) -> str:
|
1845
|
+
"""
|
1846
|
+
Get LLM response stats as a string
|
1847
|
+
|
1848
|
+
Args:
|
1849
|
+
chat_length (int): number of messages in the chat
|
1850
|
+
tot_cost (float): total cost of the chat so far
|
1851
|
+
response (LLMResponse): LLMResponse object
|
1852
|
+
"""
|
1853
|
+
|
1854
|
+
if self.config.llm is None:
|
1855
|
+
logger.warning("LLM config is None, cannot get response stats")
|
1856
|
+
return ""
|
1857
|
+
if response.usage:
|
1858
|
+
in_tokens = response.usage.prompt_tokens
|
1859
|
+
out_tokens = response.usage.completion_tokens
|
1860
|
+
llm_response_cost = format(response.usage.cost, ".4f")
|
1861
|
+
cumul_cost = format(tot_cost, ".4f")
|
1862
|
+
assert isinstance(self.llm, LanguageModel)
|
1863
|
+
context_length = self.llm.chat_context_length()
|
1864
|
+
max_out = self.config.llm.max_output_tokens
|
1865
|
+
|
1866
|
+
llm_model = (
|
1867
|
+
"no-LLM" if self.config.llm is None else self.llm.config.chat_model
|
1868
|
+
)
|
1869
|
+
# tot cost across all LLMs, agents
|
1870
|
+
all_cost = format(self.llm.tot_tokens_cost()[1], ".4f")
|
1871
|
+
return (
|
1872
|
+
f"[bold]Stats:[/bold] [magenta]N_MSG={chat_length}, "
|
1873
|
+
f"TOKENS: in={in_tokens}, out={out_tokens}, "
|
1874
|
+
f"max={max_out}, ctx={context_length}, "
|
1875
|
+
f"COST: now=${llm_response_cost}, cumul=${cumul_cost}, "
|
1876
|
+
f"tot=${all_cost} "
|
1877
|
+
f"[bold]({llm_model})[/bold][/magenta]"
|
1878
|
+
)
|
1879
|
+
return ""
|
1880
|
+
|
1881
|
+
def update_token_usage(
|
1882
|
+
self,
|
1883
|
+
response: LLMResponse,
|
1884
|
+
prompt: str | List[LLMMessage],
|
1885
|
+
stream: bool,
|
1886
|
+
chat: bool = True,
|
1887
|
+
print_response_stats: bool = True,
|
1888
|
+
) -> None:
|
1889
|
+
"""
|
1890
|
+
Updates `response.usage` obj (token usage and cost fields).the usage memebr
|
1891
|
+
It updates the cost after checking the cache and updates the
|
1892
|
+
tokens (prompts and completion) if the response stream is True, because OpenAI
|
1893
|
+
doesn't returns these fields.
|
1894
|
+
|
1895
|
+
Args:
|
1896
|
+
response (LLMResponse): LLMResponse object
|
1897
|
+
prompt (str | List[LLMMessage]): prompt or list of LLMMessage objects
|
1898
|
+
stream (bool): whether to update the usage in the response object
|
1899
|
+
if the response is not cached.
|
1900
|
+
chat (bool): whether this is a chat model or a completion model
|
1901
|
+
print_response_stats (bool): whether to print the response stats
|
1902
|
+
"""
|
1903
|
+
if response is None or self.llm is None:
|
1904
|
+
return
|
1905
|
+
|
1906
|
+
# Note: If response was not streamed, then
|
1907
|
+
# `response.usage` would already have been set by the API,
|
1908
|
+
# so we only need to update in the stream case.
|
1909
|
+
if stream:
|
1910
|
+
# usage, cost = 0 when response is from cache
|
1911
|
+
prompt_tokens = 0
|
1912
|
+
completion_tokens = 0
|
1913
|
+
cost = 0.0
|
1914
|
+
if not response.cached:
|
1915
|
+
prompt_tokens = self.num_tokens(prompt)
|
1916
|
+
completion_tokens = self.num_tokens(response.message)
|
1917
|
+
if response.function_call is not None:
|
1918
|
+
completion_tokens += self.num_tokens(str(response.function_call))
|
1919
|
+
cost = self.compute_token_cost(prompt_tokens, completion_tokens)
|
1920
|
+
response.usage = LLMTokenUsage(
|
1921
|
+
prompt_tokens=prompt_tokens,
|
1922
|
+
completion_tokens=completion_tokens,
|
1923
|
+
cost=cost,
|
1924
|
+
)
|
1925
|
+
|
1926
|
+
# update total counters
|
1927
|
+
if response.usage is not None:
|
1928
|
+
self.total_llm_token_cost += response.usage.cost
|
1929
|
+
self.total_llm_token_usage += response.usage.total_tokens
|
1930
|
+
self.llm.update_usage_cost(
|
1931
|
+
chat,
|
1932
|
+
response.usage.prompt_tokens,
|
1933
|
+
response.usage.completion_tokens,
|
1934
|
+
response.usage.cost,
|
1935
|
+
)
|
1936
|
+
chat_length = 1 if isinstance(prompt, str) else len(prompt)
|
1937
|
+
self.token_stats_str = self._get_response_stats(
|
1938
|
+
chat_length, self.total_llm_token_cost, response
|
1939
|
+
)
|
1940
|
+
if print_response_stats:
|
1941
|
+
print(self.indent + self.token_stats_str)
|
1942
|
+
|
1943
|
+
def compute_token_cost(self, prompt: int, completion: int) -> float:
|
1944
|
+
price = cast(LanguageModel, self.llm).chat_cost()
|
1945
|
+
return (price[0] * prompt + price[1] * completion) / 1000
|
1946
|
+
|
1947
|
+
def ask_agent(
|
1948
|
+
self,
|
1949
|
+
agent: "Agent",
|
1950
|
+
request: str,
|
1951
|
+
no_answer: str = NO_ANSWER,
|
1952
|
+
user_confirm: bool = True,
|
1953
|
+
) -> Optional[str]:
|
1954
|
+
"""
|
1955
|
+
Send a request to another agent, possibly after confirming with the user.
|
1956
|
+
This is not currently used, since we rely on the task loop and
|
1957
|
+
`RecipientTool` to address requests to other agents. It is generally best to
|
1958
|
+
avoid using this method.
|
1959
|
+
|
1960
|
+
Args:
|
1961
|
+
agent (Agent): agent to ask
|
1962
|
+
request (str): request to send
|
1963
|
+
no_answer (str): expected response when agent does not know the answer
|
1964
|
+
user_confirm (bool): whether to gate the request with a human confirmation
|
1965
|
+
|
1966
|
+
Returns:
|
1967
|
+
str: response from agent
|
1968
|
+
"""
|
1969
|
+
agent_type = type(agent).__name__
|
1970
|
+
if user_confirm:
|
1971
|
+
user_response = Prompt.ask(
|
1972
|
+
f"""[magenta]Here is the request or message:
|
1973
|
+
{request}
|
1974
|
+
Should I forward this to {agent_type}?""",
|
1975
|
+
default="y",
|
1976
|
+
choices=["y", "n"],
|
1977
|
+
)
|
1978
|
+
if user_response not in ["y", "yes"]:
|
1979
|
+
return None
|
1980
|
+
answer = agent.llm_response(request)
|
1981
|
+
if answer != no_answer:
|
1982
|
+
return (f"{agent_type} says: " + str(answer)).strip()
|
1983
|
+
return None
|