langroid 0.6.6__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +206 -21
- langroid/agent/callbacks/chainlit.py +1 -1
- langroid/agent/chat_agent.py +124 -29
- langroid/agent/chat_document.py +132 -28
- langroid/agent/openai_assistant.py +8 -3
- langroid/agent/special/neo4j/neo4j_chat_agent.py +1 -9
- langroid/agent/special/sql/sql_chat_agent.py +69 -13
- langroid/agent/task.py +36 -9
- langroid/agent/tool_message.py +8 -5
- langroid/agent/tools/rewind_tool.py +1 -1
- langroid/language_models/.chainlit/config.toml +121 -0
- langroid/language_models/.chainlit/translations/en-US.json +231 -0
- langroid/language_models/base.py +111 -10
- langroid/language_models/mock_lm.py +10 -1
- langroid/language_models/openai_gpt.py +260 -36
- {langroid-0.6.6.dist-info → langroid-0.8.0.dist-info}/METADATA +3 -1
- {langroid-0.6.6.dist-info → langroid-0.8.0.dist-info}/RECORD +20 -18
- pyproject.toml +1 -1
- {langroid-0.6.6.dist-info → langroid-0.8.0.dist-info}/LICENSE +0 -0
- {langroid-0.6.6.dist-info → langroid-0.8.0.dist-info}/WHEEL +0 -0
langroid/agent/chat_document.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import copy
|
4
4
|
import json
|
5
5
|
from enum import Enum
|
6
|
-
from typing import Any, List, Optional, Union, cast
|
6
|
+
from typing import Any, Dict, List, Optional, Union, cast
|
7
7
|
|
8
8
|
from langroid.agent.tool_message import ToolMessage
|
9
9
|
from langroid.language_models.base import (
|
@@ -11,7 +11,9 @@ from langroid.language_models.base import (
|
|
11
11
|
LLMMessage,
|
12
12
|
LLMResponse,
|
13
13
|
LLMTokenUsage,
|
14
|
+
OpenAIToolCall,
|
14
15
|
Role,
|
16
|
+
ToolChoiceTypes,
|
15
17
|
)
|
16
18
|
from langroid.mytypes import DocMetaData, Document, Entity
|
17
19
|
from langroid.parsing.agent_chats import parse_message
|
@@ -51,6 +53,8 @@ class ChatDocMetaData(DocMetaData):
|
|
51
53
|
agent_id: str = "" # ChatAgent that generated this message
|
52
54
|
msg_idx: int = -1 # index of this message in the agent `message_history`
|
53
55
|
sender: Entity # sender of the message
|
56
|
+
# tool_id corresponding to single tool result in ChatDocument.content
|
57
|
+
oai_tool_id: str | None = None
|
54
58
|
tool_ids: List[str] = [] # stack of tool_ids; used by OpenAIAssistant
|
55
59
|
block: None | Entity = None
|
56
60
|
sender_name: str = ""
|
@@ -86,6 +90,31 @@ class ChatDocLoggerFields(BaseModel):
|
|
86
90
|
|
87
91
|
|
88
92
|
class ChatDocument(Document):
|
93
|
+
"""
|
94
|
+
Represents a message in a conversation among agents. All responders of an agent
|
95
|
+
have signature ChatDocument -> ChatDocument (modulo None, str, etc),
|
96
|
+
and so does the Task.run() method.
|
97
|
+
|
98
|
+
Attributes:
|
99
|
+
oai_tool_calls (List[OpenAIToolCall]): Tool-calls from an OpenAI-compatible API
|
100
|
+
oai_tool_id2results (Dict[str, str]): Results of tool-calls from OpenAI
|
101
|
+
(dict is a map of tool_id -> result)
|
102
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, str]: Param controlling how the
|
103
|
+
LLM should choose tool-use in its response
|
104
|
+
(auto, none, required, or a specific tool)
|
105
|
+
function_call (LLMFunctionCall): Function-call from an OpenAI-compatible API
|
106
|
+
(deprecated; use oai_tool_calls instead)
|
107
|
+
tool_messages (List[ToolMessage]): Langroid ToolMessages extracted from
|
108
|
+
- `content` field (via JSON parsing),
|
109
|
+
- `oai_tool_calls`, or
|
110
|
+
- `function_call`
|
111
|
+
metadata (ChatDocMetaData): Metadata for the message, e.g. sender, recipient.
|
112
|
+
attachment (None | ChatDocAttachment): Any additional data attached.
|
113
|
+
"""
|
114
|
+
|
115
|
+
oai_tool_calls: Optional[List[OpenAIToolCall]] = None
|
116
|
+
oai_tool_id2result: Optional[Dict[str, str]] = None
|
117
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto"
|
89
118
|
function_call: Optional[LLMFunctionCall] = None
|
90
119
|
tool_messages: List[ToolMessage] = []
|
91
120
|
metadata: ChatDocMetaData
|
@@ -198,6 +227,24 @@ class ChatDocument(Document):
|
|
198
227
|
if len(self.metadata.tool_ids) > 0:
|
199
228
|
self.metadata.tool_ids.pop()
|
200
229
|
|
230
|
+
@staticmethod
|
231
|
+
def _clean_fn_call(fc: LLMFunctionCall | None) -> None:
|
232
|
+
# Sometimes an OpenAI LLM (esp gpt-4o) may generate a function-call
|
233
|
+
# with odditities:
|
234
|
+
# (a) the `name` is set, as well as `arguments.request` is set,
|
235
|
+
# and in langroid we use the `request` value as the `name`.
|
236
|
+
# In this case we override the `name` with the `request` value.
|
237
|
+
# (b) the `name` looks like "functions blah" or just "functions"
|
238
|
+
# In this case we strip the "functions" part.
|
239
|
+
if fc is None:
|
240
|
+
return
|
241
|
+
fc.name = fc.name.replace("functions", "").strip()
|
242
|
+
if fc.arguments is not None:
|
243
|
+
request = fc.arguments.get("request")
|
244
|
+
if request is not None and request != "":
|
245
|
+
fc.name = request
|
246
|
+
fc.arguments.pop("request")
|
247
|
+
|
201
248
|
@staticmethod
|
202
249
|
def from_LLMResponse(
|
203
250
|
response: LLMResponse,
|
@@ -216,22 +263,14 @@ class ChatDocument(Document):
|
|
216
263
|
if message in ["''", '""']:
|
217
264
|
message = ""
|
218
265
|
if response.function_call is not None:
|
219
|
-
|
220
|
-
|
221
|
-
#
|
222
|
-
|
223
|
-
|
224
|
-
# (b) the `name` looks like "functions blah" or just "functions"
|
225
|
-
# In this case we strip the "functions" part.
|
226
|
-
fc = response.function_call
|
227
|
-
fc.name = fc.name.replace("functions", "").strip()
|
228
|
-
if fc.arguments is not None:
|
229
|
-
request = fc.arguments.get("request")
|
230
|
-
if request is not None and request != "":
|
231
|
-
fc.name = request
|
232
|
-
fc.arguments.pop("request")
|
266
|
+
ChatDocument._clean_fn_call(response.function_call)
|
267
|
+
if response.oai_tool_calls is not None:
|
268
|
+
# there must be at least one if it's not None
|
269
|
+
for oai_tc in response.oai_tool_calls:
|
270
|
+
ChatDocument._clean_fn_call(oai_tc.function)
|
233
271
|
return ChatDocument(
|
234
272
|
content=message,
|
273
|
+
oai_tool_calls=response.oai_tool_calls,
|
235
274
|
function_call=response.function_call,
|
236
275
|
metadata=ChatDocMetaData(
|
237
276
|
source=Entity.LLM,
|
@@ -261,24 +300,33 @@ class ChatDocument(Document):
|
|
261
300
|
)
|
262
301
|
|
263
302
|
@staticmethod
|
264
|
-
def to_LLMMessage(
|
303
|
+
def to_LLMMessage(
|
304
|
+
message: Union[str, "ChatDocument"],
|
305
|
+
oai_tools: Optional[List[OpenAIToolCall]] = None,
|
306
|
+
) -> List[LLMMessage]:
|
265
307
|
"""
|
266
|
-
Convert to LLMMessage
|
308
|
+
Convert to list of LLMMessage, to incorporate into msg-history sent to LLM API.
|
309
|
+
Usually there will be just a single LLMMessage, but when the ChatDocument
|
310
|
+
contains results from multiple OpenAI tool-calls, we would have a sequence
|
311
|
+
LLMMessages, one per tool-call result.
|
267
312
|
|
268
313
|
Args:
|
269
314
|
message (str|ChatDocument): Message to convert.
|
315
|
+
oai_tools (Optional[List[OpenAIToolCall]]): Tool-calls currently awaiting
|
316
|
+
response, from the ChatAgent's latest message.
|
270
317
|
Returns:
|
271
|
-
LLMMessage:
|
272
|
-
|
318
|
+
List[LLMMessage]: list of LLMMessages corresponding to this ChatDocument.
|
273
319
|
"""
|
274
320
|
sender_name = None
|
275
321
|
sender_role = Role.USER
|
276
322
|
fun_call = None
|
277
|
-
|
323
|
+
oai_tool_calls = None
|
324
|
+
tool_id = "" # for OpenAI Assistant
|
278
325
|
chat_document_id: str = ""
|
279
326
|
if isinstance(message, ChatDocument):
|
280
327
|
content = message.content
|
281
328
|
fun_call = message.function_call
|
329
|
+
oai_tool_calls = message.oai_tool_calls
|
282
330
|
if message.metadata.sender == Entity.USER and fun_call is not None:
|
283
331
|
# This may happen when a (parent agent's) LLM generates a
|
284
332
|
# a Function-call, and it ends up being sent to the current task's
|
@@ -289,6 +337,10 @@ class ChatDocument(Document):
|
|
289
337
|
# in the content of the message.
|
290
338
|
content += " " + str(fun_call)
|
291
339
|
fun_call = None
|
340
|
+
if message.metadata.sender == Entity.USER and oai_tool_calls is not None:
|
341
|
+
# same reasoning as for function-call above
|
342
|
+
content += " " + "\n\n".join(str(tc) for tc in oai_tool_calls)
|
343
|
+
oai_tool_calls = None
|
292
344
|
sender_name = message.metadata.sender_name
|
293
345
|
tool_ids = message.metadata.tool_ids
|
294
346
|
tool_id = tool_ids[-1] if len(tool_ids) > 0 else ""
|
@@ -299,22 +351,74 @@ class ChatDocument(Document):
|
|
299
351
|
message.metadata.parent is not None
|
300
352
|
and message.metadata.parent.function_call is not None
|
301
353
|
):
|
354
|
+
# This is a response to a function call, so set the role to FUNCTION.
|
302
355
|
sender_role = Role.FUNCTION
|
303
356
|
sender_name = message.metadata.parent.function_call.name
|
357
|
+
elif oai_tools is not None and len(oai_tools) > 0:
|
358
|
+
pending_tool_ids = [tc.id for tc in oai_tools]
|
359
|
+
# The ChatAgent has pending OpenAI tool-call(s),
|
360
|
+
# so the current ChatDocument contains
|
361
|
+
# results for some/all/none of them.
|
362
|
+
|
363
|
+
if len(oai_tools) == 1:
|
364
|
+
# Case 1:
|
365
|
+
# There was exactly 1 pending tool-call, and in this case
|
366
|
+
# the result would be a plain string in `content`
|
367
|
+
return [
|
368
|
+
LLMMessage(
|
369
|
+
role=Role.TOOL,
|
370
|
+
tool_call_id=oai_tools[0].id,
|
371
|
+
content=content,
|
372
|
+
chat_document_id=chat_document_id,
|
373
|
+
)
|
374
|
+
]
|
375
|
+
|
376
|
+
elif (
|
377
|
+
message.metadata.oai_tool_id is not None
|
378
|
+
and message.metadata.oai_tool_id in pending_tool_ids
|
379
|
+
):
|
380
|
+
# Case 2:
|
381
|
+
# ChatDocument.content has result of a single tool-call
|
382
|
+
return [
|
383
|
+
LLMMessage(
|
384
|
+
role=Role.TOOL,
|
385
|
+
tool_call_id=message.metadata.oai_tool_id,
|
386
|
+
content=content,
|
387
|
+
chat_document_id=chat_document_id,
|
388
|
+
)
|
389
|
+
]
|
390
|
+
elif message.oai_tool_id2result is not None:
|
391
|
+
# Case 2:
|
392
|
+
# There were > 1 tool-calls awaiting response,
|
393
|
+
assert (
|
394
|
+
len(message.oai_tool_id2result) > 1
|
395
|
+
), "oai_tool_id2result must have more than 1 item."
|
396
|
+
return [
|
397
|
+
LLMMessage(
|
398
|
+
role=Role.TOOL,
|
399
|
+
tool_call_id=tool_id,
|
400
|
+
content=result,
|
401
|
+
chat_document_id=chat_document_id,
|
402
|
+
)
|
403
|
+
for tool_id, result in message.oai_tool_id2result.items()
|
404
|
+
]
|
304
405
|
elif message.metadata.sender == Entity.LLM:
|
305
406
|
sender_role = Role.ASSISTANT
|
306
407
|
else:
|
307
408
|
# LLM can only respond to text content, so extract it
|
308
409
|
content = message
|
309
410
|
|
310
|
-
return
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
411
|
+
return [
|
412
|
+
LLMMessage(
|
413
|
+
role=sender_role,
|
414
|
+
tool_id=tool_id, # for OpenAI Assistant
|
415
|
+
content=content,
|
416
|
+
function_call=fun_call,
|
417
|
+
tool_calls=oai_tool_calls,
|
418
|
+
name=sender_name,
|
419
|
+
chat_document_id=chat_document_id,
|
420
|
+
)
|
421
|
+
]
|
318
422
|
|
319
423
|
|
320
424
|
LLMMessage.update_forward_refs()
|
@@ -79,7 +79,7 @@ class OpenAIAssistantConfig(ChatAgentConfig):
|
|
79
79
|
# set to True once we can add Assistant msgs in threads
|
80
80
|
cache_responses: bool = True
|
81
81
|
timeout: int = 30 # can be different from llm.timeout
|
82
|
-
llm = OpenAIGPTConfig(chat_model=OpenAIChatModel.
|
82
|
+
llm = OpenAIGPTConfig(chat_model=OpenAIChatModel.GPT4o)
|
83
83
|
tools: List[AssistantTool] = []
|
84
84
|
files: List[str] = []
|
85
85
|
|
@@ -192,7 +192,7 @@ class OpenAIAssistant(ChatAgent):
|
|
192
192
|
self.set_system_message(sys_msg.content)
|
193
193
|
if not self.config.use_functions_api:
|
194
194
|
return
|
195
|
-
functions, _ = self._function_args()
|
195
|
+
functions, _, _, _ = self._function_args()
|
196
196
|
if functions is None:
|
197
197
|
return
|
198
198
|
# add the functions to the assistant:
|
@@ -720,7 +720,12 @@ class OpenAIAssistant(ChatAgent):
|
|
720
720
|
"""
|
721
721
|
is_tool_output = False
|
722
722
|
if message is not None:
|
723
|
-
|
723
|
+
# note: to_LLMMessage returns a list of LLMMessage,
|
724
|
+
# which is allowed to have len > 1, in case the msg
|
725
|
+
# represents results of multiple (non-assistant) tool-calls.
|
726
|
+
# But for OAI Assistant, we only assume exactly one tool-call at a time.
|
727
|
+
# TODO look into multi-tools
|
728
|
+
llm_msg = ChatDocument.to_LLMMessage(message)[0]
|
724
729
|
tool_id = llm_msg.tool_id
|
725
730
|
if tool_id in self.pending_tool_ids:
|
726
731
|
if isinstance(message, ChatDocument):
|
@@ -6,7 +6,6 @@ from rich import print
|
|
6
6
|
from rich.console import Console
|
7
7
|
|
8
8
|
from langroid.agent import ToolMessage
|
9
|
-
from langroid.parsing.parse_json import datetime_to_json
|
10
9
|
from langroid.pydantic_v1 import BaseModel, BaseSettings
|
11
10
|
|
12
11
|
if TYPE_CHECKING:
|
@@ -285,14 +284,7 @@ class Neo4jChatAgent(ChatAgent):
|
|
285
284
|
names you found in the schema.
|
286
285
|
Or retry using one of the RETRY-SUGGESTIONS in your instructions.
|
287
286
|
"""
|
288
|
-
|
289
|
-
try:
|
290
|
-
json_str = json.dumps(response.data, default=datetime_to_json)
|
291
|
-
return json_str
|
292
|
-
except TypeError:
|
293
|
-
return str(response.data)
|
294
|
-
else:
|
295
|
-
return str(response.data)
|
287
|
+
return str(response.data)
|
296
288
|
|
297
289
|
def create_query(self, msg: CypherCreationTool) -> str:
|
298
290
|
""" "
|
@@ -14,11 +14,12 @@ from rich import print
|
|
14
14
|
from rich.console import Console
|
15
15
|
|
16
16
|
from langroid.exceptions import LangroidImportError
|
17
|
+
from langroid.utils.constants import DONE
|
17
18
|
|
18
19
|
try:
|
19
20
|
from sqlalchemy import MetaData, Row, create_engine, inspect, text
|
20
21
|
from sqlalchemy.engine import Engine
|
21
|
-
from sqlalchemy.exc import SQLAlchemyError
|
22
|
+
from sqlalchemy.exc import ResourceClosedError, SQLAlchemyError
|
22
23
|
from sqlalchemy.orm import Session, sessionmaker
|
23
24
|
except ImportError as e:
|
24
25
|
raise LangroidImportError(extra="sql", error=str(e))
|
@@ -49,8 +50,8 @@ logger = logging.getLogger(__name__)
|
|
49
50
|
|
50
51
|
console = Console()
|
51
52
|
|
52
|
-
DEFAULT_SQL_CHAT_SYSTEM_MESSAGE = """
|
53
|
-
{mode}
|
53
|
+
DEFAULT_SQL_CHAT_SYSTEM_MESSAGE = f"""
|
54
|
+
{{mode}}
|
54
55
|
|
55
56
|
You do not need to attempt answering a question with just one query.
|
56
57
|
You could make a sequence of SQL queries to help you write the final query.
|
@@ -64,16 +65,19 @@ are "Male" and "Female".
|
|
64
65
|
|
65
66
|
Start by asking what I would like to know about the data.
|
66
67
|
|
68
|
+
When you have FINISHED the given query or database update task,
|
69
|
+
say {DONE} and show your answer.
|
70
|
+
|
67
71
|
"""
|
68
72
|
|
69
|
-
ADDRESSING_INSTRUCTION = """
|
73
|
+
ADDRESSING_INSTRUCTION = f"""
|
70
74
|
IMPORTANT - Whenever you are NOT writing a SQL query, make sure you address the user
|
71
|
-
using {prefix}User. You MUST use the EXACT syntax {prefix} !!!
|
75
|
+
using {{prefix}}User. You MUST use the EXACT syntax {{prefix}} !!!
|
72
76
|
|
73
77
|
In other words, you ALWAYS write EITHER:
|
74
78
|
- a SQL query using the `run_query` tool,
|
75
|
-
- OR address the user using {prefix}User
|
76
|
-
|
79
|
+
- OR address the user using {{prefix}}User, and include {DONE} to indicate your
|
80
|
+
task is FINISHED.
|
77
81
|
"""
|
78
82
|
|
79
83
|
|
@@ -135,6 +139,9 @@ class SQLChatAgent(ChatAgent):
|
|
135
139
|
Agent for chatting with a SQL database
|
136
140
|
"""
|
137
141
|
|
142
|
+
used_run_query: bool = False
|
143
|
+
llm_responded: bool = False
|
144
|
+
|
138
145
|
def __init__(self, config: "SQLChatAgentConfig") -> None:
|
139
146
|
"""Initialize the SQLChatAgent.
|
140
147
|
|
@@ -246,7 +253,49 @@ class SQLChatAgent(ChatAgent):
|
|
246
253
|
self.enable_message(GetTableSchemaTool)
|
247
254
|
self.enable_message(GetColumnDescriptionsTool)
|
248
255
|
|
249
|
-
def
|
256
|
+
def llm_response(
|
257
|
+
self, message: Optional[str | ChatDocument] = None
|
258
|
+
) -> Optional[ChatDocument]:
|
259
|
+
self.llm_responded = True
|
260
|
+
return super().llm_response(message)
|
261
|
+
|
262
|
+
def user_response(
|
263
|
+
self,
|
264
|
+
msg: Optional[str | ChatDocument] = None,
|
265
|
+
) -> Optional[ChatDocument]:
|
266
|
+
self.llm_responded = False
|
267
|
+
self.used_run_query = False
|
268
|
+
return super().user_response(msg)
|
269
|
+
|
270
|
+
def handle_message_fallback(
|
271
|
+
self, msg: str | ChatDocument
|
272
|
+
) -> str | ChatDocument | None:
|
273
|
+
|
274
|
+
if not self.llm_responded:
|
275
|
+
return None
|
276
|
+
if self.used_run_query:
|
277
|
+
prefix = (
|
278
|
+
self.config.addressing_prefix + "User"
|
279
|
+
if self.config.addressing_prefix
|
280
|
+
else ""
|
281
|
+
)
|
282
|
+
return (
|
283
|
+
DONE + prefix + (msg.content if isinstance(msg, ChatDocument) else msg)
|
284
|
+
)
|
285
|
+
|
286
|
+
else:
|
287
|
+
reminder = """
|
288
|
+
You may have forgotten to use the `run_query` tool to execute an SQL query
|
289
|
+
for the user's question/request
|
290
|
+
"""
|
291
|
+
if self.config.addressing_prefix != "":
|
292
|
+
reminder += f"""
|
293
|
+
OR you may have forgotten to address the user using the prefix
|
294
|
+
{self.config.addressing_prefix}
|
295
|
+
"""
|
296
|
+
return reminder
|
297
|
+
|
298
|
+
def _agent_response(
|
250
299
|
self,
|
251
300
|
msg: Optional[str | ChatDocument] = None,
|
252
301
|
) -> Optional[ChatDocument]:
|
@@ -326,16 +375,23 @@ class SQLChatAgent(ChatAgent):
|
|
326
375
|
"""
|
327
376
|
query = msg.query
|
328
377
|
session = self.Session
|
329
|
-
|
330
|
-
|
378
|
+
self.used_run_query = True
|
331
379
|
try:
|
332
380
|
logger.info(f"Executing SQL query: {query}")
|
333
381
|
|
334
382
|
query_result = session.execute(text(query))
|
335
383
|
session.commit()
|
336
|
-
|
337
|
-
|
338
|
-
|
384
|
+
try:
|
385
|
+
# attempt to fetch results: should work for normal SELECT queries
|
386
|
+
rows = query_result.fetchall()
|
387
|
+
response_message = self._format_rows(rows)
|
388
|
+
except ResourceClosedError:
|
389
|
+
# If we get here, it's a non-SELECT query (UPDATE, INSERT, DELETE)
|
390
|
+
affected_rows = query_result.rowcount # type: ignore
|
391
|
+
response_message = f"""
|
392
|
+
Non-SELECT query executed successfully.
|
393
|
+
Rows affected: {affected_rows}
|
394
|
+
"""
|
339
395
|
|
340
396
|
except SQLAlchemyError as e:
|
341
397
|
session.rollback()
|
langroid/agent/task.py
CHANGED
@@ -923,7 +923,6 @@ class Task:
|
|
923
923
|
# create dummy msg for logging
|
924
924
|
log_doc = ChatDocument(
|
925
925
|
content="[CANNOT RESPOND]",
|
926
|
-
function_call=None,
|
927
926
|
metadata=ChatDocMetaData(
|
928
927
|
sender=r if isinstance(r, Entity) else Entity.USER,
|
929
928
|
sender_name=str(r),
|
@@ -1027,7 +1026,6 @@ class Task:
|
|
1027
1026
|
# create dummy msg for logging
|
1028
1027
|
log_doc = ChatDocument(
|
1029
1028
|
content="[CANNOT RESPOND]",
|
1030
|
-
function_call=None,
|
1031
1029
|
metadata=ChatDocMetaData(
|
1032
1030
|
sender=r if isinstance(r, Entity) else Entity.USER,
|
1033
1031
|
sender_name=str(r),
|
@@ -1111,7 +1109,7 @@ class Task:
|
|
1111
1109
|
self.pending_sender = r
|
1112
1110
|
self.pending_message = result
|
1113
1111
|
# set the parent/child links ONLY if not already set by agent internally,
|
1114
|
-
# which may happen when using the RewindTool
|
1112
|
+
# which may happen when using the RewindTool, or in other scenarios.
|
1115
1113
|
if parent is not None and not result.metadata.parent_id:
|
1116
1114
|
result.metadata.parent_id = parent.id()
|
1117
1115
|
if parent is not None and not parent.metadata.child_id:
|
@@ -1185,8 +1183,24 @@ class Task:
|
|
1185
1183
|
max_cost=self.max_cost,
|
1186
1184
|
max_tokens=self.max_tokens,
|
1187
1185
|
)
|
1186
|
+
if result is not None:
|
1187
|
+
content, id2result, oai_tool_id = self.agent._process_tool_results(
|
1188
|
+
result.content,
|
1189
|
+
result.oai_tool_id2result,
|
1190
|
+
(
|
1191
|
+
self.pending_message.oai_tool_calls
|
1192
|
+
if isinstance(self.pending_message, ChatDocument)
|
1193
|
+
else None
|
1194
|
+
),
|
1195
|
+
)
|
1196
|
+
result.content = content
|
1197
|
+
result.oai_tool_id2result = id2result
|
1198
|
+
result.metadata.oai_tool_id = oai_tool_id
|
1199
|
+
|
1188
1200
|
result_str = ( # only used by callback to display content and possible tool
|
1189
|
-
"NONE"
|
1201
|
+
"NONE"
|
1202
|
+
if result is None
|
1203
|
+
else "\n\n".join(str(m) for m in ChatDocument.to_LLMMessage(result))
|
1190
1204
|
)
|
1191
1205
|
maybe_tool = len(extract_top_level_json(result_str)) > 0
|
1192
1206
|
self.callbacks.show_subtask_response(
|
@@ -1266,7 +1280,11 @@ class Task:
|
|
1266
1280
|
max_cost=self.max_cost,
|
1267
1281
|
max_tokens=self.max_tokens,
|
1268
1282
|
)
|
1269
|
-
result_str =
|
1283
|
+
result_str = ( # only used by callback to display content and possible tool
|
1284
|
+
"NONE"
|
1285
|
+
if result is None
|
1286
|
+
else "\n\n".join(str(m) for m in ChatDocument.to_LLMMessage(result))
|
1287
|
+
)
|
1270
1288
|
maybe_tool = len(extract_top_level_json(result_str)) > 0
|
1271
1289
|
self.callbacks.show_subtask_response(
|
1272
1290
|
task=e,
|
@@ -1301,6 +1319,8 @@ class Task:
|
|
1301
1319
|
if DONE in content:
|
1302
1320
|
# assuming it is of the form "DONE: <content>"
|
1303
1321
|
content = content.replace(DONE, "").strip()
|
1322
|
+
oai_tool_calls = result_msg.oai_tool_calls if result_msg else None
|
1323
|
+
oai_tool_id2result = result_msg.oai_tool_id2result if result_msg else None
|
1304
1324
|
fun_call = result_msg.function_call if result_msg else None
|
1305
1325
|
tool_messages = result_msg.tool_messages if result_msg else []
|
1306
1326
|
block = result_msg.metadata.block if result_msg else None
|
@@ -1312,6 +1332,8 @@ class Task:
|
|
1312
1332
|
# since to the "parent" task, this result is equivalent to a response from USER
|
1313
1333
|
result_doc = ChatDocument(
|
1314
1334
|
content=content,
|
1335
|
+
oai_tool_calls=oai_tool_calls,
|
1336
|
+
oai_tool_id2result=oai_tool_id2result,
|
1315
1337
|
function_call=fun_call,
|
1316
1338
|
tool_messages=tool_messages,
|
1317
1339
|
metadata=ChatDocMetaData(
|
@@ -1346,6 +1368,8 @@ class Task:
|
|
1346
1368
|
isinstance(msg, ChatDocument)
|
1347
1369
|
and msg.content.strip() in [PASS, ""]
|
1348
1370
|
and msg.function_call is None
|
1371
|
+
and msg.oai_tool_calls is None
|
1372
|
+
and msg.oai_tool_id2result is None
|
1349
1373
|
and msg.tool_messages == []
|
1350
1374
|
)
|
1351
1375
|
)
|
@@ -1474,8 +1498,8 @@ class Task:
|
|
1474
1498
|
and (result.content in USER_QUIT_STRINGS or DONE in result.content)
|
1475
1499
|
and result.metadata.sender == Entity.USER
|
1476
1500
|
)
|
1477
|
-
if self._level == 0 and self.
|
1478
|
-
# for top-level task,
|
1501
|
+
if self._level == 0 and self._user_can_respond() and self.only_user_quits_root:
|
1502
|
+
# for top-level task, only user can quit out
|
1479
1503
|
return (user_quit, StatusCode.USER_QUIT if user_quit else StatusCode.OK)
|
1480
1504
|
|
1481
1505
|
if self.is_done:
|
@@ -1630,14 +1654,17 @@ class Task:
|
|
1630
1654
|
and recipient != self.name # case sensitive
|
1631
1655
|
)
|
1632
1656
|
|
1633
|
-
def
|
1634
|
-
|
1657
|
+
def _user_can_respond(self) -> bool:
|
1658
|
+
return self.interactive or (
|
1635
1659
|
# regardless of self.interactive, if a msg is explicitly addressed to
|
1636
1660
|
# user, then wait for user response
|
1637
1661
|
self.pending_message is not None
|
1638
1662
|
and self.pending_message.metadata.recipient == Entity.USER
|
1639
1663
|
)
|
1640
1664
|
|
1665
|
+
def _can_respond(self, e: Responder) -> bool:
|
1666
|
+
user_can_respond = self._user_can_respond()
|
1667
|
+
|
1641
1668
|
if self.pending_sender == e or (e == Entity.USER and not user_can_respond):
|
1642
1669
|
# sender is same as e (an entity cannot respond to its own msg),
|
1643
1670
|
# or user cannot respond
|
langroid/agent/tool_message.py
CHANGED
@@ -39,6 +39,7 @@ class ToolMessage(ABC, BaseModel):
|
|
39
39
|
|
40
40
|
request: str
|
41
41
|
purpose: str
|
42
|
+
id: str = "" # placeholder for OpenAI-API tool_call_id
|
42
43
|
|
43
44
|
class Config:
|
44
45
|
arbitrary_types_allowed = False
|
@@ -46,7 +47,7 @@ class ToolMessage(ABC, BaseModel):
|
|
46
47
|
validate_assignment = True
|
47
48
|
# do not include these fields in the generated schema
|
48
49
|
# since we don't require the LLM to specify them
|
49
|
-
schema_extra = {"exclude": {"purpose"}}
|
50
|
+
schema_extra = {"exclude": {"purpose", "id"}}
|
50
51
|
|
51
52
|
@classmethod
|
52
53
|
def instructions(cls) -> str:
|
@@ -108,13 +109,13 @@ class ToolMessage(ABC, BaseModel):
|
|
108
109
|
return "\n\n".join(examples_jsons)
|
109
110
|
|
110
111
|
def to_json(self) -> str:
|
111
|
-
return self.json(indent=4, exclude=
|
112
|
+
return self.json(indent=4, exclude=self.Config.schema_extra["exclude"])
|
112
113
|
|
113
114
|
def json_example(self) -> str:
|
114
|
-
return self.json(indent=4, exclude=
|
115
|
+
return self.json(indent=4, exclude=self.Config.schema_extra["exclude"])
|
115
116
|
|
116
117
|
def dict_example(self) -> Dict[str, Any]:
|
117
|
-
return self.dict(exclude=
|
118
|
+
return self.dict(exclude=self.Config.schema_extra["exclude"])
|
118
119
|
|
119
120
|
@classmethod
|
120
121
|
def default_value(cls, f: str) -> Any:
|
@@ -218,7 +219,9 @@ class ToolMessage(ABC, BaseModel):
|
|
218
219
|
if "description" not in parameters["properties"][name]:
|
219
220
|
parameters["properties"][name]["description"] = description
|
220
221
|
|
221
|
-
excludes = ["
|
222
|
+
excludes = cls.Config.schema_extra["exclude"]
|
223
|
+
if not request:
|
224
|
+
excludes = excludes.union({"request"})
|
222
225
|
# exclude 'excludes' from parameters["properties"]:
|
223
226
|
parameters["properties"] = {
|
224
227
|
field: details
|
@@ -127,7 +127,7 @@ class RewindTool(ToolMessage):
|
|
127
127
|
result_doc.metadata.msg_idx = idx
|
128
128
|
|
129
129
|
# replace the message at idx with this new message
|
130
|
-
agent.message_history.
|
130
|
+
agent.message_history.extend(ChatDocument.to_LLMMessage(result_doc))
|
131
131
|
|
132
132
|
# set the replaced doc's parent's child to this result_doc
|
133
133
|
if parent is not None:
|