langroid 0.33.6__py3-none-any.whl → 0.33.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +106 -0
- langroid/agent/__init__.py +41 -0
- langroid/agent/base.py +1983 -0
- langroid/agent/batch.py +398 -0
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +598 -0
- langroid/agent/chat_agent.py +1899 -0
- langroid/agent/chat_document.py +454 -0
- langroid/agent/openai_assistant.py +882 -0
- langroid/agent/special/__init__.py +59 -0
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +656 -0
- langroid/agent/special/arangodb/system_messages.py +186 -0
- langroid/agent/special/arangodb/tools.py +107 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/doc_chat_agent.py +1466 -0
- langroid/agent/special/lance_doc_chat_agent.py +262 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +198 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
- langroid/agent/special/lance_tools.py +61 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +56 -0
- langroid/agent/special/sql/__init__.py +17 -0
- langroid/agent/special/sql/sql_chat_agent.py +654 -0
- langroid/agent/special/sql/utils/__init__.py +21 -0
- langroid/agent/special/sql/utils/description_extractors.py +190 -0
- langroid/agent/special/sql/utils/populate_metadata.py +85 -0
- langroid/agent/special/sql/utils/system_message.py +35 -0
- langroid/agent/special/sql/utils/tools.py +64 -0
- langroid/agent/special/table_chat_agent.py +263 -0
- langroid/agent/task.py +2095 -0
- langroid/agent/tool_message.py +393 -0
- langroid/agent/tools/__init__.py +38 -0
- langroid/agent/tools/duckduckgo_search_tool.py +50 -0
- langroid/agent/tools/file_tools.py +234 -0
- langroid/agent/tools/google_search_tool.py +39 -0
- langroid/agent/tools/metaphor_search_tool.py +68 -0
- langroid/agent/tools/orchestration.py +303 -0
- langroid/agent/tools/recipient_tool.py +235 -0
- langroid/agent/tools/retrieval_tool.py +32 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/agent/tools/segment_extract_tool.py +41 -0
- langroid/agent/xml_tool_message.py +382 -0
- langroid/cachedb/__init__.py +17 -0
- langroid/cachedb/base.py +58 -0
- langroid/cachedb/momento_cachedb.py +108 -0
- langroid/cachedb/redis_cachedb.py +153 -0
- langroid/embedding_models/__init__.py +39 -0
- langroid/embedding_models/base.py +74 -0
- langroid/embedding_models/models.py +461 -0
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/exceptions.py +71 -0
- langroid/language_models/__init__.py +53 -0
- langroid/language_models/azure_openai.py +153 -0
- langroid/language_models/base.py +678 -0
- langroid/language_models/config.py +18 -0
- langroid/language_models/mock_lm.py +124 -0
- langroid/language_models/openai_gpt.py +1964 -0
- langroid/language_models/prompt_formatter/__init__.py +16 -0
- langroid/language_models/prompt_formatter/base.py +40 -0
- langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
- langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
- langroid/language_models/utils.py +151 -0
- langroid/mytypes.py +84 -0
- langroid/parsing/__init__.py +52 -0
- langroid/parsing/agent_chats.py +38 -0
- langroid/parsing/code_parser.py +121 -0
- langroid/parsing/document_parser.py +718 -0
- langroid/parsing/para_sentence_split.py +62 -0
- langroid/parsing/parse_json.py +155 -0
- langroid/parsing/parser.py +313 -0
- langroid/parsing/repo_loader.py +790 -0
- langroid/parsing/routing.py +36 -0
- langroid/parsing/search.py +275 -0
- langroid/parsing/spider.py +102 -0
- langroid/parsing/table_loader.py +94 -0
- langroid/parsing/url_loader.py +111 -0
- langroid/parsing/urls.py +273 -0
- langroid/parsing/utils.py +373 -0
- langroid/parsing/web_search.py +156 -0
- langroid/prompts/__init__.py +9 -0
- langroid/prompts/dialog.py +17 -0
- langroid/prompts/prompts_config.py +5 -0
- langroid/prompts/templates.py +141 -0
- langroid/pydantic_v1/__init__.py +10 -0
- langroid/pydantic_v1/main.py +4 -0
- langroid/utils/__init__.py +19 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +98 -0
- langroid/utils/constants.py +30 -0
- langroid/utils/git_utils.py +252 -0
- langroid/utils/globals.py +49 -0
- langroid/utils/logging.py +135 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/output/__init__.py +20 -0
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +99 -0
- langroid/utils/output/status.py +40 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +602 -0
- langroid/utils/system.py +286 -0
- langroid/utils/types.py +93 -0
- langroid/vector_store/__init__.py +50 -0
- langroid/vector_store/base.py +359 -0
- langroid/vector_store/chromadb.py +214 -0
- langroid/vector_store/lancedb.py +406 -0
- langroid/vector_store/meilisearch.py +299 -0
- langroid/vector_store/momento.py +278 -0
- langroid/vector_store/qdrantdb.py +468 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
- langroid-0.33.7.dist-info/RECORD +127 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
- langroid-0.33.6.dist-info/RECORD +0 -7
- langroid-0.33.6.dist-info/entry_points.txt +0 -4
- pyproject.toml +0 -356
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,882 @@
|
|
1
|
+
import asyncio
|
2
|
+
import json
|
3
|
+
|
4
|
+
# setup logger
|
5
|
+
import logging
|
6
|
+
import time
|
7
|
+
from enum import Enum
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, cast, no_type_check
|
9
|
+
|
10
|
+
import openai
|
11
|
+
from openai.types.beta import Assistant, Thread
|
12
|
+
from openai.types.beta.assistant_update_params import (
|
13
|
+
ToolResources,
|
14
|
+
ToolResourcesCodeInterpreter,
|
15
|
+
)
|
16
|
+
from openai.types.beta.threads import Message, Run
|
17
|
+
from openai.types.beta.threads.runs import RunStep
|
18
|
+
from rich import print
|
19
|
+
|
20
|
+
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
21
|
+
from langroid.agent.chat_document import ChatDocument
|
22
|
+
from langroid.agent.tool_message import ToolMessage
|
23
|
+
from langroid.language_models.base import LLMFunctionCall, LLMMessage, LLMResponse, Role
|
24
|
+
from langroid.language_models.openai_gpt import (
|
25
|
+
OpenAIChatModel,
|
26
|
+
OpenAIGPT,
|
27
|
+
OpenAIGPTConfig,
|
28
|
+
)
|
29
|
+
from langroid.pydantic_v1 import BaseModel
|
30
|
+
from langroid.utils.configuration import settings
|
31
|
+
from langroid.utils.system import generate_user_id, update_hash
|
32
|
+
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
class ToolType(str, Enum):
|
37
|
+
RETRIEVAL = "file_search"
|
38
|
+
CODE_INTERPRETER = "code_interpreter"
|
39
|
+
FUNCTION = "function"
|
40
|
+
|
41
|
+
|
42
|
+
class AssistantTool(BaseModel):
|
43
|
+
type: ToolType
|
44
|
+
function: Dict[str, Any] | None = None
|
45
|
+
|
46
|
+
def dct(self) -> Dict[str, Any]:
|
47
|
+
d = super().dict()
|
48
|
+
d["type"] = d["type"].value
|
49
|
+
if self.type != ToolType.FUNCTION:
|
50
|
+
d.pop("function")
|
51
|
+
return d
|
52
|
+
|
53
|
+
|
54
|
+
class AssistantToolCall(BaseModel):
|
55
|
+
id: str
|
56
|
+
type: ToolType
|
57
|
+
function: LLMFunctionCall
|
58
|
+
|
59
|
+
|
60
|
+
class RunStatus(str, Enum):
|
61
|
+
QUEUED = "queued"
|
62
|
+
IN_PROGRESS = "in_progress"
|
63
|
+
COMPLETED = "completed"
|
64
|
+
REQUIRES_ACTION = "requires_action"
|
65
|
+
EXPIRED = "expired"
|
66
|
+
CANCELLING = "cancelling"
|
67
|
+
CANCELLED = "cancelled"
|
68
|
+
FAILED = "failed"
|
69
|
+
TIMEOUT = "timeout"
|
70
|
+
|
71
|
+
|
72
|
+
class OpenAIAssistantConfig(ChatAgentConfig):
|
73
|
+
use_cached_assistant: bool = False # set in script via user dialog
|
74
|
+
assistant_id: str | None = None
|
75
|
+
use_tools = False
|
76
|
+
use_functions_api = True
|
77
|
+
use_cached_thread: bool = False # set in script via user dialog
|
78
|
+
thread_id: str | None = None
|
79
|
+
# set to True once we can add Assistant msgs in threads
|
80
|
+
cache_responses: bool = True
|
81
|
+
timeout: int = 30 # can be different from llm.timeout
|
82
|
+
llm = OpenAIGPTConfig(chat_model=OpenAIChatModel.GPT4o)
|
83
|
+
tools: List[AssistantTool] = []
|
84
|
+
files: List[str] = []
|
85
|
+
|
86
|
+
|
87
|
+
class OpenAIAssistant(ChatAgent):
|
88
|
+
"""
|
89
|
+
A ChatAgent powered by OpenAI Assistant API:
|
90
|
+
mainly, in `llm_response` method, we avoid maintaining conversation state,
|
91
|
+
and instead let the Assistant API do it for us.
|
92
|
+
Also handles persistent storage of Assistant and Threads:
|
93
|
+
stores their ids (for given user, org) in a cache, and
|
94
|
+
reuses them based on config.use_cached_assistant and config.use_cached_thread.
|
95
|
+
|
96
|
+
This class can be used as a drop-in replacement for ChatAgent.
|
97
|
+
"""
|
98
|
+
|
99
|
+
def __init__(self, config: OpenAIAssistantConfig):
|
100
|
+
super().__init__(config)
|
101
|
+
self.config: OpenAIAssistantConfig = config
|
102
|
+
self.llm: OpenAIGPT = OpenAIGPT(self.config.llm)
|
103
|
+
assert (
|
104
|
+
self.llm.cache is not None
|
105
|
+
), "OpenAIAssistant requires a cache to store Assistant and Thread ids"
|
106
|
+
|
107
|
+
if not isinstance(self.llm.client, openai.OpenAI):
|
108
|
+
raise ValueError("Client must be OpenAI")
|
109
|
+
# handles for various entities and methods
|
110
|
+
self.client: openai.OpenAI = self.llm.client
|
111
|
+
self.runs = self.client.beta.threads.runs
|
112
|
+
self.threads = self.client.beta.threads
|
113
|
+
self.thread_messages = self.client.beta.threads.messages
|
114
|
+
self.assistants = self.client.beta.assistants
|
115
|
+
# which tool_ids are awaiting output submissions
|
116
|
+
self.pending_tool_ids: List[str] = []
|
117
|
+
self.cached_tool_ids: List[str] = []
|
118
|
+
|
119
|
+
self.thread: Thread | None = None
|
120
|
+
self.assistant: Assistant | None = None
|
121
|
+
self.run: Run | None = None
|
122
|
+
|
123
|
+
self._maybe_create_assistant(self.config.assistant_id)
|
124
|
+
self._maybe_create_thread(self.config.thread_id)
|
125
|
+
self._cache_store()
|
126
|
+
|
127
|
+
self.add_assistant_files(self.config.files)
|
128
|
+
self.add_assistant_tools(self.config.tools)
|
129
|
+
|
130
|
+
def add_assistant_files(self, files: List[str]) -> None:
|
131
|
+
"""Add file_ids to assistant"""
|
132
|
+
if self.assistant is None:
|
133
|
+
raise ValueError("Assistant is None")
|
134
|
+
self.files = [
|
135
|
+
self.client.files.create(file=open(f, "rb"), purpose="assistants")
|
136
|
+
for f in files
|
137
|
+
]
|
138
|
+
self.config.files = list(set(self.config.files + files))
|
139
|
+
self.assistant = self.assistants.update(
|
140
|
+
self.assistant.id,
|
141
|
+
tool_resources=ToolResources(
|
142
|
+
code_interpreter=ToolResourcesCodeInterpreter(
|
143
|
+
file_ids=[f.id for f in self.files],
|
144
|
+
),
|
145
|
+
),
|
146
|
+
)
|
147
|
+
|
148
|
+
def add_assistant_tools(self, tools: List[AssistantTool]) -> None:
|
149
|
+
"""Add tools to assistant"""
|
150
|
+
if self.assistant is None:
|
151
|
+
raise ValueError("Assistant is None")
|
152
|
+
all_tool_dicts = [t.dct() for t in self.config.tools]
|
153
|
+
for t in tools:
|
154
|
+
if t.dct() not in all_tool_dicts:
|
155
|
+
self.config.tools.append(t)
|
156
|
+
self.assistant = self.assistants.update(
|
157
|
+
self.assistant.id,
|
158
|
+
tools=[tool.dct() for tool in self.config.tools], # type: ignore
|
159
|
+
)
|
160
|
+
|
161
|
+
def enable_message(
|
162
|
+
self,
|
163
|
+
message_class: Optional[Type[ToolMessage] | List[Type[ToolMessage]]],
|
164
|
+
use: bool = True,
|
165
|
+
handle: bool = True,
|
166
|
+
force: bool = False,
|
167
|
+
require_recipient: bool = False,
|
168
|
+
include_defaults: bool = True,
|
169
|
+
) -> None:
|
170
|
+
"""Override ChatAgent's method: extract the function-related args.
|
171
|
+
See that method for details. But specifically about the `include_defaults` arg:
|
172
|
+
Normally the OpenAI completion API ignores these fields, but the Assistant
|
173
|
+
fn-calling seems to pay attn to these, and if we don't want this,
|
174
|
+
we should set this to False.
|
175
|
+
"""
|
176
|
+
if message_class is not None and isinstance(message_class, list):
|
177
|
+
for msg_class in message_class:
|
178
|
+
self.enable_message(
|
179
|
+
msg_class,
|
180
|
+
use=use,
|
181
|
+
handle=handle,
|
182
|
+
force=force,
|
183
|
+
require_recipient=require_recipient,
|
184
|
+
include_defaults=include_defaults,
|
185
|
+
)
|
186
|
+
return
|
187
|
+
super().enable_message(
|
188
|
+
message_class,
|
189
|
+
use=use,
|
190
|
+
handle=handle,
|
191
|
+
force=force,
|
192
|
+
require_recipient=require_recipient,
|
193
|
+
include_defaults=include_defaults,
|
194
|
+
)
|
195
|
+
if message_class is None or not use:
|
196
|
+
# no specific msg class, or
|
197
|
+
# we are not enabling USAGE/GENERATION of this tool/fn,
|
198
|
+
# then there's no need to attach the fn to the assistant
|
199
|
+
# (HANDLING the fn will still work via self.agent_response)
|
200
|
+
return
|
201
|
+
if self.config.use_tools:
|
202
|
+
sys_msg = self._create_system_and_tools_message()
|
203
|
+
self.set_system_message(sys_msg.content)
|
204
|
+
if not self.config.use_functions_api:
|
205
|
+
return
|
206
|
+
functions, _, _, _, _ = self._function_args()
|
207
|
+
if functions is None:
|
208
|
+
return
|
209
|
+
# add the functions to the assistant:
|
210
|
+
if self.assistant is None:
|
211
|
+
raise ValueError("Assistant is None")
|
212
|
+
tools = self.assistant.tools
|
213
|
+
tools.extend(
|
214
|
+
[
|
215
|
+
{
|
216
|
+
"type": "function", # type: ignore
|
217
|
+
"function": f.dict(),
|
218
|
+
}
|
219
|
+
for f in functions
|
220
|
+
]
|
221
|
+
)
|
222
|
+
self.assistant = self.assistants.update(
|
223
|
+
self.assistant.id,
|
224
|
+
tools=tools, # type: ignore
|
225
|
+
)
|
226
|
+
|
227
|
+
def _cache_thread_key(self) -> str:
|
228
|
+
"""Key to use for caching or retrieving thread id"""
|
229
|
+
org = self.client.organization or ""
|
230
|
+
uid = generate_user_id(org)
|
231
|
+
name = self.config.name
|
232
|
+
return "Thread:" + name + ":" + uid
|
233
|
+
|
234
|
+
def _cache_assistant_key(self) -> str:
|
235
|
+
"""Key to use for caching or retrieving assistant id"""
|
236
|
+
org = self.client.organization or ""
|
237
|
+
uid = generate_user_id(org)
|
238
|
+
name = self.config.name
|
239
|
+
return "Assistant:" + name + ":" + uid
|
240
|
+
|
241
|
+
@no_type_check
|
242
|
+
def _cache_messages_key(self) -> str:
|
243
|
+
"""Key to use when caching or retrieving thread messages"""
|
244
|
+
if self.thread is None:
|
245
|
+
raise ValueError("Thread is None")
|
246
|
+
return "Messages:" + self.thread.metadata["hash"]
|
247
|
+
|
248
|
+
@no_type_check
|
249
|
+
def _cache_thread_lookup(self) -> str | None:
|
250
|
+
"""Try to retrieve cached thread_id associated with
|
251
|
+
this user + machine + organization"""
|
252
|
+
key = self._cache_thread_key()
|
253
|
+
if self.llm.cache is None:
|
254
|
+
return None
|
255
|
+
return self.llm.cache.retrieve(key)
|
256
|
+
|
257
|
+
@no_type_check
|
258
|
+
def _cache_assistant_lookup(self) -> str | None:
|
259
|
+
"""Try to retrieve cached assistant_id associated with
|
260
|
+
this user + machine + organization"""
|
261
|
+
if self.llm.cache is None:
|
262
|
+
return None
|
263
|
+
key = self._cache_assistant_key()
|
264
|
+
return self.llm.cache.retrieve(key)
|
265
|
+
|
266
|
+
@no_type_check
|
267
|
+
def _cache_messages_lookup(self) -> LLMResponse | None:
|
268
|
+
"""Try to retrieve cached response for the message-list-hash"""
|
269
|
+
if not settings.cache or self.llm.cache is None:
|
270
|
+
return None
|
271
|
+
key = self._cache_messages_key()
|
272
|
+
cached_dict = self.llm.cache.retrieve(key)
|
273
|
+
if cached_dict is None:
|
274
|
+
return None
|
275
|
+
return LLMResponse.parse_obj(cached_dict)
|
276
|
+
|
277
|
+
def _cache_store(self) -> None:
|
278
|
+
"""
|
279
|
+
Cache the assistant_id, thread_id associated with
|
280
|
+
this user + machine + organization
|
281
|
+
"""
|
282
|
+
if self.llm.cache is None:
|
283
|
+
return
|
284
|
+
if self.thread is None or self.assistant is None:
|
285
|
+
raise ValueError("Thread or Assistant is None")
|
286
|
+
thread_key = self._cache_thread_key()
|
287
|
+
self.llm.cache.store(thread_key, self.thread.id)
|
288
|
+
|
289
|
+
assistant_key = self._cache_assistant_key()
|
290
|
+
self.llm.cache.store(assistant_key, self.assistant.id)
|
291
|
+
|
292
|
+
@staticmethod
|
293
|
+
def thread_msg_to_llm_msg(msg: Message) -> LLMMessage:
|
294
|
+
"""
|
295
|
+
Convert a Message to an LLMMessage
|
296
|
+
"""
|
297
|
+
return LLMMessage(
|
298
|
+
content=msg.content[0].text.value, # type: ignore
|
299
|
+
role=Role(msg.role),
|
300
|
+
)
|
301
|
+
|
302
|
+
def _update_messages_hash(self, msg: Message | LLMMessage) -> None:
|
303
|
+
"""
|
304
|
+
Update the hash-state in the thread with the given message.
|
305
|
+
"""
|
306
|
+
if self.thread is None:
|
307
|
+
raise ValueError("Thread is None")
|
308
|
+
if isinstance(msg, Message):
|
309
|
+
llm_msg = self.thread_msg_to_llm_msg(msg)
|
310
|
+
else:
|
311
|
+
llm_msg = msg
|
312
|
+
hash = self.thread.metadata["hash"] # type: ignore
|
313
|
+
most_recent_msg = llm_msg.content
|
314
|
+
most_recent_role = llm_msg.role
|
315
|
+
hash = update_hash(hash, f"{most_recent_role}:{most_recent_msg}")
|
316
|
+
# TODO is this inplace?
|
317
|
+
self.thread = self.threads.update(
|
318
|
+
self.thread.id,
|
319
|
+
metadata={
|
320
|
+
"hash": hash,
|
321
|
+
},
|
322
|
+
)
|
323
|
+
assert self.thread.metadata["hash"] == hash # type: ignore
|
324
|
+
|
325
|
+
def _maybe_create_thread(self, id: str | None = None) -> None:
|
326
|
+
"""Retrieve or create a thread if one does not exist,
|
327
|
+
or retrieve it from cache"""
|
328
|
+
if id is not None:
|
329
|
+
try:
|
330
|
+
self.thread = self.threads.retrieve(thread_id=id)
|
331
|
+
except Exception:
|
332
|
+
logger.warning(
|
333
|
+
f"""
|
334
|
+
Could not retrieve thread with id {id},
|
335
|
+
so creating a new one.
|
336
|
+
"""
|
337
|
+
)
|
338
|
+
self.thread = None
|
339
|
+
if self.thread is not None:
|
340
|
+
return
|
341
|
+
cached = self._cache_thread_lookup()
|
342
|
+
if cached is not None:
|
343
|
+
if self.config.use_cached_thread:
|
344
|
+
self.thread = self.client.beta.threads.retrieve(thread_id=cached)
|
345
|
+
else:
|
346
|
+
logger.warning(
|
347
|
+
f"""
|
348
|
+
Found cached thread id {cached},
|
349
|
+
but config.use_cached_thread = False, so deleting it.
|
350
|
+
"""
|
351
|
+
)
|
352
|
+
try:
|
353
|
+
self.client.beta.threads.delete(thread_id=cached)
|
354
|
+
except Exception:
|
355
|
+
logger.warning(
|
356
|
+
f"""
|
357
|
+
Could not delete thread with id {cached}, ignoring.
|
358
|
+
"""
|
359
|
+
)
|
360
|
+
if self.llm.cache is not None:
|
361
|
+
self.llm.cache.delete_keys([self._cache_thread_key()])
|
362
|
+
if self.thread is None:
|
363
|
+
if self.assistant is None:
|
364
|
+
raise ValueError("Assistant is None")
|
365
|
+
self.thread = self.client.beta.threads.create()
|
366
|
+
hash_key_str = (
|
367
|
+
(self.assistant.instructions or "")
|
368
|
+
+ str(self.config.use_tools)
|
369
|
+
+ str(self.config.use_functions_api)
|
370
|
+
)
|
371
|
+
hash_hex = update_hash(None, s=hash_key_str)
|
372
|
+
self.thread = self.threads.update(
|
373
|
+
self.thread.id,
|
374
|
+
metadata={
|
375
|
+
"hash": hash_hex,
|
376
|
+
},
|
377
|
+
)
|
378
|
+
assert self.thread.metadata["hash"] == hash_hex # type: ignore
|
379
|
+
|
380
|
+
def _maybe_create_assistant(self, id: str | None = None) -> None:
|
381
|
+
"""Retrieve or create an assistant if one does not exist,
|
382
|
+
or retrieve it from cache"""
|
383
|
+
if id is not None:
|
384
|
+
try:
|
385
|
+
self.assistant = self.assistants.retrieve(assistant_id=id)
|
386
|
+
except Exception:
|
387
|
+
logger.warning(
|
388
|
+
f"""
|
389
|
+
Could not retrieve assistant with id {id},
|
390
|
+
so creating a new one.
|
391
|
+
"""
|
392
|
+
)
|
393
|
+
self.assistant = None
|
394
|
+
if self.assistant is not None:
|
395
|
+
return
|
396
|
+
cached = self._cache_assistant_lookup()
|
397
|
+
if cached is not None:
|
398
|
+
if self.config.use_cached_assistant:
|
399
|
+
self.assistant = self.client.beta.assistants.retrieve(
|
400
|
+
assistant_id=cached
|
401
|
+
)
|
402
|
+
else:
|
403
|
+
logger.warning(
|
404
|
+
f"""
|
405
|
+
Found cached assistant id {cached},
|
406
|
+
but config.use_cached_assistant = False, so deleting it.
|
407
|
+
"""
|
408
|
+
)
|
409
|
+
try:
|
410
|
+
self.client.beta.assistants.delete(assistant_id=cached)
|
411
|
+
except Exception:
|
412
|
+
logger.warning(
|
413
|
+
f"""
|
414
|
+
Could not delete assistant with id {cached}, ignoring.
|
415
|
+
"""
|
416
|
+
)
|
417
|
+
if self.llm.cache is not None:
|
418
|
+
self.llm.cache.delete_keys([self._cache_assistant_key()])
|
419
|
+
if self.assistant is None:
|
420
|
+
self.assistant = self.client.beta.assistants.create(
|
421
|
+
name=self.config.name,
|
422
|
+
instructions=self.config.system_message,
|
423
|
+
tools=[],
|
424
|
+
model=self.config.llm.chat_model,
|
425
|
+
)
|
426
|
+
|
427
|
+
def _get_run(self) -> Run:
|
428
|
+
"""Retrieve the run object associated with this thread and run,
|
429
|
+
to see its latest status.
|
430
|
+
"""
|
431
|
+
if self.thread is None or self.run is None:
|
432
|
+
raise ValueError("Thread or Run is None")
|
433
|
+
return self.runs.retrieve(thread_id=self.thread.id, run_id=self.run.id)
|
434
|
+
|
435
|
+
def _get_run_steps(self) -> List[RunStep]:
|
436
|
+
if self.thread is None or self.run is None:
|
437
|
+
raise ValueError("Thread or Run is None")
|
438
|
+
result = self.runs.steps.list(thread_id=self.thread.id, run_id=self.run.id)
|
439
|
+
if result is None:
|
440
|
+
return []
|
441
|
+
return result.data
|
442
|
+
|
443
|
+
def _get_code_logs(self) -> List[Tuple[str, str]]:
|
444
|
+
"""
|
445
|
+
Get list of input, output strings from code logs
|
446
|
+
"""
|
447
|
+
run_steps = self._get_run_steps()
|
448
|
+
# each step may have multiple tool-calls,
|
449
|
+
# each tool-call may have multiple outputs
|
450
|
+
tool_calls = [ # list of list of tool-calls
|
451
|
+
s.step_details.tool_calls
|
452
|
+
for s in run_steps
|
453
|
+
if s.step_details is not None and hasattr(s.step_details, "tool_calls")
|
454
|
+
]
|
455
|
+
code_logs = []
|
456
|
+
for tcl in tool_calls: # each tool-call-list
|
457
|
+
for tc in tcl:
|
458
|
+
if tc is None or tc.type != ToolType.CODE_INTERPRETER:
|
459
|
+
continue
|
460
|
+
io = tc.code_interpreter # type: ignore
|
461
|
+
input = io.input
|
462
|
+
# TODO for CodeInterpreterOutputImage, there is no "logs"
|
463
|
+
# revisit when we handle images.
|
464
|
+
outputs = "\n\n".join(
|
465
|
+
o.logs
|
466
|
+
for o in io.outputs
|
467
|
+
if o.type == "logs" and hasattr(o, "logs")
|
468
|
+
)
|
469
|
+
code_logs.append((input, outputs))
|
470
|
+
# return the reversed list, since they are stored in reverse chron order
|
471
|
+
return code_logs[::-1]
|
472
|
+
|
473
|
+
def _get_code_logs_str(self) -> str:
|
474
|
+
"""
|
475
|
+
Get string representation of code logs
|
476
|
+
"""
|
477
|
+
code_logs = self._get_code_logs()
|
478
|
+
return "\n\n".join(
|
479
|
+
f"INPUT:\n{input}\n\nOUTPUT:\n{output}" for input, output in code_logs
|
480
|
+
)
|
481
|
+
|
482
|
+
def _add_thread_message(self, msg: str, role: Role) -> None:
|
483
|
+
"""
|
484
|
+
Add a message with the given role to the thread.
|
485
|
+
Args:
|
486
|
+
msg (str): message to add
|
487
|
+
role (Role): role of the message
|
488
|
+
"""
|
489
|
+
if self.thread is None:
|
490
|
+
raise ValueError("Thread is None")
|
491
|
+
# CACHING TRICK! Since the API only allows inserting USER messages,
|
492
|
+
# we prepend the role to the message, so that we can store ASSISTANT msgs
|
493
|
+
# as well! When the LLM sees the thread messages, they will contain
|
494
|
+
# the right sequence of alternating roles, so that it has no trouble
|
495
|
+
# responding when it is its turn.
|
496
|
+
msg = f"{role.value.upper()}: {msg}"
|
497
|
+
thread_msg = self.thread_messages.create(
|
498
|
+
content=msg,
|
499
|
+
thread_id=self.thread.id,
|
500
|
+
# We ALWAYS store user role since only user role allowed currently
|
501
|
+
role=Role.USER.value,
|
502
|
+
)
|
503
|
+
self._update_messages_hash(thread_msg)
|
504
|
+
|
505
|
+
def _get_thread_messages(self, n: int = 20) -> List[LLMMessage]:
|
506
|
+
"""
|
507
|
+
Get the last n messages in the thread, in cleaned-up form (LLMMessage).
|
508
|
+
Args:
|
509
|
+
n (int): number of messages to retrieve
|
510
|
+
Returns:
|
511
|
+
List[LLMMessage]: list of messages
|
512
|
+
"""
|
513
|
+
if self.thread is None:
|
514
|
+
raise ValueError("Thread is None")
|
515
|
+
result = self.thread_messages.list(
|
516
|
+
thread_id=self.thread.id,
|
517
|
+
limit=n,
|
518
|
+
)
|
519
|
+
num = len(result.data)
|
520
|
+
if result.has_more and num < n: # type: ignore
|
521
|
+
logger.warning(f"Retrieving last {num} messages, but there are more")
|
522
|
+
thread_msgs = result.data
|
523
|
+
for msg in thread_msgs:
|
524
|
+
self.process_citations(msg)
|
525
|
+
return [
|
526
|
+
LLMMessage(
|
527
|
+
# TODO: could be image, deal with it later
|
528
|
+
content=m.content[0].text.value, # type: ignore
|
529
|
+
role=Role(m.role),
|
530
|
+
)
|
531
|
+
for m in thread_msgs
|
532
|
+
]
|
533
|
+
|
534
|
+
def _wait_for_run(
|
535
|
+
self,
|
536
|
+
until_not: List[RunStatus] = [RunStatus.QUEUED, RunStatus.IN_PROGRESS],
|
537
|
+
until: List[RunStatus] = [],
|
538
|
+
timeout: int = 30,
|
539
|
+
) -> RunStatus:
|
540
|
+
"""
|
541
|
+
Poll the run until it either:
|
542
|
+
- EXITs the statuses specified in `until_not`, or
|
543
|
+
- ENTERs the statuses specified in `until`, or
|
544
|
+
"""
|
545
|
+
if self.thread is None or self.run is None:
|
546
|
+
raise ValueError("Thread or Run is None")
|
547
|
+
while True:
|
548
|
+
run = self._get_run()
|
549
|
+
if run.status not in until_not or run.status in until:
|
550
|
+
return cast(RunStatus, run.status)
|
551
|
+
time.sleep(1)
|
552
|
+
timeout -= 1
|
553
|
+
if timeout <= 0:
|
554
|
+
return cast(RunStatus, RunStatus.TIMEOUT)
|
555
|
+
|
556
|
+
async def _wait_for_run_async(
|
557
|
+
self,
|
558
|
+
until_not: List[RunStatus] = [RunStatus.QUEUED, RunStatus.IN_PROGRESS],
|
559
|
+
until: List[RunStatus] = [],
|
560
|
+
timeout: int = 30,
|
561
|
+
) -> RunStatus:
|
562
|
+
"""Async version of _wait_for_run"""
|
563
|
+
if self.thread is None or self.run is None:
|
564
|
+
raise ValueError("Thread or Run is None")
|
565
|
+
while True:
|
566
|
+
run = self._get_run()
|
567
|
+
if run.status not in until_not or run.status in until:
|
568
|
+
return cast(RunStatus, run.status)
|
569
|
+
await asyncio.sleep(1)
|
570
|
+
timeout -= 1
|
571
|
+
if timeout <= 0:
|
572
|
+
return cast(RunStatus, RunStatus.TIMEOUT)
|
573
|
+
|
574
|
+
def set_system_message(self, msg: str) -> None:
|
575
|
+
"""
|
576
|
+
Override ChatAgent's method.
|
577
|
+
The Task may use this method to set the system message
|
578
|
+
of the chat assistant.
|
579
|
+
"""
|
580
|
+
super().set_system_message(msg)
|
581
|
+
if self.assistant is None:
|
582
|
+
raise ValueError("Assistant is None")
|
583
|
+
self.assistant = self.assistants.update(self.assistant.id, instructions=msg)
|
584
|
+
|
585
|
+
def _start_run(self) -> None:
|
586
|
+
"""
|
587
|
+
Run the assistant on the thread.
|
588
|
+
"""
|
589
|
+
if self.thread is None or self.assistant is None:
|
590
|
+
raise ValueError("Thread or Assistant is None")
|
591
|
+
self.run = self.runs.create(
|
592
|
+
thread_id=self.thread.id,
|
593
|
+
assistant_id=self.assistant.id,
|
594
|
+
)
|
595
|
+
|
596
|
+
def _run_result(self) -> LLMResponse:
|
597
|
+
"""Result from run completed on the thread."""
|
598
|
+
status = self._wait_for_run(
|
599
|
+
timeout=self.config.timeout,
|
600
|
+
)
|
601
|
+
return self._process_run_result(status)
|
602
|
+
|
603
|
+
async def _run_result_async(self) -> LLMResponse:
|
604
|
+
"""(Async) Result from run completed on the thread."""
|
605
|
+
status = await self._wait_for_run_async(
|
606
|
+
timeout=self.config.timeout,
|
607
|
+
)
|
608
|
+
return self._process_run_result(status)
|
609
|
+
|
610
|
+
def _process_run_result(self, status: RunStatus) -> LLMResponse:
|
611
|
+
"""Process the result of the run."""
|
612
|
+
function_call: LLMFunctionCall | None = None
|
613
|
+
response = ""
|
614
|
+
tool_id = ""
|
615
|
+
# IMPORTANT: FIRST save hash key to store result,
|
616
|
+
# before it gets updated with the response
|
617
|
+
key = self._cache_messages_key()
|
618
|
+
if status == RunStatus.TIMEOUT:
|
619
|
+
logger.warning("Timeout waiting for run to complete, return empty string")
|
620
|
+
elif status == RunStatus.COMPLETED:
|
621
|
+
messages = self._get_thread_messages(n=1)
|
622
|
+
response = messages[0].content
|
623
|
+
# update hash to include the response.
|
624
|
+
self._update_messages_hash(messages[0])
|
625
|
+
elif status == RunStatus.REQUIRES_ACTION:
|
626
|
+
tool_calls = self._parse_run_required_action()
|
627
|
+
# pick the FIRST tool call with type "function"
|
628
|
+
tool_call_fn = [t for t in tool_calls if t.type == ToolType.FUNCTION][0]
|
629
|
+
# TODO Handling only first tool/fn call for now
|
630
|
+
# revisit later: multi-tools affects the task.run() loop.
|
631
|
+
function_call = tool_call_fn.function
|
632
|
+
tool_id = tool_call_fn.id
|
633
|
+
result = LLMResponse(
|
634
|
+
message=response,
|
635
|
+
tool_id=tool_id,
|
636
|
+
function_call=function_call,
|
637
|
+
usage=None, # TODO
|
638
|
+
cached=False, # TODO - revisit when able to insert Assistant responses
|
639
|
+
)
|
640
|
+
if self.llm.cache is not None:
|
641
|
+
self.llm.cache.store(key, result.dict())
|
642
|
+
return result
|
643
|
+
|
644
|
+
def _parse_run_required_action(self) -> List[AssistantToolCall]:
|
645
|
+
"""
|
646
|
+
Parse the required_action field of the run, i.e. get the list of tool calls.
|
647
|
+
Currently only tool calls are supported.
|
648
|
+
"""
|
649
|
+
# see https://platform.openai.com/docs/assistants/tools/function-calling
|
650
|
+
run = self._get_run()
|
651
|
+
if run.status != RunStatus.REQUIRES_ACTION: # type: ignore
|
652
|
+
return []
|
653
|
+
|
654
|
+
if (action := run.required_action.type) != "submit_tool_outputs":
|
655
|
+
raise ValueError(f"Unexpected required_action type {action}")
|
656
|
+
tool_calls = run.required_action.submit_tool_outputs.tool_calls
|
657
|
+
return [
|
658
|
+
AssistantToolCall(
|
659
|
+
id=tool_call.id,
|
660
|
+
type=ToolType(tool_call.type),
|
661
|
+
function=LLMFunctionCall.from_dict(tool_call.function.model_dump()),
|
662
|
+
)
|
663
|
+
for tool_call in tool_calls
|
664
|
+
]
|
665
|
+
|
666
|
+
def _submit_tool_outputs(self, msg: LLMMessage) -> None:
|
667
|
+
"""
|
668
|
+
Submit the tool (fn) outputs to the run/thread
|
669
|
+
"""
|
670
|
+
if self.run is None or self.thread is None:
|
671
|
+
raise ValueError("Run or Thread is None")
|
672
|
+
tool_outputs = [
|
673
|
+
{
|
674
|
+
"tool_call_id": msg.tool_id,
|
675
|
+
"output": msg.content,
|
676
|
+
}
|
677
|
+
]
|
678
|
+
# run enters queued, in_progress state after this
|
679
|
+
self.runs.submit_tool_outputs(
|
680
|
+
thread_id=self.thread.id,
|
681
|
+
run_id=self.run.id,
|
682
|
+
tool_outputs=tool_outputs, # type: ignore
|
683
|
+
)
|
684
|
+
|
685
|
+
def process_citations(self, thread_msg: Message) -> None:
|
686
|
+
"""
|
687
|
+
Process citations in the thread message.
|
688
|
+
Modifies the thread message in-place.
|
689
|
+
"""
|
690
|
+
# could there be multiple content items?
|
691
|
+
# TODO content could be MessageContentImageFile; handle that later
|
692
|
+
annotated_content = thread_msg.content[0].text # type: ignore
|
693
|
+
annotations = annotated_content.annotations
|
694
|
+
citations = []
|
695
|
+
# Iterate over the annotations and add footnotes
|
696
|
+
for index, annotation in enumerate(annotations):
|
697
|
+
# Replace the text with a footnote
|
698
|
+
annotated_content.value = annotated_content.value.replace(
|
699
|
+
annotation.text, f" [{index}]"
|
700
|
+
)
|
701
|
+
# Gather citations based on annotation attributes
|
702
|
+
if file_citation := getattr(annotation, "file_citation", None):
|
703
|
+
try:
|
704
|
+
cited_file = self.client.files.retrieve(file_citation.file_id)
|
705
|
+
except Exception:
|
706
|
+
logger.warning(
|
707
|
+
f"""
|
708
|
+
Could not retrieve cited file with id {file_citation.file_id},
|
709
|
+
ignoring.
|
710
|
+
"""
|
711
|
+
)
|
712
|
+
continue
|
713
|
+
citations.append(
|
714
|
+
f"[{index}] '{file_citation.quote}',-- from {cited_file.filename}"
|
715
|
+
)
|
716
|
+
elif file_path := getattr(annotation, "file_path", None):
|
717
|
+
cited_file = self.client.files.retrieve(file_path.file_id)
|
718
|
+
citations.append(
|
719
|
+
f"[{index}] Click <here> to download {cited_file.filename}"
|
720
|
+
)
|
721
|
+
# Note: File download functionality not implemented above for brevity
|
722
|
+
sep = "\n" if len(citations) > 0 else ""
|
723
|
+
annotated_content.value += sep + "\n".join(citations)
|
724
|
+
|
725
|
+
def _llm_response_preprocess(
|
726
|
+
self,
|
727
|
+
message: Optional[str | ChatDocument] = None,
|
728
|
+
) -> LLMResponse | None:
|
729
|
+
"""
|
730
|
+
Preprocess message and return response if found in cache, else None.
|
731
|
+
"""
|
732
|
+
is_tool_output = False
|
733
|
+
if message is not None:
|
734
|
+
# note: to_LLMMessage returns a list of LLMMessage,
|
735
|
+
# which is allowed to have len > 1, in case the msg
|
736
|
+
# represents results of multiple (non-assistant) tool-calls.
|
737
|
+
# But for OAI Assistant, we only assume exactly one tool-call at a time.
|
738
|
+
# TODO look into multi-tools
|
739
|
+
llm_msg = ChatDocument.to_LLMMessage(message)[0]
|
740
|
+
tool_id = llm_msg.tool_id
|
741
|
+
if tool_id in self.pending_tool_ids:
|
742
|
+
if isinstance(message, ChatDocument):
|
743
|
+
message.pop_tool_ids()
|
744
|
+
result_msg = f"Result for Tool_id {tool_id}: {llm_msg.content}"
|
745
|
+
if tool_id in self.cached_tool_ids:
|
746
|
+
self.cached_tool_ids.remove(tool_id)
|
747
|
+
# add actual result of cached fn-call
|
748
|
+
self._add_thread_message(result_msg, role=Role.USER)
|
749
|
+
else:
|
750
|
+
is_tool_output = True
|
751
|
+
# submit tool/fn result to the thread/run
|
752
|
+
self._submit_tool_outputs(llm_msg)
|
753
|
+
# We cannot ACTUALLY add this result to thread now
|
754
|
+
# since run is in `action_required` state,
|
755
|
+
# so we just update the message hash
|
756
|
+
self._update_messages_hash(
|
757
|
+
LLMMessage(content=result_msg, role=Role.USER)
|
758
|
+
)
|
759
|
+
self.pending_tool_ids.remove(tool_id)
|
760
|
+
else:
|
761
|
+
# add message to the thread
|
762
|
+
self._add_thread_message(llm_msg.content, role=Role.USER)
|
763
|
+
|
764
|
+
# When message is None, the thread may have no user msgs,
|
765
|
+
# Note: system message is NOT placed in the thread by the OpenAI system.
|
766
|
+
|
767
|
+
# check if we have cached the response.
|
768
|
+
# TODO: handle the case of structured result (fn-call, tool, etc)
|
769
|
+
response = self._cache_messages_lookup()
|
770
|
+
if response is not None:
|
771
|
+
response.cached = True
|
772
|
+
# store the result in the thread so
|
773
|
+
# it looks like assistant produced it
|
774
|
+
if self.config.cache_responses:
|
775
|
+
self._add_thread_message(
|
776
|
+
json.dumps(response.dict()), role=Role.ASSISTANT
|
777
|
+
)
|
778
|
+
return response # type: ignore
|
779
|
+
else:
|
780
|
+
# create a run for this assistant on this thread,
|
781
|
+
# i.e. actually "run"
|
782
|
+
if not is_tool_output:
|
783
|
+
# DO NOT start a run if we submitted tool outputs,
|
784
|
+
# since submission of tool outputs resumes a run from
|
785
|
+
# status = "requires_action"
|
786
|
+
self._start_run()
|
787
|
+
return None
|
788
|
+
|
789
|
+
def _llm_response_postprocess(
|
790
|
+
self,
|
791
|
+
response: LLMResponse,
|
792
|
+
cached: bool,
|
793
|
+
message: Optional[str | ChatDocument] = None,
|
794
|
+
) -> Optional[ChatDocument]:
|
795
|
+
# code from ChatAgent.llm_response_messages
|
796
|
+
if response.function_call is not None:
|
797
|
+
self.pending_tool_ids += [response.tool_id]
|
798
|
+
if cached:
|
799
|
+
# add to cached tools list so we don't create an Assistant run
|
800
|
+
# in _llm_response_preprocess
|
801
|
+
self.cached_tool_ids += [response.tool_id]
|
802
|
+
response_str = str(response.function_call)
|
803
|
+
else:
|
804
|
+
response_str = response.message
|
805
|
+
cache_str = "[red](cached)[/red]" if cached else ""
|
806
|
+
if not settings.quiet:
|
807
|
+
if not cached and self._get_code_logs_str():
|
808
|
+
print(
|
809
|
+
f"[magenta]CODE-INTERPRETER LOGS:\n"
|
810
|
+
"-------------------------------\n"
|
811
|
+
f"{self._get_code_logs_str()}[/magenta]"
|
812
|
+
)
|
813
|
+
print(f"{cache_str}[green]" + response_str + "[/green]")
|
814
|
+
cdoc = ChatDocument.from_LLMResponse(response, displayed=False)
|
815
|
+
# Note message.metadata.tool_ids may have been popped above
|
816
|
+
tool_ids = (
|
817
|
+
[]
|
818
|
+
if (message is None or isinstance(message, str))
|
819
|
+
else message.metadata.tool_ids
|
820
|
+
)
|
821
|
+
|
822
|
+
if response.tool_id != "":
|
823
|
+
tool_ids.append(response.tool_id)
|
824
|
+
cdoc.metadata.tool_ids = tool_ids
|
825
|
+
return cdoc
|
826
|
+
|
827
|
+
def llm_response(
|
828
|
+
self, message: Optional[str | ChatDocument] = None
|
829
|
+
) -> Optional[ChatDocument]:
|
830
|
+
"""
|
831
|
+
Override ChatAgent's method: this is the main LLM response method.
|
832
|
+
In the ChatAgent, this updates `self.message_history` and then calls
|
833
|
+
`self.llm_response_messages`, but since we are relying on the Assistant API
|
834
|
+
to maintain conversation state, this method is simpler: Simply start a run
|
835
|
+
on the message-thread, and wait for it to complete.
|
836
|
+
|
837
|
+
Args:
|
838
|
+
message (Optional[str | ChatDocument], optional): message to respond to
|
839
|
+
(if absent, the LLM response will be based on the
|
840
|
+
instructions in the system_message). Defaults to None.
|
841
|
+
Returns:
|
842
|
+
Optional[ChatDocument]: LLM response
|
843
|
+
"""
|
844
|
+
response = self._llm_response_preprocess(message)
|
845
|
+
cached = True
|
846
|
+
if response is None:
|
847
|
+
cached = False
|
848
|
+
response = self._run_result()
|
849
|
+
return self._llm_response_postprocess(response, cached=cached, message=message)
|
850
|
+
|
851
|
+
async def llm_response_async(
|
852
|
+
self, message: Optional[str | ChatDocument] = None
|
853
|
+
) -> Optional[ChatDocument]:
|
854
|
+
"""
|
855
|
+
Async version of llm_response.
|
856
|
+
"""
|
857
|
+
response = self._llm_response_preprocess(message)
|
858
|
+
cached = True
|
859
|
+
if response is None:
|
860
|
+
cached = False
|
861
|
+
response = await self._run_result_async()
|
862
|
+
return self._llm_response_postprocess(response, cached=cached, message=message)
|
863
|
+
|
864
|
+
def agent_response(
|
865
|
+
self,
|
866
|
+
msg: Optional[str | ChatDocument] = None,
|
867
|
+
) -> Optional[ChatDocument]:
|
868
|
+
response = super().agent_response(msg)
|
869
|
+
if msg is None:
|
870
|
+
return response
|
871
|
+
if response is None:
|
872
|
+
return None
|
873
|
+
try:
|
874
|
+
# When the agent response is to a tool message,
|
875
|
+
# we prefix it with "TOOL Result: " so that it is clear to the
|
876
|
+
# LLM that this is the result of the last TOOL;
|
877
|
+
# This ensures our caching trick works.
|
878
|
+
if self.config.use_tools and len(self.get_tool_messages(msg)) > 0:
|
879
|
+
response.content = "TOOL Result: " + response.content
|
880
|
+
return response
|
881
|
+
except Exception:
|
882
|
+
return response
|