langroid 0.33.6__py3-none-any.whl → 0.33.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +106 -0
- langroid/agent/__init__.py +41 -0
- langroid/agent/base.py +1983 -0
- langroid/agent/batch.py +398 -0
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +598 -0
- langroid/agent/chat_agent.py +1899 -0
- langroid/agent/chat_document.py +454 -0
- langroid/agent/openai_assistant.py +882 -0
- langroid/agent/special/__init__.py +59 -0
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +656 -0
- langroid/agent/special/arangodb/system_messages.py +186 -0
- langroid/agent/special/arangodb/tools.py +107 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/doc_chat_agent.py +1466 -0
- langroid/agent/special/lance_doc_chat_agent.py +262 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +198 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
- langroid/agent/special/lance_tools.py +61 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +56 -0
- langroid/agent/special/sql/__init__.py +17 -0
- langroid/agent/special/sql/sql_chat_agent.py +654 -0
- langroid/agent/special/sql/utils/__init__.py +21 -0
- langroid/agent/special/sql/utils/description_extractors.py +190 -0
- langroid/agent/special/sql/utils/populate_metadata.py +85 -0
- langroid/agent/special/sql/utils/system_message.py +35 -0
- langroid/agent/special/sql/utils/tools.py +64 -0
- langroid/agent/special/table_chat_agent.py +263 -0
- langroid/agent/task.py +2095 -0
- langroid/agent/tool_message.py +393 -0
- langroid/agent/tools/__init__.py +38 -0
- langroid/agent/tools/duckduckgo_search_tool.py +50 -0
- langroid/agent/tools/file_tools.py +234 -0
- langroid/agent/tools/google_search_tool.py +39 -0
- langroid/agent/tools/metaphor_search_tool.py +68 -0
- langroid/agent/tools/orchestration.py +303 -0
- langroid/agent/tools/recipient_tool.py +235 -0
- langroid/agent/tools/retrieval_tool.py +32 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/agent/tools/segment_extract_tool.py +41 -0
- langroid/agent/xml_tool_message.py +382 -0
- langroid/cachedb/__init__.py +17 -0
- langroid/cachedb/base.py +58 -0
- langroid/cachedb/momento_cachedb.py +108 -0
- langroid/cachedb/redis_cachedb.py +153 -0
- langroid/embedding_models/__init__.py +39 -0
- langroid/embedding_models/base.py +74 -0
- langroid/embedding_models/models.py +461 -0
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/exceptions.py +71 -0
- langroid/language_models/__init__.py +53 -0
- langroid/language_models/azure_openai.py +153 -0
- langroid/language_models/base.py +678 -0
- langroid/language_models/config.py +18 -0
- langroid/language_models/mock_lm.py +124 -0
- langroid/language_models/openai_gpt.py +1964 -0
- langroid/language_models/prompt_formatter/__init__.py +16 -0
- langroid/language_models/prompt_formatter/base.py +40 -0
- langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
- langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
- langroid/language_models/utils.py +151 -0
- langroid/mytypes.py +84 -0
- langroid/parsing/__init__.py +52 -0
- langroid/parsing/agent_chats.py +38 -0
- langroid/parsing/code_parser.py +121 -0
- langroid/parsing/document_parser.py +718 -0
- langroid/parsing/para_sentence_split.py +62 -0
- langroid/parsing/parse_json.py +155 -0
- langroid/parsing/parser.py +313 -0
- langroid/parsing/repo_loader.py +790 -0
- langroid/parsing/routing.py +36 -0
- langroid/parsing/search.py +275 -0
- langroid/parsing/spider.py +102 -0
- langroid/parsing/table_loader.py +94 -0
- langroid/parsing/url_loader.py +111 -0
- langroid/parsing/urls.py +273 -0
- langroid/parsing/utils.py +373 -0
- langroid/parsing/web_search.py +156 -0
- langroid/prompts/__init__.py +9 -0
- langroid/prompts/dialog.py +17 -0
- langroid/prompts/prompts_config.py +5 -0
- langroid/prompts/templates.py +141 -0
- langroid/pydantic_v1/__init__.py +10 -0
- langroid/pydantic_v1/main.py +4 -0
- langroid/utils/__init__.py +19 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +98 -0
- langroid/utils/constants.py +30 -0
- langroid/utils/git_utils.py +252 -0
- langroid/utils/globals.py +49 -0
- langroid/utils/logging.py +135 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/output/__init__.py +20 -0
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +99 -0
- langroid/utils/output/status.py +40 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +602 -0
- langroid/utils/system.py +286 -0
- langroid/utils/types.py +93 -0
- langroid/vector_store/__init__.py +50 -0
- langroid/vector_store/base.py +359 -0
- langroid/vector_store/chromadb.py +214 -0
- langroid/vector_store/lancedb.py +406 -0
- langroid/vector_store/meilisearch.py +299 -0
- langroid/vector_store/momento.py +278 -0
- langroid/vector_store/qdrantdb.py +468 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
- langroid-0.33.7.dist-info/RECORD +127 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
- langroid-0.33.6.dist-info/RECORD +0 -7
- langroid-0.33.6.dist-info/entry_points.txt +0 -4
- pyproject.toml +0 -356
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1899 @@
|
|
1
|
+
import copy
|
2
|
+
import inspect
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import textwrap
|
6
|
+
from contextlib import ExitStack
|
7
|
+
from inspect import isclass
|
8
|
+
from typing import Dict, List, Optional, Self, Set, Tuple, Type, Union, cast
|
9
|
+
|
10
|
+
import openai
|
11
|
+
from rich import print
|
12
|
+
from rich.console import Console
|
13
|
+
from rich.markup import escape
|
14
|
+
|
15
|
+
from langroid.agent.base import Agent, AgentConfig, async_noop_fn, noop_fn
|
16
|
+
from langroid.agent.chat_document import ChatDocument
|
17
|
+
from langroid.agent.tool_message import (
|
18
|
+
ToolMessage,
|
19
|
+
format_schema_for_strict,
|
20
|
+
)
|
21
|
+
from langroid.agent.xml_tool_message import XMLToolMessage
|
22
|
+
from langroid.language_models.base import (
|
23
|
+
LLMFunctionCall,
|
24
|
+
LLMFunctionSpec,
|
25
|
+
LLMMessage,
|
26
|
+
LLMResponse,
|
27
|
+
OpenAIJsonSchemaSpec,
|
28
|
+
OpenAIToolSpec,
|
29
|
+
Role,
|
30
|
+
StreamingIfAllowed,
|
31
|
+
ToolChoiceTypes,
|
32
|
+
)
|
33
|
+
from langroid.language_models.openai_gpt import OpenAIGPT
|
34
|
+
from langroid.pydantic_v1 import BaseModel, ValidationError
|
35
|
+
from langroid.utils.configuration import settings
|
36
|
+
from langroid.utils.object_registry import ObjectRegistry
|
37
|
+
from langroid.utils.output import status
|
38
|
+
from langroid.utils.pydantic_utils import PydanticWrapper, get_pydantic_wrapper
|
39
|
+
|
40
|
+
console = Console()
|
41
|
+
|
42
|
+
logger = logging.getLogger(__name__)
|
43
|
+
|
44
|
+
|
45
|
+
class ChatAgentConfig(AgentConfig):
|
46
|
+
"""
|
47
|
+
Configuration for ChatAgent
|
48
|
+
Attributes:
|
49
|
+
system_message: system message to include in message sequence
|
50
|
+
(typically defines role and task of agent).
|
51
|
+
Used only if `task` is not specified in the constructor.
|
52
|
+
user_message: user message to include in message sequence.
|
53
|
+
Used only if `task` is not specified in the constructor.
|
54
|
+
use_tools: whether to use our own ToolMessages mechanism
|
55
|
+
use_functions_api: whether to use functions/tools native to the LLM API
|
56
|
+
(e.g. OpenAI's `function_call` or `tool_call` mechanism)
|
57
|
+
use_tools_api: When `use_functions_api` is True, if this is also True,
|
58
|
+
the OpenAI tool-call API is used, rather than the older/deprecated
|
59
|
+
function-call API. However the tool-call API has some tricky aspects,
|
60
|
+
hence we set this to False by default.
|
61
|
+
strict_recovery: whether to enable strict schema recovery when there
|
62
|
+
is a tool-generation error.
|
63
|
+
enable_orchestration_tool_handling: whether to enable handling of orchestration
|
64
|
+
tools, e.g. ForwardTool, DoneTool, PassTool, etc.
|
65
|
+
output_format: When supported by the LLM (certain OpenAI LLMs
|
66
|
+
and local LLMs served by providers such as vLLM), ensures
|
67
|
+
that the output is a JSON matching the corresponding
|
68
|
+
schema via grammar-based decoding
|
69
|
+
handle_output_format: When `output_format` is a `ToolMessage` T,
|
70
|
+
controls whether T is "enabled for handling".
|
71
|
+
use_output_format: When `output_format` is a `ToolMessage` T,
|
72
|
+
controls whether T is "enabled for use" (by LLM) and
|
73
|
+
instructions on using T are added to the system message.
|
74
|
+
instructions_output_format: Controls whether we generate instructions for
|
75
|
+
`output_format` in the system message.
|
76
|
+
use_tools_on_output_format: Controls whether to automatically switch
|
77
|
+
to the Langroid-native tools mechanism when `output_format` is set.
|
78
|
+
Note that LLMs may generate tool calls which do not belong to
|
79
|
+
`output_format` even when strict JSON mode is enabled, so this should be
|
80
|
+
enabled when such tool calls are not desired.
|
81
|
+
output_format_include_defaults: Whether to include fields with default arguments
|
82
|
+
in the output schema
|
83
|
+
"""
|
84
|
+
|
85
|
+
system_message: str = "You are a helpful assistant."
|
86
|
+
user_message: Optional[str] = None
|
87
|
+
use_tools: bool = False
|
88
|
+
use_functions_api: bool = True
|
89
|
+
use_tools_api: bool = False
|
90
|
+
strict_recovery: bool = True
|
91
|
+
enable_orchestration_tool_handling: bool = True
|
92
|
+
output_format: Optional[type] = None
|
93
|
+
handle_output_format: bool = True
|
94
|
+
use_output_format: bool = True
|
95
|
+
instructions_output_format: bool = True
|
96
|
+
output_format_include_defaults: bool = True
|
97
|
+
use_tools_on_output_format: bool = True
|
98
|
+
|
99
|
+
def _set_fn_or_tools(self, fn_available: bool) -> None:
|
100
|
+
"""
|
101
|
+
Enable Langroid Tool or OpenAI-like fn-calling,
|
102
|
+
depending on config settings and availability of fn-calling.
|
103
|
+
"""
|
104
|
+
if self.use_functions_api and not fn_available:
|
105
|
+
logger.debug(
|
106
|
+
"""
|
107
|
+
You have enabled `use_functions_api` but the LLM does not support it.
|
108
|
+
So we will enable `use_tools` instead, so we can use
|
109
|
+
Langroid's ToolMessage mechanism.
|
110
|
+
"""
|
111
|
+
)
|
112
|
+
self.use_functions_api = False
|
113
|
+
self.use_tools = True
|
114
|
+
|
115
|
+
if not self.use_functions_api or not self.use_tools:
|
116
|
+
return
|
117
|
+
if self.use_functions_api and self.use_tools:
|
118
|
+
logger.debug(
|
119
|
+
"""
|
120
|
+
You have enabled both `use_tools` and `use_functions_api`.
|
121
|
+
Turning off `use_tools`, since the LLM supports function-calling.
|
122
|
+
"""
|
123
|
+
)
|
124
|
+
self.use_tools = False
|
125
|
+
self.use_functions_api = True
|
126
|
+
|
127
|
+
|
128
|
+
class ChatAgent(Agent):
|
129
|
+
"""
|
130
|
+
Chat Agent interacting with external env
|
131
|
+
(could be human, or external tools).
|
132
|
+
The agent (the LLM actually) is provided with an optional "Task Spec",
|
133
|
+
which is a sequence of `LLMMessage`s. These are used to initialize
|
134
|
+
the `task_messages` of the agent.
|
135
|
+
In most applications we will use a `ChatAgent` rather than a bare `Agent`.
|
136
|
+
The `Agent` class mainly exists to hold various common methods and attributes.
|
137
|
+
One difference between `ChatAgent` and `Agent` is that `ChatAgent`'s
|
138
|
+
`llm_response` method uses "chat mode" API (i.e. one that takes a
|
139
|
+
message sequence rather than a single message),
|
140
|
+
whereas the same method in the `Agent` class uses "completion mode" API (i.e. one
|
141
|
+
that takes a single message).
|
142
|
+
"""
|
143
|
+
|
144
|
+
def __init__(
|
145
|
+
self,
|
146
|
+
config: ChatAgentConfig = ChatAgentConfig(),
|
147
|
+
task: Optional[List[LLMMessage]] = None,
|
148
|
+
):
|
149
|
+
"""
|
150
|
+
Chat-mode agent initialized with task spec as the initial message sequence
|
151
|
+
Args:
|
152
|
+
config: settings for the agent
|
153
|
+
|
154
|
+
"""
|
155
|
+
super().__init__(config)
|
156
|
+
self.config: ChatAgentConfig = config
|
157
|
+
self.config._set_fn_or_tools(self._fn_call_available())
|
158
|
+
self.message_history: List[LLMMessage] = []
|
159
|
+
self.init_state()
|
160
|
+
# An agent's "task" is defined by a system msg and an optional user msg;
|
161
|
+
# These are "priming" messages that kick off the agent's conversation.
|
162
|
+
self.system_message: str = self.config.system_message
|
163
|
+
self.user_message: str | None = self.config.user_message
|
164
|
+
|
165
|
+
if task is not None:
|
166
|
+
# if task contains a system msg, we override the config system msg
|
167
|
+
if len(task) > 0 and task[0].role == Role.SYSTEM:
|
168
|
+
self.system_message = task[0].content
|
169
|
+
# if task contains a user msg, we override the config user msg
|
170
|
+
if len(task) > 1 and task[1].role == Role.USER:
|
171
|
+
self.user_message = task[1].content
|
172
|
+
|
173
|
+
# system-level instructions for using tools/functions:
|
174
|
+
# We maintain these as tools/functions are enabled/disabled,
|
175
|
+
# and whenever an LLM response is sought, these are used to
|
176
|
+
# recreate the system message (via `_create_system_and_tools_message`)
|
177
|
+
# each time, so it reflects the current set of enabled tools/functions.
|
178
|
+
# (a) these are general instructions on using certain tools/functions,
|
179
|
+
# if they are specified in a ToolMessage class as a classmethod `instructions`
|
180
|
+
self.system_tool_instructions: str = ""
|
181
|
+
# (b) these are only for the builtin in Langroid TOOLS mechanism:
|
182
|
+
self.system_tool_format_instructions: str = ""
|
183
|
+
|
184
|
+
self.llm_functions_map: Dict[str, LLMFunctionSpec] = {}
|
185
|
+
self.llm_functions_handled: Set[str] = set()
|
186
|
+
self.llm_functions_usable: Set[str] = set()
|
187
|
+
self.llm_function_force: Optional[Dict[str, str]] = None
|
188
|
+
|
189
|
+
self.output_format: Optional[type[ToolMessage | BaseModel]] = None
|
190
|
+
|
191
|
+
self.saved_requests_and_tool_setings = self._requests_and_tool_settings()
|
192
|
+
# This variable is not None and equals a `ToolMessage` T, if and only if:
|
193
|
+
# (a) T has been set as the output_format of this agent, AND
|
194
|
+
# (b) T has been "enabled for use" ONLY for enforcing this output format, AND
|
195
|
+
# (c) T has NOT been explicitly "enabled for use" by this Agent.
|
196
|
+
self.enabled_use_output_format: Optional[type[ToolMessage]] = None
|
197
|
+
# As above but deals with "enabled for handling" instead of "enabled for use".
|
198
|
+
self.enabled_handling_output_format: Optional[type[ToolMessage]] = None
|
199
|
+
if config.output_format is not None:
|
200
|
+
self.set_output_format(config.output_format)
|
201
|
+
# instructions specifically related to enforcing `output_format`
|
202
|
+
self.output_format_instructions = ""
|
203
|
+
|
204
|
+
# controls whether to disable strict schemas for this agent if
|
205
|
+
# strict mode causes exception
|
206
|
+
self.disable_strict = False
|
207
|
+
# Tracks whether any strict tool is enabled; used to determine whether to set
|
208
|
+
# `self.disable_strict` on an exception
|
209
|
+
self.any_strict = False
|
210
|
+
# Tracks the set of tools on which we force-disable strict decoding
|
211
|
+
self.disable_strict_tools_set: set[str] = set()
|
212
|
+
|
213
|
+
if self.config.enable_orchestration_tool_handling:
|
214
|
+
# Only enable HANDLING by `agent_response`, NOT LLM generation of these.
|
215
|
+
# This is useful where tool-handlers or agent_response generate these
|
216
|
+
# tools, and need to be handled.
|
217
|
+
# We don't want enable orch tool GENERATION by default, since that
|
218
|
+
# might clutter-up the LLM system message unnecessarily.
|
219
|
+
from langroid.agent.tools.orchestration import (
|
220
|
+
AgentDoneTool,
|
221
|
+
AgentSendTool,
|
222
|
+
DonePassTool,
|
223
|
+
DoneTool,
|
224
|
+
ForwardTool,
|
225
|
+
PassTool,
|
226
|
+
ResultTool,
|
227
|
+
SendTool,
|
228
|
+
)
|
229
|
+
|
230
|
+
self.enable_message(ForwardTool, use=False, handle=True)
|
231
|
+
self.enable_message(DoneTool, use=False, handle=True)
|
232
|
+
self.enable_message(AgentDoneTool, use=False, handle=True)
|
233
|
+
self.enable_message(PassTool, use=False, handle=True)
|
234
|
+
self.enable_message(DonePassTool, use=False, handle=True)
|
235
|
+
self.enable_message(SendTool, use=False, handle=True)
|
236
|
+
self.enable_message(AgentSendTool, use=False, handle=True)
|
237
|
+
self.enable_message(ResultTool, use=False, handle=True)
|
238
|
+
|
239
|
+
def init_state(self) -> None:
|
240
|
+
"""
|
241
|
+
Initialize the state of the agent. Just conversation state here,
|
242
|
+
but subclasses can override this to initialize other state.
|
243
|
+
"""
|
244
|
+
super().init_state()
|
245
|
+
self.clear_history(0)
|
246
|
+
self.clear_dialog()
|
247
|
+
|
248
|
+
@staticmethod
|
249
|
+
def from_id(id: str) -> "ChatAgent":
|
250
|
+
"""
|
251
|
+
Get an agent from its ID
|
252
|
+
Args:
|
253
|
+
agent_id (str): ID of the agent
|
254
|
+
Returns:
|
255
|
+
ChatAgent: The agent with the given ID
|
256
|
+
"""
|
257
|
+
return cast(ChatAgent, Agent.from_id(id))
|
258
|
+
|
259
|
+
def clone(self, i: int = 0) -> "ChatAgent":
|
260
|
+
"""Create i'th clone of this agent, ensuring tool use/handling is cloned.
|
261
|
+
Important: We assume all member variables are in the __init__ method here
|
262
|
+
and in the Agent class.
|
263
|
+
TODO: We are attempting to clone an agent after its state has been
|
264
|
+
changed in possibly many ways. Below is an imperfect solution. Caution advised.
|
265
|
+
Revisit later.
|
266
|
+
"""
|
267
|
+
agent_cls = type(self)
|
268
|
+
config_copy = copy.deepcopy(self.config)
|
269
|
+
config_copy.name = f"{config_copy.name}-{i}"
|
270
|
+
new_agent = agent_cls(config_copy)
|
271
|
+
new_agent.system_tool_instructions = self.system_tool_instructions
|
272
|
+
new_agent.system_tool_format_instructions = self.system_tool_format_instructions
|
273
|
+
new_agent.llm_tools_map = self.llm_tools_map
|
274
|
+
new_agent.llm_functions_map = self.llm_functions_map
|
275
|
+
new_agent.llm_functions_handled = self.llm_functions_handled
|
276
|
+
new_agent.llm_functions_usable = self.llm_functions_usable
|
277
|
+
new_agent.llm_function_force = self.llm_function_force
|
278
|
+
# Caution - we are copying the vector-db, maybe we don't always want this?
|
279
|
+
new_agent.vecdb = self.vecdb
|
280
|
+
new_agent.id = ObjectRegistry.new_id()
|
281
|
+
if self.config.add_to_registry:
|
282
|
+
ObjectRegistry.register_object(new_agent)
|
283
|
+
return new_agent
|
284
|
+
|
285
|
+
def _strict_mode_for_tool(self, tool: str | type[ToolMessage]) -> bool:
|
286
|
+
"""Should we enable strict mode for a given tool?"""
|
287
|
+
if isinstance(tool, str):
|
288
|
+
tool_class = self.llm_tools_map[tool]
|
289
|
+
else:
|
290
|
+
tool_class = tool
|
291
|
+
name = tool_class.default_value("request")
|
292
|
+
if name in self.disable_strict_tools_set or self.disable_strict:
|
293
|
+
return False
|
294
|
+
strict: Optional[bool] = tool_class.default_value("strict")
|
295
|
+
if strict is None:
|
296
|
+
strict = self._strict_tools_available()
|
297
|
+
|
298
|
+
return strict
|
299
|
+
|
300
|
+
def _fn_call_available(self) -> bool:
|
301
|
+
"""Does this agent's LLM support function calling?"""
|
302
|
+
return (
|
303
|
+
self.llm is not None
|
304
|
+
and isinstance(self.llm, OpenAIGPT)
|
305
|
+
and self.llm.is_openai_chat_model()
|
306
|
+
and self.llm.supports_functions_or_tools()
|
307
|
+
)
|
308
|
+
|
309
|
+
def _strict_tools_available(self) -> bool:
|
310
|
+
"""Does this agent's LLM support strict tools?"""
|
311
|
+
return (
|
312
|
+
not self.disable_strict
|
313
|
+
and self.llm is not None
|
314
|
+
and isinstance(self.llm, OpenAIGPT)
|
315
|
+
and self.llm.config.parallel_tool_calls is False
|
316
|
+
and self.llm.supports_strict_tools
|
317
|
+
)
|
318
|
+
|
319
|
+
def _json_schema_available(self) -> bool:
|
320
|
+
"""Does this agent's LLM support strict JSON schema output format?"""
|
321
|
+
return (
|
322
|
+
not self.disable_strict
|
323
|
+
and self.llm is not None
|
324
|
+
and isinstance(self.llm, OpenAIGPT)
|
325
|
+
and self.llm.supports_json_schema
|
326
|
+
)
|
327
|
+
|
328
|
+
def set_system_message(self, msg: str) -> None:
|
329
|
+
self.system_message = msg
|
330
|
+
if len(self.message_history) > 0:
|
331
|
+
# if there is message history, update the system message in it
|
332
|
+
self.message_history[0].content = msg
|
333
|
+
|
334
|
+
def set_user_message(self, msg: str) -> None:
|
335
|
+
self.user_message = msg
|
336
|
+
|
337
|
+
@property
|
338
|
+
def task_messages(self) -> List[LLMMessage]:
|
339
|
+
"""
|
340
|
+
The task messages are the initial messages that define the task
|
341
|
+
of the agent. There will be at least a system message plus possibly a user msg.
|
342
|
+
Returns:
|
343
|
+
List[LLMMessage]: the task messages
|
344
|
+
"""
|
345
|
+
msgs = [self._create_system_and_tools_message()]
|
346
|
+
if self.user_message:
|
347
|
+
msgs.append(LLMMessage(role=Role.USER, content=self.user_message))
|
348
|
+
return msgs
|
349
|
+
|
350
|
+
def _drop_msg_update_tool_calls(self, msg: LLMMessage) -> None:
|
351
|
+
id2idx = {t.id: i for i, t in enumerate(self.oai_tool_calls)}
|
352
|
+
if msg.role == Role.TOOL:
|
353
|
+
# dropping tool result, so ADD the corresponding tool-call back
|
354
|
+
# to the list of pending calls!
|
355
|
+
id = msg.tool_call_id
|
356
|
+
if id in self.oai_tool_id2call:
|
357
|
+
self.oai_tool_calls.append(self.oai_tool_id2call[id])
|
358
|
+
elif msg.tool_calls is not None:
|
359
|
+
# dropping a msg with tool-calls, so DROP these from pending list
|
360
|
+
# as well as from id -> call map
|
361
|
+
for tool_call in msg.tool_calls:
|
362
|
+
if tool_call.id in id2idx:
|
363
|
+
self.oai_tool_calls.pop(id2idx[tool_call.id])
|
364
|
+
if tool_call.id in self.oai_tool_id2call:
|
365
|
+
del self.oai_tool_id2call[tool_call.id]
|
366
|
+
|
367
|
+
def clear_history(self, start: int = -2) -> None:
|
368
|
+
"""
|
369
|
+
Clear the message history, starting at the index `start`
|
370
|
+
|
371
|
+
Args:
|
372
|
+
start (int): index of first message to delete; default = -2
|
373
|
+
(i.e. delete last 2 messages, typically these
|
374
|
+
are the last user and assistant messages)
|
375
|
+
"""
|
376
|
+
if start < 0:
|
377
|
+
n = len(self.message_history)
|
378
|
+
start = max(0, n + start)
|
379
|
+
dropped = self.message_history[start:]
|
380
|
+
# consider the dropped msgs in REVERSE order, so we are
|
381
|
+
# carefully updating self.oai_tool_calls
|
382
|
+
for msg in reversed(dropped):
|
383
|
+
self._drop_msg_update_tool_calls(msg)
|
384
|
+
# clear out the chat document from the ObjectRegistry
|
385
|
+
ChatDocument.delete_id(msg.chat_document_id)
|
386
|
+
self.message_history = self.message_history[:start]
|
387
|
+
|
388
|
+
def update_history(self, message: str, response: str) -> None:
|
389
|
+
"""
|
390
|
+
Update the message history with the latest user message and LLM response.
|
391
|
+
Args:
|
392
|
+
message (str): user message
|
393
|
+
response: (str): LLM response
|
394
|
+
"""
|
395
|
+
self.message_history.extend(
|
396
|
+
[
|
397
|
+
LLMMessage(role=Role.USER, content=message),
|
398
|
+
LLMMessage(role=Role.ASSISTANT, content=response),
|
399
|
+
]
|
400
|
+
)
|
401
|
+
|
402
|
+
def tool_format_rules(self) -> str:
|
403
|
+
"""
|
404
|
+
Specification of tool formatting rules
|
405
|
+
(typically JSON-based but can be non-JSON, e.g. XMLToolMessage),
|
406
|
+
based on the currently enabled usable `ToolMessage`s
|
407
|
+
|
408
|
+
Returns:
|
409
|
+
str: formatting rules
|
410
|
+
"""
|
411
|
+
# ONLY Usable tools (i.e. LLM-generation allowed),
|
412
|
+
usable_tool_classes: List[Type[ToolMessage]] = [
|
413
|
+
t
|
414
|
+
for t in list(self.llm_tools_map.values())
|
415
|
+
if t.default_value("request") in self.llm_tools_usable
|
416
|
+
]
|
417
|
+
|
418
|
+
if len(usable_tool_classes) == 0:
|
419
|
+
return "You can ask questions in natural language."
|
420
|
+
format_instructions = "\n\n".join(
|
421
|
+
[
|
422
|
+
msg_cls.format_instructions(tool=self.config.use_tools)
|
423
|
+
for msg_cls in usable_tool_classes
|
424
|
+
]
|
425
|
+
)
|
426
|
+
# if any of the enabled classes has json_group_instructions, then use that,
|
427
|
+
# else fall back to ToolMessage.json_group_instructions
|
428
|
+
for msg_cls in usable_tool_classes:
|
429
|
+
if hasattr(msg_cls, "json_group_instructions") and callable(
|
430
|
+
getattr(msg_cls, "json_group_instructions")
|
431
|
+
):
|
432
|
+
return msg_cls.group_format_instructions().format(
|
433
|
+
format_instructions=format_instructions
|
434
|
+
)
|
435
|
+
return ToolMessage.group_format_instructions().format(
|
436
|
+
format_instructions=format_instructions
|
437
|
+
)
|
438
|
+
|
439
|
+
def tool_instructions(self) -> str:
|
440
|
+
"""
|
441
|
+
Instructions for tools or function-calls, for enabled and usable Tools.
|
442
|
+
These are inserted into system prompt regardless of whether we are using
|
443
|
+
our own ToolMessage mechanism or the LLM's function-call mechanism.
|
444
|
+
|
445
|
+
Returns:
|
446
|
+
str: concatenation of instructions for all usable tools
|
447
|
+
"""
|
448
|
+
enabled_classes: List[Type[ToolMessage]] = list(self.llm_tools_map.values())
|
449
|
+
if len(enabled_classes) == 0:
|
450
|
+
return ""
|
451
|
+
instructions = []
|
452
|
+
for msg_cls in enabled_classes:
|
453
|
+
if msg_cls.default_value("request") in self.llm_tools_usable:
|
454
|
+
class_instructions = ""
|
455
|
+
if hasattr(msg_cls, "instructions") and inspect.ismethod(
|
456
|
+
msg_cls.instructions
|
457
|
+
):
|
458
|
+
class_instructions = msg_cls.instructions()
|
459
|
+
if (
|
460
|
+
self.config.use_tools
|
461
|
+
and hasattr(msg_cls, "langroid_tools_instructions")
|
462
|
+
and inspect.ismethod(msg_cls.langroid_tools_instructions)
|
463
|
+
):
|
464
|
+
class_instructions += msg_cls.langroid_tools_instructions()
|
465
|
+
# example will be shown in tool_format_rules() when using TOOLs,
|
466
|
+
# so we don't need to show it here.
|
467
|
+
example = "" if self.config.use_tools else (msg_cls.usage_examples())
|
468
|
+
if example != "":
|
469
|
+
example = "EXAMPLES:\n" + example
|
470
|
+
guidance = (
|
471
|
+
""
|
472
|
+
if class_instructions == ""
|
473
|
+
else ("GUIDANCE: " + class_instructions)
|
474
|
+
)
|
475
|
+
if guidance == "" and example == "":
|
476
|
+
continue
|
477
|
+
instructions.append(
|
478
|
+
textwrap.dedent(
|
479
|
+
f"""
|
480
|
+
TOOL: {msg_cls.default_value("request")}:
|
481
|
+
{guidance}
|
482
|
+
{example}
|
483
|
+
""".lstrip()
|
484
|
+
)
|
485
|
+
)
|
486
|
+
if len(instructions) == 0:
|
487
|
+
return ""
|
488
|
+
instructions_str = "\n\n".join(instructions)
|
489
|
+
return textwrap.dedent(
|
490
|
+
f"""
|
491
|
+
=== GUIDELINES ON SOME TOOLS/FUNCTIONS USAGE ===
|
492
|
+
{instructions_str}
|
493
|
+
""".lstrip()
|
494
|
+
)
|
495
|
+
|
496
|
+
def augment_system_message(self, message: str) -> None:
|
497
|
+
"""
|
498
|
+
Augment the system message with the given message.
|
499
|
+
Args:
|
500
|
+
message (str): system message
|
501
|
+
"""
|
502
|
+
self.system_message += "\n\n" + message
|
503
|
+
|
504
|
+
def last_message_with_role(self, role: Role) -> LLMMessage | None:
|
505
|
+
"""from `message_history`, return the last message with role `role`"""
|
506
|
+
n_role_msgs = len([m for m in self.message_history if m.role == role])
|
507
|
+
if n_role_msgs == 0:
|
508
|
+
return None
|
509
|
+
idx = self.nth_message_idx_with_role(role, n_role_msgs)
|
510
|
+
return self.message_history[idx]
|
511
|
+
|
512
|
+
def nth_message_idx_with_role(self, role: Role, n: int) -> int:
|
513
|
+
"""Index of `n`th message in message_history, with specified role.
|
514
|
+
(n is assumed to be 1-based, i.e. 1 is the first message with that role).
|
515
|
+
Return -1 if not found. Index = 0 is the first message in the history.
|
516
|
+
"""
|
517
|
+
indices_with_role = [
|
518
|
+
i for i, m in enumerate(self.message_history) if m.role == role
|
519
|
+
]
|
520
|
+
|
521
|
+
if len(indices_with_role) < n:
|
522
|
+
return -1
|
523
|
+
return indices_with_role[n - 1]
|
524
|
+
|
525
|
+
def update_last_message(self, message: str, role: str = Role.USER) -> None:
|
526
|
+
"""
|
527
|
+
Update the last message that has role `role` in the message history.
|
528
|
+
Useful when we want to replace a long user prompt, that may contain context
|
529
|
+
documents plus a question, with just the question.
|
530
|
+
Args:
|
531
|
+
message (str): new message to replace with
|
532
|
+
role (str): role of message to replace
|
533
|
+
"""
|
534
|
+
if len(self.message_history) == 0:
|
535
|
+
return
|
536
|
+
# find last message in self.message_history with role `role`
|
537
|
+
for i in range(len(self.message_history) - 1, -1, -1):
|
538
|
+
if self.message_history[i].role == role:
|
539
|
+
self.message_history[i].content = message
|
540
|
+
break
|
541
|
+
|
542
|
+
def _create_system_and_tools_message(self) -> LLMMessage:
|
543
|
+
"""
|
544
|
+
(Re-)Create the system message for the LLM of the agent,
|
545
|
+
taking into account any tool instructions that have been added
|
546
|
+
after the agent was initialized.
|
547
|
+
|
548
|
+
The system message will consist of:
|
549
|
+
(a) the system message from the `task` arg in constructor, if any,
|
550
|
+
otherwise the default system message from the config
|
551
|
+
(b) the system tool instructions, if any
|
552
|
+
(c) the system json tool instructions, if any
|
553
|
+
|
554
|
+
Returns:
|
555
|
+
LLMMessage object
|
556
|
+
"""
|
557
|
+
content = textwrap.dedent(
|
558
|
+
f"""
|
559
|
+
{self.system_message}
|
560
|
+
|
561
|
+
{self.system_tool_instructions}
|
562
|
+
|
563
|
+
{self.system_tool_format_instructions}
|
564
|
+
|
565
|
+
{self.output_format_instructions}
|
566
|
+
"""
|
567
|
+
)
|
568
|
+
# remove leading and trailing newlines and other whitespace
|
569
|
+
return LLMMessage(role=Role.SYSTEM, content=content.strip())
|
570
|
+
|
571
|
+
def unhandled_tools(self) -> set[str]:
|
572
|
+
"""The set of tools that are known but not handled.
|
573
|
+
Useful in task flow: an agent can refuse to accept an incoming msg
|
574
|
+
when it only has unhandled tools.
|
575
|
+
"""
|
576
|
+
return self.llm_tools_known - self.llm_tools_handled
|
577
|
+
|
578
|
+
def enable_message(
|
579
|
+
self,
|
580
|
+
message_class: Optional[Type[ToolMessage] | List[Type[ToolMessage]]],
|
581
|
+
use: bool = True,
|
582
|
+
handle: bool = True,
|
583
|
+
force: bool = False,
|
584
|
+
require_recipient: bool = False,
|
585
|
+
include_defaults: bool = True,
|
586
|
+
) -> None:
|
587
|
+
"""
|
588
|
+
Add the tool (message class) to the agent, and enable either
|
589
|
+
- tool USE (i.e. the LLM can generate JSON to use this tool),
|
590
|
+
- tool HANDLING (i.e. the agent can handle JSON from this tool),
|
591
|
+
|
592
|
+
Args:
|
593
|
+
message_class: The ToolMessage class OR List of such classes to enable,
|
594
|
+
for USE, or HANDLING, or both.
|
595
|
+
If this is a list of ToolMessage classes, then the remain args are
|
596
|
+
applied to all classes.
|
597
|
+
Optional; if None, then apply the enabling to all tools in the
|
598
|
+
agent's toolset that have been enabled so far.
|
599
|
+
use: IF True, allow the agent (LLM) to use this tool (or all tools),
|
600
|
+
else disallow
|
601
|
+
handle: if True, allow the agent (LLM) to handle (i.e. respond to) this
|
602
|
+
tool (or all tools)
|
603
|
+
force: whether to FORCE the agent (LLM) to USE the specific
|
604
|
+
tool represented by `message_class`.
|
605
|
+
`force` is ignored if `message_class` is None.
|
606
|
+
require_recipient: whether to require that recipient be specified
|
607
|
+
when using the tool message (only applies if `use` is True).
|
608
|
+
include_defaults: whether to include fields that have default values,
|
609
|
+
in the "properties" section of the JSON format instructions.
|
610
|
+
(Normally the OpenAI completion API ignores these fields,
|
611
|
+
but the Assistant fn-calling seems to pay attn to these,
|
612
|
+
and if we don't want this, we should set this to False.)
|
613
|
+
"""
|
614
|
+
if message_class is not None and isinstance(message_class, list):
|
615
|
+
for mc in message_class:
|
616
|
+
self.enable_message(
|
617
|
+
mc,
|
618
|
+
use=use,
|
619
|
+
handle=handle,
|
620
|
+
force=force,
|
621
|
+
require_recipient=require_recipient,
|
622
|
+
include_defaults=include_defaults,
|
623
|
+
)
|
624
|
+
return None
|
625
|
+
if require_recipient and message_class is not None:
|
626
|
+
message_class = message_class.require_recipient()
|
627
|
+
if isinstance(message_class, XMLToolMessage):
|
628
|
+
# XMLToolMessage is not compatible with OpenAI's Tools/functions API,
|
629
|
+
# so we disable use of functions API, enable langroid-native Tools,
|
630
|
+
# which are prompt-based.
|
631
|
+
self.config.use_functions_api = False
|
632
|
+
self.config.use_tools = True
|
633
|
+
super().enable_message_handling(message_class) # enables handling only
|
634
|
+
tools = self._get_tool_list(message_class)
|
635
|
+
if message_class is not None:
|
636
|
+
request = message_class.default_value("request")
|
637
|
+
if request == "":
|
638
|
+
raise ValueError(
|
639
|
+
f"""
|
640
|
+
ToolMessage class {message_class} must have a non-empty
|
641
|
+
'request' field if it is to be enabled as a tool.
|
642
|
+
"""
|
643
|
+
)
|
644
|
+
llm_function = message_class.llm_function_schema(defaults=include_defaults)
|
645
|
+
self.llm_functions_map[request] = llm_function
|
646
|
+
if force:
|
647
|
+
self.llm_function_force = dict(name=request)
|
648
|
+
else:
|
649
|
+
self.llm_function_force = None
|
650
|
+
|
651
|
+
for t in tools:
|
652
|
+
self.llm_tools_known.add(t)
|
653
|
+
|
654
|
+
if handle:
|
655
|
+
self.llm_tools_handled.add(t)
|
656
|
+
self.llm_functions_handled.add(t)
|
657
|
+
|
658
|
+
if (
|
659
|
+
self.enabled_handling_output_format is not None
|
660
|
+
and self.enabled_handling_output_format.name() == t
|
661
|
+
):
|
662
|
+
# `t` was designated as "enabled for handling" ONLY for
|
663
|
+
# output_format enforcement, but we are explicitly ]
|
664
|
+
# enabling it for handling here, so we set the variable to None.
|
665
|
+
self.enabled_handling_output_format = None
|
666
|
+
else:
|
667
|
+
self.llm_tools_handled.discard(t)
|
668
|
+
self.llm_functions_handled.discard(t)
|
669
|
+
|
670
|
+
if use:
|
671
|
+
tool_class = self.llm_tools_map[t]
|
672
|
+
if tool_class._allow_llm_use:
|
673
|
+
self.llm_tools_usable.add(t)
|
674
|
+
self.llm_functions_usable.add(t)
|
675
|
+
else:
|
676
|
+
logger.warning(
|
677
|
+
f"""
|
678
|
+
ToolMessage class {tool_class} does not allow LLM use,
|
679
|
+
because `_allow_llm_use=False` either in the Tool or a
|
680
|
+
parent class of this tool;
|
681
|
+
so not enabling LLM use for this tool!
|
682
|
+
If you intended an LLM to use this tool,
|
683
|
+
set `_allow_llm_use=True` when you define the tool.
|
684
|
+
"""
|
685
|
+
)
|
686
|
+
if (
|
687
|
+
self.enabled_use_output_format is not None
|
688
|
+
and self.enabled_use_output_format.default_value("request") == t
|
689
|
+
):
|
690
|
+
# `t` was designated as "enabled for use" ONLY for output_format
|
691
|
+
# enforcement, but we are explicitly enabling it for use here,
|
692
|
+
# so we set the variable to None.
|
693
|
+
self.enabled_use_output_format = None
|
694
|
+
else:
|
695
|
+
self.llm_tools_usable.discard(t)
|
696
|
+
self.llm_functions_usable.discard(t)
|
697
|
+
|
698
|
+
# Set tool instructions and JSON format instructions
|
699
|
+
if self.config.use_tools:
|
700
|
+
self.system_tool_format_instructions = self.tool_format_rules()
|
701
|
+
self.system_tool_instructions = self.tool_instructions()
|
702
|
+
|
703
|
+
def _requests_and_tool_settings(self) -> tuple[Optional[set[str]], bool, bool]:
|
704
|
+
"""
|
705
|
+
Returns the current set of enabled requests for inference and tools configs.
|
706
|
+
Used for restoring setings overriden by `set_output_format`.
|
707
|
+
"""
|
708
|
+
return (
|
709
|
+
self.enabled_requests_for_inference,
|
710
|
+
self.config.use_functions_api,
|
711
|
+
self.config.use_tools,
|
712
|
+
)
|
713
|
+
|
714
|
+
@property
|
715
|
+
def all_llm_tools_known(self) -> set[str]:
|
716
|
+
"""All known tools; we include `output_format` if it is a `ToolMessage`."""
|
717
|
+
known = self.llm_tools_known
|
718
|
+
|
719
|
+
if self.output_format is not None and issubclass(
|
720
|
+
self.output_format, ToolMessage
|
721
|
+
):
|
722
|
+
return known.union({self.output_format.default_value("request")})
|
723
|
+
|
724
|
+
return known
|
725
|
+
|
726
|
+
def set_output_format(
|
727
|
+
self,
|
728
|
+
output_type: Optional[type],
|
729
|
+
force_tools: Optional[bool] = None,
|
730
|
+
use: Optional[bool] = None,
|
731
|
+
handle: Optional[bool] = None,
|
732
|
+
instructions: Optional[bool] = None,
|
733
|
+
is_copy: bool = False,
|
734
|
+
) -> None:
|
735
|
+
"""
|
736
|
+
Sets `output_format` to `output_type` and, if `force_tools` is enabled,
|
737
|
+
switches to the native Langroid tools mechanism to ensure that no tool
|
738
|
+
calls not of `output_type` are generated. By default, `force_tools`
|
739
|
+
follows the `use_tools_on_output_format` parameter in the config.
|
740
|
+
|
741
|
+
If `output_type` is None, restores to the state prior to setting
|
742
|
+
`output_format`.
|
743
|
+
|
744
|
+
If `use`, we enable use of `output_type` when it is a subclass
|
745
|
+
of `ToolMesage`. Note that this primarily controls instruction
|
746
|
+
generation: the model will always generate `output_type` regardless
|
747
|
+
of whether `use` is set. Defaults to the `use_output_format`
|
748
|
+
parameter in the config. Similarly, handling of `output_type` is
|
749
|
+
controlled by `handle`, which defaults to the
|
750
|
+
`handle_output_format` parameter in the config.
|
751
|
+
|
752
|
+
`instructions` controls whether we generate instructions specifying
|
753
|
+
the output format schema. Defaults to the `instructions_output_format`
|
754
|
+
parameter in the config.
|
755
|
+
|
756
|
+
`is_copy` is set when called via `__getitem__`. In that case, we must
|
757
|
+
copy certain fields to ensure that we do not overwrite the main agent's
|
758
|
+
setings.
|
759
|
+
"""
|
760
|
+
# Disable usage of an output format which was not specifically enabled
|
761
|
+
# by `enable_message`
|
762
|
+
if self.enabled_use_output_format is not None:
|
763
|
+
self.disable_message_use(self.enabled_use_output_format)
|
764
|
+
self.enabled_use_output_format = None
|
765
|
+
|
766
|
+
# Disable handling of an output format which did not specifically have
|
767
|
+
# handling enabled via `enable_message`
|
768
|
+
if self.enabled_handling_output_format is not None:
|
769
|
+
self.disable_message_handling(self.enabled_handling_output_format)
|
770
|
+
self.enabled_handling_output_format = None
|
771
|
+
|
772
|
+
# Reset any previous instructions
|
773
|
+
self.output_format_instructions = ""
|
774
|
+
|
775
|
+
if output_type is None:
|
776
|
+
self.output_format = None
|
777
|
+
(
|
778
|
+
requests_for_inference,
|
779
|
+
use_functions_api,
|
780
|
+
use_tools,
|
781
|
+
) = self.saved_requests_and_tool_setings
|
782
|
+
self.config = self.config.copy()
|
783
|
+
self.enabled_requests_for_inference = requests_for_inference
|
784
|
+
self.config.use_functions_api = use_functions_api
|
785
|
+
self.config.use_tools = use_tools
|
786
|
+
else:
|
787
|
+
if force_tools is None:
|
788
|
+
force_tools = self.config.use_tools_on_output_format
|
789
|
+
|
790
|
+
if not any(
|
791
|
+
(isclass(output_type) and issubclass(output_type, t))
|
792
|
+
for t in [ToolMessage, BaseModel]
|
793
|
+
):
|
794
|
+
output_type = get_pydantic_wrapper(output_type)
|
795
|
+
|
796
|
+
if self.output_format is None and force_tools:
|
797
|
+
self.saved_requests_and_tool_setings = (
|
798
|
+
self._requests_and_tool_settings()
|
799
|
+
)
|
800
|
+
|
801
|
+
self.output_format = output_type
|
802
|
+
if issubclass(output_type, ToolMessage):
|
803
|
+
name = output_type.default_value("request")
|
804
|
+
if use is None:
|
805
|
+
use = self.config.use_output_format
|
806
|
+
|
807
|
+
if handle is None:
|
808
|
+
handle = self.config.handle_output_format
|
809
|
+
|
810
|
+
if use or handle:
|
811
|
+
is_usable = name in self.llm_tools_usable.union(
|
812
|
+
self.llm_functions_usable
|
813
|
+
)
|
814
|
+
is_handled = name in self.llm_tools_handled.union(
|
815
|
+
self.llm_functions_handled
|
816
|
+
)
|
817
|
+
|
818
|
+
if is_copy:
|
819
|
+
if use:
|
820
|
+
# We must copy `llm_tools_usable` so the base agent
|
821
|
+
# is unmodified
|
822
|
+
self.llm_tools_usable = copy.copy(self.llm_tools_usable)
|
823
|
+
self.llm_functions_usable = copy.copy(
|
824
|
+
self.llm_functions_usable
|
825
|
+
)
|
826
|
+
if handle:
|
827
|
+
# If handling the tool, do the same for `llm_tools_handled`
|
828
|
+
self.llm_tools_handled = copy.copy(self.llm_tools_handled)
|
829
|
+
self.llm_functions_handled = copy.copy(
|
830
|
+
self.llm_functions_handled
|
831
|
+
)
|
832
|
+
# Enable `output_type`
|
833
|
+
self.enable_message(
|
834
|
+
output_type,
|
835
|
+
# Do not override existing settings
|
836
|
+
use=use or is_usable,
|
837
|
+
handle=handle or is_handled,
|
838
|
+
)
|
839
|
+
|
840
|
+
# If the `output_type` ToilMessage was not already enabled for
|
841
|
+
# use, this means we are ONLY enabling it for use specifically
|
842
|
+
# for enforcing this output format, so we set the
|
843
|
+
# `enabled_use_output_forma to this output_type, to
|
844
|
+
# record that it should be disabled when `output_format` is changed
|
845
|
+
if not is_usable:
|
846
|
+
self.enabled_use_output_format = output_type
|
847
|
+
|
848
|
+
# (same reasoning as for use-enabling)
|
849
|
+
if not is_handled:
|
850
|
+
self.enabled_handling_output_format = output_type
|
851
|
+
|
852
|
+
generated_tool_instructions = name in self.llm_tools_usable.union(
|
853
|
+
self.llm_functions_usable
|
854
|
+
)
|
855
|
+
else:
|
856
|
+
generated_tool_instructions = False
|
857
|
+
|
858
|
+
if instructions is None:
|
859
|
+
instructions = self.config.instructions_output_format
|
860
|
+
if issubclass(output_type, BaseModel) and instructions:
|
861
|
+
if generated_tool_instructions:
|
862
|
+
# Already generated tool instructions as part of "enabling for use",
|
863
|
+
# so only need to generate a reminder to use this tool.
|
864
|
+
name = cast(ToolMessage, output_type).default_value("request")
|
865
|
+
self.output_format_instructions = textwrap.dedent(
|
866
|
+
f"""
|
867
|
+
=== OUTPUT FORMAT INSTRUCTIONS ===
|
868
|
+
|
869
|
+
Please provide output using the `{name}` tool/function.
|
870
|
+
"""
|
871
|
+
)
|
872
|
+
else:
|
873
|
+
if issubclass(output_type, ToolMessage):
|
874
|
+
output_format_schema = output_type.llm_function_schema(
|
875
|
+
request=True,
|
876
|
+
defaults=self.config.output_format_include_defaults,
|
877
|
+
).parameters
|
878
|
+
else:
|
879
|
+
output_format_schema = output_type.schema()
|
880
|
+
|
881
|
+
format_schema_for_strict(output_format_schema)
|
882
|
+
|
883
|
+
self.output_format_instructions = textwrap.dedent(
|
884
|
+
f"""
|
885
|
+
=== OUTPUT FORMAT INSTRUCTIONS ===
|
886
|
+
Please provide output as JSON with the following schema:
|
887
|
+
|
888
|
+
{output_format_schema}
|
889
|
+
"""
|
890
|
+
)
|
891
|
+
|
892
|
+
if force_tools:
|
893
|
+
if issubclass(output_type, ToolMessage):
|
894
|
+
self.enabled_requests_for_inference = {
|
895
|
+
output_type.default_value("request")
|
896
|
+
}
|
897
|
+
if self.config.use_functions_api:
|
898
|
+
self.config = self.config.copy()
|
899
|
+
self.config.use_functions_api = False
|
900
|
+
self.config.use_tools = True
|
901
|
+
|
902
|
+
def __getitem__(self, output_type: type) -> Self:
|
903
|
+
"""
|
904
|
+
Returns a (shallow) copy of `self` with a forced output type.
|
905
|
+
"""
|
906
|
+
clone = copy.copy(self)
|
907
|
+
clone.set_output_format(output_type, is_copy=True)
|
908
|
+
return clone
|
909
|
+
|
910
|
+
def disable_message_handling(
|
911
|
+
self,
|
912
|
+
message_class: Optional[Type[ToolMessage]] = None,
|
913
|
+
) -> None:
|
914
|
+
"""
|
915
|
+
Disable this agent from RESPONDING to a `message_class` (Tool). If
|
916
|
+
`message_class` is None, then disable this agent from responding to ALL.
|
917
|
+
Args:
|
918
|
+
message_class: The ToolMessage class to disable; Optional.
|
919
|
+
"""
|
920
|
+
super().disable_message_handling(message_class)
|
921
|
+
for t in self._get_tool_list(message_class):
|
922
|
+
self.llm_tools_handled.discard(t)
|
923
|
+
self.llm_functions_handled.discard(t)
|
924
|
+
|
925
|
+
def disable_message_use(
|
926
|
+
self,
|
927
|
+
message_class: Optional[Type[ToolMessage]],
|
928
|
+
) -> None:
|
929
|
+
"""
|
930
|
+
Disable this agent from USING a message class (Tool).
|
931
|
+
If `message_class` is None, then disable this agent from USING ALL tools.
|
932
|
+
Args:
|
933
|
+
message_class: The ToolMessage class to disable.
|
934
|
+
If None, disable all.
|
935
|
+
"""
|
936
|
+
for t in self._get_tool_list(message_class):
|
937
|
+
self.llm_tools_usable.discard(t)
|
938
|
+
self.llm_functions_usable.discard(t)
|
939
|
+
|
940
|
+
def disable_message_use_except(self, message_class: Type[ToolMessage]) -> None:
|
941
|
+
"""
|
942
|
+
Disable this agent from USING ALL messages EXCEPT a message class (Tool)
|
943
|
+
Args:
|
944
|
+
message_class: The only ToolMessage class to allow
|
945
|
+
"""
|
946
|
+
request = message_class.__fields__["request"].default
|
947
|
+
to_remove = [r for r in self.llm_tools_usable if r != request]
|
948
|
+
for r in to_remove:
|
949
|
+
self.llm_tools_usable.discard(r)
|
950
|
+
self.llm_functions_usable.discard(r)
|
951
|
+
|
952
|
+
def _load_output_format(self, message: ChatDocument) -> None:
|
953
|
+
"""
|
954
|
+
If set, attempts to parse a value of type `self.output_format` from the message
|
955
|
+
contents or any tool/function call and assigns it to `content_any`.
|
956
|
+
"""
|
957
|
+
if self.output_format is not None:
|
958
|
+
any_succeeded = False
|
959
|
+
attempts: list[str | LLMFunctionCall] = [
|
960
|
+
message.content,
|
961
|
+
]
|
962
|
+
|
963
|
+
if message.function_call is not None:
|
964
|
+
attempts.append(message.function_call)
|
965
|
+
|
966
|
+
if message.oai_tool_calls is not None:
|
967
|
+
attempts.extend(
|
968
|
+
[
|
969
|
+
c.function
|
970
|
+
for c in message.oai_tool_calls
|
971
|
+
if c.function is not None
|
972
|
+
]
|
973
|
+
)
|
974
|
+
|
975
|
+
for attempt in attempts:
|
976
|
+
try:
|
977
|
+
if isinstance(attempt, str):
|
978
|
+
content = json.loads(attempt)
|
979
|
+
else:
|
980
|
+
if not (
|
981
|
+
issubclass(self.output_format, ToolMessage)
|
982
|
+
and attempt.name
|
983
|
+
== self.output_format.default_value("request")
|
984
|
+
):
|
985
|
+
continue
|
986
|
+
|
987
|
+
content = attempt.arguments
|
988
|
+
|
989
|
+
content_any = self.output_format.parse_obj(content)
|
990
|
+
|
991
|
+
if issubclass(self.output_format, PydanticWrapper):
|
992
|
+
message.content_any = content_any.value # type: ignore
|
993
|
+
else:
|
994
|
+
message.content_any = content_any
|
995
|
+
any_succeeded = True
|
996
|
+
break
|
997
|
+
except (ValidationError, json.JSONDecodeError):
|
998
|
+
continue
|
999
|
+
|
1000
|
+
if not any_succeeded:
|
1001
|
+
self.disable_strict = True
|
1002
|
+
logging.warning(
|
1003
|
+
"""
|
1004
|
+
Validation error occured with strict output format enabled.
|
1005
|
+
Disabling strict mode.
|
1006
|
+
"""
|
1007
|
+
)
|
1008
|
+
|
1009
|
+
def get_tool_messages(
|
1010
|
+
self,
|
1011
|
+
msg: str | ChatDocument | None,
|
1012
|
+
all_tools: bool = False,
|
1013
|
+
) -> List[ToolMessage]:
|
1014
|
+
"""
|
1015
|
+
Extracts messages and tracks whether any errors occured. If strict mode
|
1016
|
+
was enabled, disables it for the tool, else triggers strict recovery.
|
1017
|
+
"""
|
1018
|
+
self.tool_error = False
|
1019
|
+
try:
|
1020
|
+
tools = super().get_tool_messages(msg, all_tools)
|
1021
|
+
except ValidationError as ve:
|
1022
|
+
tool_class = ve.model
|
1023
|
+
if issubclass(tool_class, ToolMessage):
|
1024
|
+
was_strict = (
|
1025
|
+
self.config.use_functions_api
|
1026
|
+
and self.config.use_tools_api
|
1027
|
+
and self._strict_mode_for_tool(tool_class)
|
1028
|
+
)
|
1029
|
+
# If the result of strict output for a tool using the
|
1030
|
+
# OpenAI tools API fails to parse, we infer that the
|
1031
|
+
# schema edits necessary for compatibility prevented
|
1032
|
+
# adherence to the underlying `ToolMessage` schema and
|
1033
|
+
# disable strict output for the tool
|
1034
|
+
if was_strict:
|
1035
|
+
name = tool_class.default_value("request")
|
1036
|
+
self.disable_strict_tools_set.add(name)
|
1037
|
+
logging.warning(
|
1038
|
+
f"""
|
1039
|
+
Validation error occured with strict tool format.
|
1040
|
+
Disabling strict mode for the {name} tool.
|
1041
|
+
"""
|
1042
|
+
)
|
1043
|
+
else:
|
1044
|
+
# We will trigger the strict recovery mechanism to force
|
1045
|
+
# the LLM to correct its output, allowing us to parse
|
1046
|
+
self.tool_error = True
|
1047
|
+
|
1048
|
+
raise ve
|
1049
|
+
|
1050
|
+
return tools
|
1051
|
+
|
1052
|
+
def _get_any_tool_message(self, optional: bool = True) -> type[ToolMessage]:
|
1053
|
+
"""
|
1054
|
+
Returns a `ToolMessage` which wraps all enabled tools, excluding those
|
1055
|
+
where strict recovery is disabled. Used in strict recovery.
|
1056
|
+
"""
|
1057
|
+
any_tool_type = Union[ # type: ignore
|
1058
|
+
*(
|
1059
|
+
self.llm_tools_map[t]
|
1060
|
+
for t in self.llm_tools_usable
|
1061
|
+
if t not in self.disable_strict_tools_set
|
1062
|
+
)
|
1063
|
+
]
|
1064
|
+
maybe_optional_type = Optional[any_tool_type] if optional else any_tool_type
|
1065
|
+
|
1066
|
+
class AnyTool(ToolMessage):
|
1067
|
+
purpose: str = "To call a tool/function."
|
1068
|
+
request: str = "tool_or_function"
|
1069
|
+
tool: maybe_optional_type # type: ignore
|
1070
|
+
|
1071
|
+
def response(self, agent: ChatAgent) -> None | str | ChatDocument:
|
1072
|
+
# One-time use
|
1073
|
+
agent.set_output_format(None)
|
1074
|
+
|
1075
|
+
if self.tool is None:
|
1076
|
+
return None
|
1077
|
+
|
1078
|
+
# As the ToolMessage schema accepts invalid
|
1079
|
+
# `tool.request` values, reparse with the
|
1080
|
+
# corresponding tool
|
1081
|
+
request = self.tool.request
|
1082
|
+
if request not in agent.llm_tools_map:
|
1083
|
+
return None
|
1084
|
+
tool = agent.llm_tools_map[request].parse_raw(self.tool.to_json())
|
1085
|
+
|
1086
|
+
return agent.handle_tool_message(tool)
|
1087
|
+
|
1088
|
+
async def response_async(
|
1089
|
+
self, agent: ChatAgent
|
1090
|
+
) -> None | str | ChatDocument:
|
1091
|
+
# One-time use
|
1092
|
+
agent.set_output_format(None)
|
1093
|
+
|
1094
|
+
if self.tool is None:
|
1095
|
+
return None
|
1096
|
+
|
1097
|
+
# As the ToolMessage schema accepts invalid
|
1098
|
+
# `tool.request` values, reparse with the
|
1099
|
+
# corresponding tool
|
1100
|
+
request = self.tool.request
|
1101
|
+
if request not in agent.llm_tools_map:
|
1102
|
+
return None
|
1103
|
+
tool = agent.llm_tools_map[request].parse_raw(self.tool.to_json())
|
1104
|
+
|
1105
|
+
return await agent.handle_tool_message_async(tool)
|
1106
|
+
|
1107
|
+
return AnyTool
|
1108
|
+
|
1109
|
+
def _strict_recovery_instructions(
|
1110
|
+
self,
|
1111
|
+
tool_type: Optional[type[ToolMessage]] = None,
|
1112
|
+
optional: bool = True,
|
1113
|
+
) -> str:
|
1114
|
+
"""Returns instructions for strict recovery."""
|
1115
|
+
optional_instructions = (
|
1116
|
+
(
|
1117
|
+
"\n"
|
1118
|
+
+ """
|
1119
|
+
If you did NOT intend to do so, `tool` should be null.
|
1120
|
+
"""
|
1121
|
+
)
|
1122
|
+
if optional
|
1123
|
+
else ""
|
1124
|
+
)
|
1125
|
+
response_prefix = "If you intended to make such a call, r" if optional else "R"
|
1126
|
+
instruction_prefix = "If you do so, b" if optional else "B"
|
1127
|
+
|
1128
|
+
schema_instructions = (
|
1129
|
+
f"""
|
1130
|
+
The schema for `tool_or_function` is as follows:
|
1131
|
+
{tool_type.llm_function_schema(defaults=True, request=True).parameters}
|
1132
|
+
"""
|
1133
|
+
if tool_type
|
1134
|
+
else ""
|
1135
|
+
)
|
1136
|
+
|
1137
|
+
return textwrap.dedent(
|
1138
|
+
f"""
|
1139
|
+
Your previous attempt to make a tool/function call appears to have failed.
|
1140
|
+
{response_prefix}espond with your desired tool/function. Do so with the
|
1141
|
+
`tool_or_function` tool/function where `tool` is set to your intended call.
|
1142
|
+
{schema_instructions}
|
1143
|
+
|
1144
|
+
{instruction_prefix}e sure that your corrected call matches your intention
|
1145
|
+
in your previous request. For any field with a default value which
|
1146
|
+
you did not intend to override in your previous attempt, be sure
|
1147
|
+
to set that field to its default value. {optional_instructions}
|
1148
|
+
"""
|
1149
|
+
)
|
1150
|
+
|
1151
|
+
def truncate_message(
|
1152
|
+
self,
|
1153
|
+
idx: int,
|
1154
|
+
tokens: int = 5,
|
1155
|
+
warning: str = "...[Contents truncated!]",
|
1156
|
+
) -> LLMMessage:
|
1157
|
+
"""Truncate message at idx in msg history to `tokens` tokens"""
|
1158
|
+
llm_msg = self.message_history[idx]
|
1159
|
+
orig_content = llm_msg.content
|
1160
|
+
new_content = (
|
1161
|
+
self.parser.truncate_tokens(orig_content, tokens)
|
1162
|
+
if self.parser is not None
|
1163
|
+
else orig_content[: tokens * 4] # approx truncation
|
1164
|
+
)
|
1165
|
+
llm_msg.content = new_content + "\n" + warning
|
1166
|
+
return llm_msg
|
1167
|
+
|
1168
|
+
def _reduce_raw_tool_results(self, message: ChatDocument) -> None:
|
1169
|
+
"""
|
1170
|
+
If message is the result of a ToolMessage that had
|
1171
|
+
a `_max_retained_tokens` set to a non-None value, then we replace contents
|
1172
|
+
with a placeholder message.
|
1173
|
+
"""
|
1174
|
+
parent_message: ChatDocument | None = message.parent
|
1175
|
+
tools = [] if parent_message is None else parent_message.tool_messages
|
1176
|
+
truncate_tools = [t for t in tools if t._max_retained_tokens is not None]
|
1177
|
+
limiting_tool = truncate_tools[0] if len(truncate_tools) > 0 else None
|
1178
|
+
if limiting_tool is not None and limiting_tool._max_retained_tokens is not None:
|
1179
|
+
tool_name = limiting_tool.default_value("request")
|
1180
|
+
max_tokens: int = limiting_tool._max_retained_tokens
|
1181
|
+
truncation_warning = f"""
|
1182
|
+
The result of the {tool_name} tool were too large,
|
1183
|
+
and has been truncated to {max_tokens} tokens.
|
1184
|
+
To obtain the full result, the tool needs to be re-used.
|
1185
|
+
"""
|
1186
|
+
self.truncate_message(
|
1187
|
+
message.metadata.msg_idx, max_tokens, truncation_warning
|
1188
|
+
)
|
1189
|
+
|
1190
|
+
def llm_response(
|
1191
|
+
self, message: Optional[str | ChatDocument] = None
|
1192
|
+
) -> Optional[ChatDocument]:
|
1193
|
+
"""
|
1194
|
+
Respond to a single user message, appended to the message history,
|
1195
|
+
in "chat" mode
|
1196
|
+
Args:
|
1197
|
+
message (str|ChatDocument): message or ChatDocument object to respond to.
|
1198
|
+
If None, use the self.task_messages
|
1199
|
+
Returns:
|
1200
|
+
LLM response as a ChatDocument object
|
1201
|
+
"""
|
1202
|
+
if self.llm is None:
|
1203
|
+
return None
|
1204
|
+
|
1205
|
+
# If enabled and a tool error occurred, we recover by generating the tool in
|
1206
|
+
# strict json mode
|
1207
|
+
if (
|
1208
|
+
self.tool_error
|
1209
|
+
and self.output_format is None
|
1210
|
+
and self._json_schema_available()
|
1211
|
+
and self.config.strict_recovery
|
1212
|
+
):
|
1213
|
+
AnyTool = self._get_any_tool_message()
|
1214
|
+
self.set_output_format(
|
1215
|
+
AnyTool,
|
1216
|
+
force_tools=True,
|
1217
|
+
use=True,
|
1218
|
+
handle=True,
|
1219
|
+
instructions=True,
|
1220
|
+
)
|
1221
|
+
recovery_message = self._strict_recovery_instructions(AnyTool)
|
1222
|
+
|
1223
|
+
if message is None:
|
1224
|
+
message = recovery_message
|
1225
|
+
elif isinstance(message, str):
|
1226
|
+
message = message + recovery_message
|
1227
|
+
else:
|
1228
|
+
message.content = message.content + recovery_message
|
1229
|
+
|
1230
|
+
return self.llm_response(message)
|
1231
|
+
|
1232
|
+
hist, output_len = self._prep_llm_messages(message)
|
1233
|
+
if len(hist) == 0:
|
1234
|
+
return None
|
1235
|
+
tool_choice = (
|
1236
|
+
"auto"
|
1237
|
+
if isinstance(message, str)
|
1238
|
+
else (message.oai_tool_choice if message is not None else "auto")
|
1239
|
+
)
|
1240
|
+
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
1241
|
+
try:
|
1242
|
+
response = self.llm_response_messages(hist, output_len, tool_choice)
|
1243
|
+
except openai.BadRequestError as e:
|
1244
|
+
if self.any_strict:
|
1245
|
+
self.disable_strict = True
|
1246
|
+
self.set_output_format(None)
|
1247
|
+
logging.warning(
|
1248
|
+
f"""
|
1249
|
+
OpenAI BadRequestError raised with strict mode enabled.
|
1250
|
+
Message: {e.message}
|
1251
|
+
Disabling strict mode and retrying.
|
1252
|
+
"""
|
1253
|
+
)
|
1254
|
+
return self.llm_response(message)
|
1255
|
+
else:
|
1256
|
+
raise e
|
1257
|
+
self.message_history.extend(ChatDocument.to_LLMMessage(response))
|
1258
|
+
response.metadata.msg_idx = len(self.message_history) - 1
|
1259
|
+
response.metadata.agent_id = self.id
|
1260
|
+
if isinstance(message, ChatDocument):
|
1261
|
+
self._reduce_raw_tool_results(message)
|
1262
|
+
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
1263
|
+
response.metadata.tool_ids = (
|
1264
|
+
[]
|
1265
|
+
if isinstance(message, str)
|
1266
|
+
else message.metadata.tool_ids if message is not None else []
|
1267
|
+
)
|
1268
|
+
|
1269
|
+
return response
|
1270
|
+
|
1271
|
+
async def llm_response_async(
|
1272
|
+
self, message: Optional[str | ChatDocument] = None
|
1273
|
+
) -> Optional[ChatDocument]:
|
1274
|
+
"""
|
1275
|
+
Async version of `llm_response`. See there for details.
|
1276
|
+
"""
|
1277
|
+
if self.llm is None:
|
1278
|
+
return None
|
1279
|
+
|
1280
|
+
# If enabled and a tool error occurred, we recover by generating the tool in
|
1281
|
+
# strict json mode
|
1282
|
+
if (
|
1283
|
+
self.tool_error
|
1284
|
+
and self.output_format is None
|
1285
|
+
and self._json_schema_available()
|
1286
|
+
and self.config.strict_recovery
|
1287
|
+
):
|
1288
|
+
AnyTool = self._get_any_tool_message()
|
1289
|
+
self.set_output_format(
|
1290
|
+
AnyTool,
|
1291
|
+
force_tools=True,
|
1292
|
+
use=True,
|
1293
|
+
handle=True,
|
1294
|
+
instructions=True,
|
1295
|
+
)
|
1296
|
+
recovery_message = self._strict_recovery_instructions(AnyTool)
|
1297
|
+
|
1298
|
+
if message is None:
|
1299
|
+
message = recovery_message
|
1300
|
+
elif isinstance(message, str):
|
1301
|
+
message = message + recovery_message
|
1302
|
+
else:
|
1303
|
+
message.content = message.content + recovery_message
|
1304
|
+
|
1305
|
+
return self.llm_response(message)
|
1306
|
+
|
1307
|
+
hist, output_len = self._prep_llm_messages(message)
|
1308
|
+
if len(hist) == 0:
|
1309
|
+
return None
|
1310
|
+
tool_choice = (
|
1311
|
+
"auto"
|
1312
|
+
if isinstance(message, str)
|
1313
|
+
else (message.oai_tool_choice if message is not None else "auto")
|
1314
|
+
)
|
1315
|
+
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
1316
|
+
try:
|
1317
|
+
response = await self.llm_response_messages_async(
|
1318
|
+
hist, output_len, tool_choice
|
1319
|
+
)
|
1320
|
+
except openai.BadRequestError as e:
|
1321
|
+
if self.any_strict:
|
1322
|
+
self.disable_strict = True
|
1323
|
+
self.set_output_format(None)
|
1324
|
+
logging.warning(
|
1325
|
+
f"""
|
1326
|
+
OpenAI BadRequestError raised with strict mode enabled.
|
1327
|
+
Message: {e.message}
|
1328
|
+
Disabling strict mode and retrying.
|
1329
|
+
"""
|
1330
|
+
)
|
1331
|
+
return await self.llm_response_async(message)
|
1332
|
+
else:
|
1333
|
+
raise e
|
1334
|
+
self.message_history.extend(ChatDocument.to_LLMMessage(response))
|
1335
|
+
response.metadata.msg_idx = len(self.message_history) - 1
|
1336
|
+
response.metadata.agent_id = self.id
|
1337
|
+
if isinstance(message, ChatDocument):
|
1338
|
+
self._reduce_raw_tool_results(message)
|
1339
|
+
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
1340
|
+
response.metadata.tool_ids = (
|
1341
|
+
[]
|
1342
|
+
if isinstance(message, str)
|
1343
|
+
else message.metadata.tool_ids if message is not None else []
|
1344
|
+
)
|
1345
|
+
|
1346
|
+
return response
|
1347
|
+
|
1348
|
+
def init_message_history(self) -> None:
|
1349
|
+
"""
|
1350
|
+
Initialize the message history with the system message and user message
|
1351
|
+
"""
|
1352
|
+
self.message_history = [self._create_system_and_tools_message()]
|
1353
|
+
if self.user_message:
|
1354
|
+
self.message_history.append(
|
1355
|
+
LLMMessage(role=Role.USER, content=self.user_message)
|
1356
|
+
)
|
1357
|
+
|
1358
|
+
def _prep_llm_messages(
|
1359
|
+
self,
|
1360
|
+
message: Optional[str | ChatDocument] = None,
|
1361
|
+
truncate: bool = True,
|
1362
|
+
) -> Tuple[List[LLMMessage], int]:
|
1363
|
+
"""
|
1364
|
+
Prepare messages to be sent to self.llm_response_messages,
|
1365
|
+
which is the main method that calls the LLM API to get a response.
|
1366
|
+
|
1367
|
+
Returns:
|
1368
|
+
Tuple[List[LLMMessage], int]: (messages, output_len)
|
1369
|
+
messages = Full list of messages to send
|
1370
|
+
output_len = max expected number of tokens in response
|
1371
|
+
"""
|
1372
|
+
|
1373
|
+
if (
|
1374
|
+
not self.llm_can_respond(message)
|
1375
|
+
or self.config.llm is None
|
1376
|
+
or self.llm is None
|
1377
|
+
):
|
1378
|
+
return [], 0
|
1379
|
+
|
1380
|
+
if message is None and len(self.message_history) > 0:
|
1381
|
+
# this means agent has been used to get LLM response already,
|
1382
|
+
# and so the last message is an "assistant" response.
|
1383
|
+
# We delete this last assistant response and re-generate it.
|
1384
|
+
self.clear_history(-1)
|
1385
|
+
logger.warning(
|
1386
|
+
"Re-generating the last assistant response since message is None"
|
1387
|
+
)
|
1388
|
+
|
1389
|
+
if len(self.message_history) == 0:
|
1390
|
+
# initial messages have not yet been loaded, so load them
|
1391
|
+
self.init_message_history()
|
1392
|
+
|
1393
|
+
# for debugging, show the initial message history
|
1394
|
+
if settings.debug:
|
1395
|
+
print(
|
1396
|
+
f"""
|
1397
|
+
[grey37]LLM Initial Msg History:
|
1398
|
+
{escape(self.message_history_str())}
|
1399
|
+
[/grey37]
|
1400
|
+
"""
|
1401
|
+
)
|
1402
|
+
else:
|
1403
|
+
assert self.message_history[0].role == Role.SYSTEM
|
1404
|
+
# update the system message with the latest tool instructions
|
1405
|
+
self.message_history[0] = self._create_system_and_tools_message()
|
1406
|
+
|
1407
|
+
if message is not None:
|
1408
|
+
if (
|
1409
|
+
isinstance(message, str)
|
1410
|
+
or message.id() != self.message_history[-1].chat_document_id
|
1411
|
+
):
|
1412
|
+
# either the message is a str, or it is a fresh ChatDocument
|
1413
|
+
# different from the last message in the history
|
1414
|
+
llm_msgs = ChatDocument.to_LLMMessage(message, self.oai_tool_calls)
|
1415
|
+
# LLM only responds to the content, so only those msgs with
|
1416
|
+
# non-empty content should be kept
|
1417
|
+
llm_msgs = [m for m in llm_msgs if m.content.strip() != ""]
|
1418
|
+
if len(llm_msgs) == 0:
|
1419
|
+
return [], 0
|
1420
|
+
# process tools if any
|
1421
|
+
done_tools = [m.tool_call_id for m in llm_msgs if m.role == Role.TOOL]
|
1422
|
+
self.oai_tool_calls = [
|
1423
|
+
t for t in self.oai_tool_calls if t.id not in done_tools
|
1424
|
+
]
|
1425
|
+
self.message_history.extend(llm_msgs)
|
1426
|
+
|
1427
|
+
hist = self.message_history
|
1428
|
+
output_len = self.config.llm.max_output_tokens
|
1429
|
+
if (
|
1430
|
+
truncate
|
1431
|
+
and self.chat_num_tokens(hist)
|
1432
|
+
> self.llm.chat_context_length() - self.config.llm.max_output_tokens
|
1433
|
+
):
|
1434
|
+
# chat + output > max context length,
|
1435
|
+
# so first try to shorten requested output len to fit.
|
1436
|
+
output_len = self.llm.chat_context_length() - self.chat_num_tokens(hist)
|
1437
|
+
if output_len < self.config.llm.min_output_tokens:
|
1438
|
+
# unacceptably small output len, so drop early parts of conv history
|
1439
|
+
# if output_len is still too long, then drop early parts of conv history
|
1440
|
+
# TODO we should really be doing summarization or other types of
|
1441
|
+
# prompt-size reduction
|
1442
|
+
while (
|
1443
|
+
self.chat_num_tokens(hist)
|
1444
|
+
> self.llm.chat_context_length() - self.config.llm.min_output_tokens
|
1445
|
+
):
|
1446
|
+
# try dropping early parts of conv history
|
1447
|
+
# TODO we should really be doing summarization or other types of
|
1448
|
+
# prompt-size reduction
|
1449
|
+
if len(hist) <= 2:
|
1450
|
+
# We want to preserve the first message (typically system msg)
|
1451
|
+
# and last message (user msg).
|
1452
|
+
raise ValueError(
|
1453
|
+
"""
|
1454
|
+
The message history is longer than the max chat context
|
1455
|
+
length allowed, and we have run out of messages to drop.
|
1456
|
+
HINT: In your `OpenAIGPTConfig` object, try increasing
|
1457
|
+
`chat_context_length` or decreasing `max_output_tokens`.
|
1458
|
+
"""
|
1459
|
+
)
|
1460
|
+
# drop the second message, i.e. first msg after the sys msg
|
1461
|
+
# (typically user msg).
|
1462
|
+
ChatDocument.delete_id(hist[1].chat_document_id)
|
1463
|
+
hist = hist[:1] + hist[2:]
|
1464
|
+
|
1465
|
+
if len(hist) < len(self.message_history):
|
1466
|
+
msg_tokens = self.chat_num_tokens()
|
1467
|
+
logger.warning(
|
1468
|
+
f"""
|
1469
|
+
Chat Model context length is {self.llm.chat_context_length()}
|
1470
|
+
tokens, but the current message history is {msg_tokens} tokens long.
|
1471
|
+
Dropped the {len(self.message_history) - len(hist)} messages
|
1472
|
+
from early in the conversation history so that history token
|
1473
|
+
length is {self.chat_num_tokens(hist)}.
|
1474
|
+
This may still not be low enough to allow minimum output length of
|
1475
|
+
{self.config.llm.min_output_tokens} tokens.
|
1476
|
+
"""
|
1477
|
+
)
|
1478
|
+
|
1479
|
+
if output_len < 0:
|
1480
|
+
raise ValueError(
|
1481
|
+
f"""
|
1482
|
+
Tried to shorten prompt history for chat mode
|
1483
|
+
but even after dropping all messages except system msg and last (
|
1484
|
+
user) msg, the history token len {self.chat_num_tokens(hist)} is longer
|
1485
|
+
than the model's max context length {self.llm.chat_context_length()}.
|
1486
|
+
Please try shortening the system msg or user prompts.
|
1487
|
+
"""
|
1488
|
+
)
|
1489
|
+
if output_len < self.config.llm.min_output_tokens:
|
1490
|
+
logger.warning(
|
1491
|
+
f"""
|
1492
|
+
Tried to shorten prompt history for chat mode
|
1493
|
+
but the feasible output length {output_len} is still
|
1494
|
+
less than the minimum output length {self.config.llm.min_output_tokens}.
|
1495
|
+
Your chat history is too long for this model,
|
1496
|
+
and the response may be truncated.
|
1497
|
+
"""
|
1498
|
+
)
|
1499
|
+
if isinstance(message, ChatDocument):
|
1500
|
+
# record the position of the corresponding LLMMessage in
|
1501
|
+
# the message_history
|
1502
|
+
message.metadata.msg_idx = len(hist) - 1
|
1503
|
+
message.metadata.agent_id = self.id
|
1504
|
+
|
1505
|
+
return hist, output_len
|
1506
|
+
|
1507
|
+
def _function_args(
|
1508
|
+
self,
|
1509
|
+
) -> Tuple[
|
1510
|
+
Optional[List[LLMFunctionSpec]],
|
1511
|
+
str | Dict[str, str],
|
1512
|
+
Optional[List[OpenAIToolSpec]],
|
1513
|
+
Optional[Dict[str, Dict[str, str] | str]],
|
1514
|
+
Optional[OpenAIJsonSchemaSpec],
|
1515
|
+
]:
|
1516
|
+
"""
|
1517
|
+
Get function/tool spec/output format arguments for
|
1518
|
+
OpenAI-compatible LLM API call
|
1519
|
+
"""
|
1520
|
+
functions: Optional[List[LLMFunctionSpec]] = None
|
1521
|
+
fun_call: str | Dict[str, str] = "none"
|
1522
|
+
tools: Optional[List[OpenAIToolSpec]] = None
|
1523
|
+
force_tool: Optional[Dict[str, Dict[str, str] | str]] = None
|
1524
|
+
self.any_strict = False
|
1525
|
+
if self.config.use_functions_api and len(self.llm_functions_usable) > 0:
|
1526
|
+
if not self.config.use_tools_api:
|
1527
|
+
functions = [
|
1528
|
+
self.llm_functions_map[f] for f in self.llm_functions_usable
|
1529
|
+
]
|
1530
|
+
fun_call = (
|
1531
|
+
"auto"
|
1532
|
+
if self.llm_function_force is None
|
1533
|
+
else self.llm_function_force
|
1534
|
+
)
|
1535
|
+
else:
|
1536
|
+
|
1537
|
+
def to_maybe_strict_spec(function: str) -> OpenAIToolSpec:
|
1538
|
+
spec = self.llm_functions_map[function]
|
1539
|
+
strict = self._strict_mode_for_tool(function)
|
1540
|
+
if strict:
|
1541
|
+
self.any_strict = True
|
1542
|
+
strict_spec = copy.deepcopy(spec)
|
1543
|
+
format_schema_for_strict(strict_spec.parameters)
|
1544
|
+
else:
|
1545
|
+
strict_spec = spec
|
1546
|
+
|
1547
|
+
return OpenAIToolSpec(
|
1548
|
+
type="function",
|
1549
|
+
strict=strict,
|
1550
|
+
function=strict_spec,
|
1551
|
+
)
|
1552
|
+
|
1553
|
+
tools = [to_maybe_strict_spec(f) for f in self.llm_functions_usable]
|
1554
|
+
force_tool = (
|
1555
|
+
None
|
1556
|
+
if self.llm_function_force is None
|
1557
|
+
else {
|
1558
|
+
"type": "function",
|
1559
|
+
"function": {"name": self.llm_function_force["name"]},
|
1560
|
+
}
|
1561
|
+
)
|
1562
|
+
output_format = None
|
1563
|
+
if self.output_format is not None and self._json_schema_available():
|
1564
|
+
self.any_strict = True
|
1565
|
+
if issubclass(self.output_format, ToolMessage) and not issubclass(
|
1566
|
+
self.output_format, XMLToolMessage
|
1567
|
+
):
|
1568
|
+
spec = self.output_format.llm_function_schema(
|
1569
|
+
request=True,
|
1570
|
+
defaults=self.config.output_format_include_defaults,
|
1571
|
+
)
|
1572
|
+
format_schema_for_strict(spec.parameters)
|
1573
|
+
|
1574
|
+
output_format = OpenAIJsonSchemaSpec(
|
1575
|
+
# We always require that outputs strictly match the schema
|
1576
|
+
strict=True,
|
1577
|
+
function=spec,
|
1578
|
+
)
|
1579
|
+
elif issubclass(self.output_format, BaseModel):
|
1580
|
+
param_spec = self.output_format.schema()
|
1581
|
+
format_schema_for_strict(param_spec)
|
1582
|
+
|
1583
|
+
output_format = OpenAIJsonSchemaSpec(
|
1584
|
+
# We always require that outputs strictly match the schema
|
1585
|
+
strict=True,
|
1586
|
+
function=LLMFunctionSpec(
|
1587
|
+
name="json_output",
|
1588
|
+
description="Strict Json output format.",
|
1589
|
+
parameters=param_spec,
|
1590
|
+
),
|
1591
|
+
)
|
1592
|
+
|
1593
|
+
return functions, fun_call, tools, force_tool, output_format
|
1594
|
+
|
1595
|
+
def llm_response_messages(
|
1596
|
+
self,
|
1597
|
+
messages: List[LLMMessage],
|
1598
|
+
output_len: Optional[int] = None,
|
1599
|
+
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
1600
|
+
) -> ChatDocument:
|
1601
|
+
"""
|
1602
|
+
Respond to a series of messages, e.g. with OpenAI ChatCompletion
|
1603
|
+
Args:
|
1604
|
+
messages: seq of messages (with role, content fields) sent to LLM
|
1605
|
+
output_len: max number of tokens expected in response.
|
1606
|
+
If None, use the LLM's default max_output_tokens.
|
1607
|
+
Returns:
|
1608
|
+
Document (i.e. with fields "content", "metadata")
|
1609
|
+
"""
|
1610
|
+
assert self.config.llm is not None and self.llm is not None
|
1611
|
+
output_len = output_len or self.config.llm.max_output_tokens
|
1612
|
+
streamer = noop_fn
|
1613
|
+
if self.llm.get_stream():
|
1614
|
+
streamer = self.callbacks.start_llm_stream()
|
1615
|
+
self.llm.config.streamer = streamer
|
1616
|
+
with ExitStack() as stack: # for conditionally using rich spinner
|
1617
|
+
if not self.llm.get_stream() and not settings.quiet:
|
1618
|
+
# show rich spinner only if not streaming!
|
1619
|
+
# (Why? b/c the intent of showing a spinner is to "show progress",
|
1620
|
+
# and we don't need to do that when streaming, since
|
1621
|
+
# streaming output already shows progress.)
|
1622
|
+
cm = status(
|
1623
|
+
"LLM responding to messages...",
|
1624
|
+
log_if_quiet=False,
|
1625
|
+
)
|
1626
|
+
stack.enter_context(cm)
|
1627
|
+
if self.llm.get_stream() and not settings.quiet:
|
1628
|
+
console.print(f"[green]{self.indent}", end="")
|
1629
|
+
functions, fun_call, tools, force_tool, output_format = (
|
1630
|
+
self._function_args()
|
1631
|
+
)
|
1632
|
+
assert self.llm is not None
|
1633
|
+
response = self.llm.chat(
|
1634
|
+
messages,
|
1635
|
+
output_len,
|
1636
|
+
tools=tools,
|
1637
|
+
tool_choice=force_tool or tool_choice,
|
1638
|
+
functions=functions,
|
1639
|
+
function_call=fun_call,
|
1640
|
+
response_format=output_format,
|
1641
|
+
)
|
1642
|
+
if self.llm.get_stream():
|
1643
|
+
self.callbacks.finish_llm_stream(
|
1644
|
+
content=str(response),
|
1645
|
+
is_tool=self.has_tool_message_attempt(
|
1646
|
+
ChatDocument.from_LLMResponse(response, displayed=True),
|
1647
|
+
),
|
1648
|
+
)
|
1649
|
+
self.llm.config.streamer = noop_fn
|
1650
|
+
if response.cached:
|
1651
|
+
self.callbacks.cancel_llm_stream()
|
1652
|
+
self._render_llm_response(response)
|
1653
|
+
self.update_token_usage(
|
1654
|
+
response, # .usage attrib is updated!
|
1655
|
+
messages,
|
1656
|
+
self.llm.get_stream(),
|
1657
|
+
chat=True,
|
1658
|
+
print_response_stats=self.config.show_stats and not settings.quiet,
|
1659
|
+
)
|
1660
|
+
chat_doc = ChatDocument.from_LLMResponse(response, displayed=True)
|
1661
|
+
self.oai_tool_calls = response.oai_tool_calls or []
|
1662
|
+
self.oai_tool_id2call.update(
|
1663
|
+
{t.id: t for t in self.oai_tool_calls if t.id is not None}
|
1664
|
+
)
|
1665
|
+
|
1666
|
+
# If using strict output format, parse the output JSON
|
1667
|
+
self._load_output_format(chat_doc)
|
1668
|
+
|
1669
|
+
return chat_doc
|
1670
|
+
|
1671
|
+
async def llm_response_messages_async(
|
1672
|
+
self,
|
1673
|
+
messages: List[LLMMessage],
|
1674
|
+
output_len: Optional[int] = None,
|
1675
|
+
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
1676
|
+
) -> ChatDocument:
|
1677
|
+
"""
|
1678
|
+
Async version of `llm_response_messages`. See there for details.
|
1679
|
+
"""
|
1680
|
+
assert self.config.llm is not None and self.llm is not None
|
1681
|
+
output_len = output_len or self.config.llm.max_output_tokens
|
1682
|
+
functions, fun_call, tools, force_tool, output_format = self._function_args()
|
1683
|
+
assert self.llm is not None
|
1684
|
+
|
1685
|
+
streamer_async = async_noop_fn
|
1686
|
+
if self.llm.get_stream():
|
1687
|
+
streamer_async = await self.callbacks.start_llm_stream_async()
|
1688
|
+
self.llm.config.streamer_async = streamer_async
|
1689
|
+
|
1690
|
+
response = await self.llm.achat(
|
1691
|
+
messages,
|
1692
|
+
output_len,
|
1693
|
+
tools=tools,
|
1694
|
+
tool_choice=force_tool or tool_choice,
|
1695
|
+
functions=functions,
|
1696
|
+
function_call=fun_call,
|
1697
|
+
response_format=output_format,
|
1698
|
+
)
|
1699
|
+
if self.llm.get_stream():
|
1700
|
+
self.callbacks.finish_llm_stream(
|
1701
|
+
content=str(response),
|
1702
|
+
is_tool=self.has_tool_message_attempt(
|
1703
|
+
ChatDocument.from_LLMResponse(response, displayed=True),
|
1704
|
+
),
|
1705
|
+
)
|
1706
|
+
self.llm.config.streamer_async = async_noop_fn
|
1707
|
+
if response.cached:
|
1708
|
+
self.callbacks.cancel_llm_stream()
|
1709
|
+
self._render_llm_response(response)
|
1710
|
+
self.update_token_usage(
|
1711
|
+
response, # .usage attrib is updated!
|
1712
|
+
messages,
|
1713
|
+
self.llm.get_stream(),
|
1714
|
+
chat=True,
|
1715
|
+
print_response_stats=self.config.show_stats and not settings.quiet,
|
1716
|
+
)
|
1717
|
+
chat_doc = ChatDocument.from_LLMResponse(response, displayed=True)
|
1718
|
+
self.oai_tool_calls = response.oai_tool_calls or []
|
1719
|
+
self.oai_tool_id2call.update(
|
1720
|
+
{t.id: t for t in self.oai_tool_calls if t.id is not None}
|
1721
|
+
)
|
1722
|
+
|
1723
|
+
# If using strict output format, parse the output JSON
|
1724
|
+
self._load_output_format(chat_doc)
|
1725
|
+
|
1726
|
+
return chat_doc
|
1727
|
+
|
1728
|
+
def _render_llm_response(
|
1729
|
+
self, response: ChatDocument | LLMResponse, citation_only: bool = False
|
1730
|
+
) -> None:
|
1731
|
+
is_cached = (
|
1732
|
+
response.cached
|
1733
|
+
if isinstance(response, LLMResponse)
|
1734
|
+
else response.metadata.cached
|
1735
|
+
)
|
1736
|
+
if self.llm is None:
|
1737
|
+
return
|
1738
|
+
if not citation_only and (not self.llm.get_stream() or is_cached):
|
1739
|
+
# We would have already displayed the msg "live" ONLY if
|
1740
|
+
# streaming was enabled, AND we did not find a cached response.
|
1741
|
+
# If we are here, it means the response has not yet been displayed.
|
1742
|
+
cached = f"[red]{self.indent}(cached)[/red]" if is_cached else ""
|
1743
|
+
chat_doc = (
|
1744
|
+
response
|
1745
|
+
if isinstance(response, ChatDocument)
|
1746
|
+
else ChatDocument.from_LLMResponse(response, displayed=True)
|
1747
|
+
)
|
1748
|
+
# TODO: prepend TOOL: or OAI-TOOL: if it's a tool-call
|
1749
|
+
if not settings.quiet:
|
1750
|
+
print(cached + "[green]" + escape(str(response)))
|
1751
|
+
self.callbacks.show_llm_response(
|
1752
|
+
content=str(response),
|
1753
|
+
is_tool=self.has_tool_message_attempt(chat_doc),
|
1754
|
+
cached=is_cached,
|
1755
|
+
)
|
1756
|
+
if isinstance(response, LLMResponse):
|
1757
|
+
# we are in the context immediately after an LLM responded,
|
1758
|
+
# we won't have citations yet, so we're done
|
1759
|
+
return
|
1760
|
+
if response.metadata.has_citation:
|
1761
|
+
if not settings.quiet:
|
1762
|
+
print(
|
1763
|
+
"[grey37]SOURCES:\n"
|
1764
|
+
+ escape(response.metadata.source)
|
1765
|
+
+ "[/grey37]"
|
1766
|
+
)
|
1767
|
+
self.callbacks.show_llm_response(
|
1768
|
+
content=str(response.metadata.source),
|
1769
|
+
is_tool=False,
|
1770
|
+
cached=False,
|
1771
|
+
language="text",
|
1772
|
+
)
|
1773
|
+
|
1774
|
+
def _llm_response_temp_context(self, message: str, prompt: str) -> ChatDocument:
|
1775
|
+
"""
|
1776
|
+
Get LLM response to `prompt` (which presumably includes the `message`
|
1777
|
+
somewhere, along with possible large "context" passages),
|
1778
|
+
but only include `message` as the USER message, and not the
|
1779
|
+
full `prompt`, in the message history.
|
1780
|
+
Args:
|
1781
|
+
message: the original, relatively short, user request or query
|
1782
|
+
prompt: the full prompt potentially containing `message` plus context
|
1783
|
+
|
1784
|
+
Returns:
|
1785
|
+
Document object containing the response.
|
1786
|
+
"""
|
1787
|
+
# we explicitly call THIS class's respond method,
|
1788
|
+
# not a derived class's (or else there would be infinite recursion!)
|
1789
|
+
with StreamingIfAllowed(self.llm, self.llm.get_stream()): # type: ignore
|
1790
|
+
answer_doc = cast(ChatDocument, ChatAgent.llm_response(self, prompt))
|
1791
|
+
self.update_last_message(message, role=Role.USER)
|
1792
|
+
return answer_doc
|
1793
|
+
|
1794
|
+
async def _llm_response_temp_context_async(
|
1795
|
+
self, message: str, prompt: str
|
1796
|
+
) -> ChatDocument:
|
1797
|
+
"""
|
1798
|
+
Async version of `_llm_response_temp_context`. See there for details.
|
1799
|
+
"""
|
1800
|
+
# we explicitly call THIS class's respond method,
|
1801
|
+
# not a derived class's (or else there would be infinite recursion!)
|
1802
|
+
with StreamingIfAllowed(self.llm, self.llm.get_stream()): # type: ignore
|
1803
|
+
answer_doc = cast(
|
1804
|
+
ChatDocument,
|
1805
|
+
await ChatAgent.llm_response_async(self, prompt),
|
1806
|
+
)
|
1807
|
+
self.update_last_message(message, role=Role.USER)
|
1808
|
+
return answer_doc
|
1809
|
+
|
1810
|
+
def llm_response_forget(self, message: str) -> ChatDocument:
|
1811
|
+
"""
|
1812
|
+
LLM Response to single message, and restore message_history.
|
1813
|
+
In effect a "one-off" message & response that leaves agent
|
1814
|
+
message history state intact.
|
1815
|
+
|
1816
|
+
Args:
|
1817
|
+
message (str): user message
|
1818
|
+
|
1819
|
+
Returns:
|
1820
|
+
A Document object with the response.
|
1821
|
+
|
1822
|
+
"""
|
1823
|
+
# explicitly call THIS class's respond method,
|
1824
|
+
# not a derived class's (or else there would be infinite recursion!)
|
1825
|
+
n_msgs = len(self.message_history)
|
1826
|
+
with StreamingIfAllowed(self.llm, self.llm.get_stream()): # type: ignore
|
1827
|
+
response = cast(ChatDocument, ChatAgent.llm_response(self, message))
|
1828
|
+
# If there is a response, then we will have two additional
|
1829
|
+
# messages in the message history, i.e. the user message and the
|
1830
|
+
# assistant response. We want to (carefully) remove these two messages.
|
1831
|
+
if len(self.message_history) > n_msgs:
|
1832
|
+
msg = self.message_history.pop()
|
1833
|
+
self._drop_msg_update_tool_calls(msg)
|
1834
|
+
|
1835
|
+
if len(self.message_history) > n_msgs:
|
1836
|
+
msg = self.message_history.pop()
|
1837
|
+
self._drop_msg_update_tool_calls(msg)
|
1838
|
+
|
1839
|
+
# If using strict output format, parse the output JSON
|
1840
|
+
self._load_output_format(response)
|
1841
|
+
|
1842
|
+
return response
|
1843
|
+
|
1844
|
+
async def llm_response_forget_async(self, message: str) -> ChatDocument:
|
1845
|
+
"""
|
1846
|
+
Async version of `llm_response_forget`. See there for details.
|
1847
|
+
"""
|
1848
|
+
# explicitly call THIS class's respond method,
|
1849
|
+
# not a derived class's (or else there would be infinite recursion!)
|
1850
|
+
n_msgs = len(self.message_history)
|
1851
|
+
with StreamingIfAllowed(self.llm, self.llm.get_stream()): # type: ignore
|
1852
|
+
response = cast(
|
1853
|
+
ChatDocument, await ChatAgent.llm_response_async(self, message)
|
1854
|
+
)
|
1855
|
+
# If there is a response, then we will have two additional
|
1856
|
+
# messages in the message history, i.e. the user message and the
|
1857
|
+
# assistant response. We want to (carefully) remove these two messages.
|
1858
|
+
if len(self.message_history) > n_msgs:
|
1859
|
+
msg = self.message_history.pop()
|
1860
|
+
self._drop_msg_update_tool_calls(msg)
|
1861
|
+
|
1862
|
+
if len(self.message_history) > n_msgs:
|
1863
|
+
msg = self.message_history.pop()
|
1864
|
+
self._drop_msg_update_tool_calls(msg)
|
1865
|
+
return response
|
1866
|
+
|
1867
|
+
def chat_num_tokens(self, messages: Optional[List[LLMMessage]] = None) -> int:
|
1868
|
+
"""
|
1869
|
+
Total number of tokens in the message history so far.
|
1870
|
+
|
1871
|
+
Args:
|
1872
|
+
messages: if provided, compute the number of tokens in this list of
|
1873
|
+
messages, rather than the current message history.
|
1874
|
+
Returns:
|
1875
|
+
int: number of tokens in message history
|
1876
|
+
"""
|
1877
|
+
if self.parser is None:
|
1878
|
+
raise ValueError(
|
1879
|
+
"ChatAgent.parser is None. "
|
1880
|
+
"You must set ChatAgent.parser "
|
1881
|
+
"before calling chat_num_tokens()."
|
1882
|
+
)
|
1883
|
+
hist = messages if messages is not None else self.message_history
|
1884
|
+
return sum([self.parser.num_tokens(m.content) for m in hist])
|
1885
|
+
|
1886
|
+
def message_history_str(self, i: Optional[int] = None) -> str:
|
1887
|
+
"""
|
1888
|
+
Return a string representation of the message history
|
1889
|
+
Args:
|
1890
|
+
i: if provided, return only the i-th message when i is postive,
|
1891
|
+
or last k messages when i = -k.
|
1892
|
+
Returns:
|
1893
|
+
"""
|
1894
|
+
if i is None:
|
1895
|
+
return "\n".join([str(m) for m in self.message_history])
|
1896
|
+
elif i > 0:
|
1897
|
+
return str(self.message_history[i])
|
1898
|
+
else:
|
1899
|
+
return "\n".join([str(m) for m in self.message_history[i:]])
|