langroid 0.58.2__py3-none-any.whl → 0.59.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +39 -17
- langroid/agent/base.py-e +2216 -0
- langroid/agent/callbacks/chainlit.py +2 -1
- langroid/agent/chat_agent.py +73 -55
- langroid/agent/chat_agent.py-e +2086 -0
- langroid/agent/chat_document.py +7 -7
- langroid/agent/chat_document.py-e +513 -0
- langroid/agent/openai_assistant.py +9 -9
- langroid/agent/openai_assistant.py-e +882 -0
- langroid/agent/special/arangodb/arangodb_agent.py +10 -18
- langroid/agent/special/arangodb/arangodb_agent.py-e +648 -0
- langroid/agent/special/arangodb/tools.py +3 -3
- langroid/agent/special/doc_chat_agent.py +16 -14
- langroid/agent/special/lance_rag/critic_agent.py +2 -2
- langroid/agent/special/lance_rag/query_planner_agent.py +4 -4
- langroid/agent/special/lance_tools.py +6 -5
- langroid/agent/special/lance_tools.py-e +61 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +3 -7
- langroid/agent/special/neo4j/neo4j_chat_agent.py-e +430 -0
- langroid/agent/special/relevance_extractor_agent.py +1 -1
- langroid/agent/special/sql/sql_chat_agent.py +11 -3
- langroid/agent/task.py +9 -87
- langroid/agent/task.py-e +2418 -0
- langroid/agent/tool_message.py +33 -17
- langroid/agent/tool_message.py-e +400 -0
- langroid/agent/tools/file_tools.py +4 -2
- langroid/agent/tools/file_tools.py-e +234 -0
- langroid/agent/tools/mcp/fastmcp_client.py +19 -6
- langroid/agent/tools/mcp/fastmcp_client.py-e +584 -0
- langroid/agent/tools/orchestration.py +22 -17
- langroid/agent/tools/orchestration.py-e +301 -0
- langroid/agent/tools/recipient_tool.py +3 -3
- langroid/agent/tools/task_tool.py +22 -16
- langroid/agent/tools/task_tool.py-e +249 -0
- langroid/agent/xml_tool_message.py +90 -35
- langroid/agent/xml_tool_message.py-e +392 -0
- langroid/cachedb/base.py +1 -1
- langroid/embedding_models/base.py +2 -2
- langroid/embedding_models/models.py +3 -7
- langroid/embedding_models/models.py-e +563 -0
- langroid/exceptions.py +4 -1
- langroid/language_models/azure_openai.py +2 -2
- langroid/language_models/azure_openai.py-e +134 -0
- langroid/language_models/base.py +6 -4
- langroid/language_models/base.py-e +812 -0
- langroid/language_models/client_cache.py +64 -0
- langroid/language_models/config.py +2 -4
- langroid/language_models/config.py-e +18 -0
- langroid/language_models/model_info.py +9 -1
- langroid/language_models/model_info.py-e +483 -0
- langroid/language_models/openai_gpt.py +119 -20
- langroid/language_models/openai_gpt.py-e +2280 -0
- langroid/language_models/provider_params.py +3 -22
- langroid/language_models/provider_params.py-e +153 -0
- langroid/mytypes.py +11 -4
- langroid/mytypes.py-e +132 -0
- langroid/parsing/code_parser.py +1 -1
- langroid/parsing/file_attachment.py +1 -1
- langroid/parsing/file_attachment.py-e +246 -0
- langroid/parsing/md_parser.py +14 -4
- langroid/parsing/md_parser.py-e +574 -0
- langroid/parsing/parser.py +22 -7
- langroid/parsing/parser.py-e +410 -0
- langroid/parsing/repo_loader.py +3 -1
- langroid/parsing/repo_loader.py-e +812 -0
- langroid/parsing/search.py +1 -1
- langroid/parsing/url_loader.py +17 -51
- langroid/parsing/url_loader.py-e +683 -0
- langroid/parsing/urls.py +5 -4
- langroid/parsing/urls.py-e +279 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/pydantic_v1/__init__.py +45 -6
- langroid/pydantic_v1/__init__.py-e +36 -0
- langroid/pydantic_v1/main.py +11 -4
- langroid/pydantic_v1/main.py-e +11 -0
- langroid/utils/configuration.py +13 -11
- langroid/utils/configuration.py-e +141 -0
- langroid/utils/constants.py +1 -1
- langroid/utils/constants.py-e +32 -0
- langroid/utils/globals.py +21 -5
- langroid/utils/globals.py-e +49 -0
- langroid/utils/html_logger.py +2 -1
- langroid/utils/html_logger.py-e +825 -0
- langroid/utils/object_registry.py +1 -1
- langroid/utils/object_registry.py-e +66 -0
- langroid/utils/pydantic_utils.py +55 -28
- langroid/utils/pydantic_utils.py-e +602 -0
- langroid/utils/types.py +2 -2
- langroid/utils/types.py-e +113 -0
- langroid/vector_store/base.py +3 -3
- langroid/vector_store/lancedb.py +5 -5
- langroid/vector_store/lancedb.py-e +404 -0
- langroid/vector_store/meilisearch.py +2 -2
- langroid/vector_store/pineconedb.py +4 -4
- langroid/vector_store/pineconedb.py-e +427 -0
- langroid/vector_store/postgres.py +1 -1
- langroid/vector_store/qdrantdb.py +3 -3
- langroid/vector_store/weaviatedb.py +1 -1
- {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/METADATA +3 -2
- langroid-0.59.0b1.dist-info/RECORD +181 -0
- langroid/agent/special/doc_chat_task.py +0 -0
- langroid/mcp/__init__.py +0 -1
- langroid/mcp/server/__init__.py +0 -1
- langroid-0.58.2.dist-info/RECORD +0 -145
- {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/WHEEL +0 -0
- {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/licenses/LICENSE +0 -0
langroid/agent/chat_document.py
CHANGED
@@ -6,6 +6,8 @@ from collections import OrderedDict
|
|
6
6
|
from enum import Enum
|
7
7
|
from typing import Any, Dict, List, Optional, Union, cast
|
8
8
|
|
9
|
+
from pydantic import BaseModel, ConfigDict
|
10
|
+
|
9
11
|
from langroid.agent.tool_message import ToolMessage
|
10
12
|
from langroid.agent.xml_tool_message import XMLToolMessage
|
11
13
|
from langroid.language_models.base import (
|
@@ -21,7 +23,6 @@ from langroid.mytypes import DocMetaData, Document, Entity
|
|
21
23
|
from langroid.parsing.agent_chats import parse_message
|
22
24
|
from langroid.parsing.file_attachment import FileAttachment
|
23
25
|
from langroid.parsing.parse_json import extract_top_level_json, top_level_json_field
|
24
|
-
from langroid.pydantic_v1 import BaseModel, Extra
|
25
26
|
from langroid.utils.object_registry import ObjectRegistry
|
26
27
|
from langroid.utils.output.printing import shorten_text
|
27
28
|
from langroid.utils.types import to_string
|
@@ -29,8 +30,7 @@ from langroid.utils.types import to_string
|
|
29
30
|
|
30
31
|
class ChatDocAttachment(BaseModel):
|
31
32
|
# any additional data that should be attached to the document
|
32
|
-
|
33
|
-
extra = Extra.allow
|
33
|
+
model_config = ConfigDict(extra="allow")
|
34
34
|
|
35
35
|
|
36
36
|
class StatusCode(str, Enum):
|
@@ -89,7 +89,7 @@ class ChatDocLoggerFields(BaseModel):
|
|
89
89
|
|
90
90
|
@classmethod
|
91
91
|
def tsv_header(cls) -> str:
|
92
|
-
field_names = cls().
|
92
|
+
field_names = cls().model_dump().keys()
|
93
93
|
return "\t".join(field_names)
|
94
94
|
|
95
95
|
|
@@ -259,7 +259,7 @@ class ChatDocument(Document):
|
|
259
259
|
def tsv_str(self) -> str:
|
260
260
|
fields = self.log_fields()
|
261
261
|
fields.content = shorten_text(fields.content, 80)
|
262
|
-
field_values = fields.
|
262
|
+
field_values = fields.model_dump().values()
|
263
263
|
return "\t".join(str(v) for v in field_values)
|
264
264
|
|
265
265
|
def pop_tool_ids(self) -> None:
|
@@ -510,5 +510,5 @@ class ChatDocument(Document):
|
|
510
510
|
]
|
511
511
|
|
512
512
|
|
513
|
-
LLMMessage.
|
514
|
-
ChatDocMetaData.
|
513
|
+
LLMMessage.model_rebuild()
|
514
|
+
ChatDocMetaData.model_rebuild()
|
@@ -0,0 +1,513 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import copy
|
4
|
+
import json
|
5
|
+
from collections import OrderedDict
|
6
|
+
from enum import Enum
|
7
|
+
from typing import Any, Dict, List, Optional, Union, cast
|
8
|
+
|
9
|
+
from langroid.agent.tool_message import ToolMessage
|
10
|
+
from langroid.agent.xml_tool_message import XMLToolMessage
|
11
|
+
from langroid.language_models.base import (
|
12
|
+
LLMFunctionCall,
|
13
|
+
LLMMessage,
|
14
|
+
LLMResponse,
|
15
|
+
LLMTokenUsage,
|
16
|
+
OpenAIToolCall,
|
17
|
+
Role,
|
18
|
+
ToolChoiceTypes,
|
19
|
+
)
|
20
|
+
from langroid.mytypes import DocMetaData, Document, Entity
|
21
|
+
from langroid.parsing.agent_chats import parse_message
|
22
|
+
from langroid.parsing.file_attachment import FileAttachment
|
23
|
+
from langroid.parsing.parse_json import extract_top_level_json, top_level_json_field
|
24
|
+
from pydantic import BaseModel, ConfigDict
|
25
|
+
from langroid.utils.object_registry import ObjectRegistry
|
26
|
+
from langroid.utils.output.printing import shorten_text
|
27
|
+
from langroid.utils.types import to_string
|
28
|
+
|
29
|
+
|
30
|
+
class ChatDocAttachment(BaseModel):
|
31
|
+
# any additional data that should be attached to the document
|
32
|
+
model_config = ConfigDict(extra="allow")
|
33
|
+
|
34
|
+
|
35
|
+
class StatusCode(str, Enum):
|
36
|
+
"""Codes meant to be returned by task.run(). Some are not used yet."""
|
37
|
+
|
38
|
+
OK = "OK"
|
39
|
+
ERROR = "ERROR"
|
40
|
+
DONE = "DONE"
|
41
|
+
STALLED = "STALLED"
|
42
|
+
INF_LOOP = "INF_LOOP"
|
43
|
+
KILL = "KILL"
|
44
|
+
FIXED_TURNS = "FIXED_TURNS" # reached intended number of turns
|
45
|
+
MAX_TURNS = "MAX_TURNS" # hit max-turns limit
|
46
|
+
MAX_COST = "MAX_COST"
|
47
|
+
MAX_TOKENS = "MAX_TOKENS"
|
48
|
+
TIMEOUT = "TIMEOUT"
|
49
|
+
NO_ANSWER = "NO_ANSWER"
|
50
|
+
USER_QUIT = "USER_QUIT"
|
51
|
+
|
52
|
+
|
53
|
+
class ChatDocMetaData(DocMetaData):
|
54
|
+
parent_id: str = "" # msg (ChatDocument) to which this is a response
|
55
|
+
child_id: str = "" # ChatDocument that has response to this message
|
56
|
+
agent_id: str = "" # ChatAgent that generated this message
|
57
|
+
msg_idx: int = -1 # index of this message in the agent `message_history`
|
58
|
+
sender: Entity # sender of the message
|
59
|
+
# tool_id corresponding to single tool result in ChatDocument.content
|
60
|
+
oai_tool_id: str | None = None
|
61
|
+
tool_ids: List[str] = [] # stack of tool_ids; used by OpenAIAssistant
|
62
|
+
block: None | Entity = None
|
63
|
+
sender_name: str = ""
|
64
|
+
recipient: str = ""
|
65
|
+
usage: Optional[LLMTokenUsage] = None
|
66
|
+
cached: bool = False
|
67
|
+
displayed: bool = False
|
68
|
+
has_citation: bool = False
|
69
|
+
status: Optional[StatusCode] = None
|
70
|
+
|
71
|
+
@property
|
72
|
+
def parent(self) -> Optional["ChatDocument"]:
|
73
|
+
return ChatDocument.from_id(self.parent_id)
|
74
|
+
|
75
|
+
@property
|
76
|
+
def child(self) -> Optional["ChatDocument"]:
|
77
|
+
return ChatDocument.from_id(self.child_id)
|
78
|
+
|
79
|
+
|
80
|
+
class ChatDocLoggerFields(BaseModel):
|
81
|
+
sender_entity: Entity = Entity.USER
|
82
|
+
sender_name: str = ""
|
83
|
+
recipient: str = ""
|
84
|
+
block: Entity | None = None
|
85
|
+
tool_type: str = ""
|
86
|
+
tool: str = ""
|
87
|
+
content: str = ""
|
88
|
+
|
89
|
+
@classmethod
|
90
|
+
def tsv_header(cls) -> str:
|
91
|
+
field_names = cls().model_dump().keys()
|
92
|
+
return "\t".join(field_names)
|
93
|
+
|
94
|
+
|
95
|
+
class ChatDocument(Document):
|
96
|
+
"""
|
97
|
+
Represents a message in a conversation among agents. All responders of an agent
|
98
|
+
have signature ChatDocument -> ChatDocument (modulo None, str, etc),
|
99
|
+
and so does the Task.run() method.
|
100
|
+
|
101
|
+
Attributes:
|
102
|
+
oai_tool_calls (Optional[List[OpenAIToolCall]]):
|
103
|
+
Tool-calls from an OpenAI-compatible API
|
104
|
+
oai_tool_id2results (Optional[OrderedDict[str, str]]):
|
105
|
+
Results of tool-calls from OpenAI (dict is a map of tool_id -> result)
|
106
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, str]: Param controlling how the
|
107
|
+
LLM should choose tool-use in its response
|
108
|
+
(auto, none, required, or a specific tool)
|
109
|
+
function_call (Optional[LLMFunctionCall]):
|
110
|
+
Function-call from an OpenAI-compatible API
|
111
|
+
(deprecated by OpenAI, in favor of tool-calls)
|
112
|
+
tool_messages (List[ToolMessage]): Langroid ToolMessages extracted from
|
113
|
+
- `content` field (via JSON parsing),
|
114
|
+
- `oai_tool_calls`, or
|
115
|
+
- `function_call`
|
116
|
+
metadata (ChatDocMetaData): Metadata for the message, e.g. sender, recipient.
|
117
|
+
attachment (None | ChatDocAttachment): Any additional data attached.
|
118
|
+
"""
|
119
|
+
|
120
|
+
reasoning: str = "" # reasoning produced by a reasoning LLM
|
121
|
+
content_any: Any = None # to hold arbitrary data returned by responders
|
122
|
+
files: List[FileAttachment] = [] # list of file attachments
|
123
|
+
oai_tool_calls: Optional[List[OpenAIToolCall]] = None
|
124
|
+
oai_tool_id2result: Optional[OrderedDict[str, str]] = None
|
125
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto"
|
126
|
+
function_call: Optional[LLMFunctionCall] = None
|
127
|
+
# tools that are explicitly added by agent response/handler,
|
128
|
+
# or tools recognized in the ChatDocument as handle-able tools
|
129
|
+
tool_messages: List[ToolMessage] = []
|
130
|
+
# all known tools in the msg that are in an agent's llm_tools_known list,
|
131
|
+
# even if non-used/handled
|
132
|
+
all_tool_messages: List[ToolMessage] = []
|
133
|
+
|
134
|
+
metadata: ChatDocMetaData
|
135
|
+
attachment: None | ChatDocAttachment = None
|
136
|
+
|
137
|
+
def __init__(self, **data: Any):
|
138
|
+
super().__init__(**data)
|
139
|
+
ObjectRegistry.register_object(self)
|
140
|
+
|
141
|
+
@property
|
142
|
+
def parent(self) -> Optional["ChatDocument"]:
|
143
|
+
return ChatDocument.from_id(self.metadata.parent_id)
|
144
|
+
|
145
|
+
@property
|
146
|
+
def child(self) -> Optional["ChatDocument"]:
|
147
|
+
return ChatDocument.from_id(self.metadata.child_id)
|
148
|
+
|
149
|
+
@staticmethod
|
150
|
+
def deepcopy(doc: ChatDocument) -> ChatDocument:
|
151
|
+
new_doc = copy.deepcopy(doc)
|
152
|
+
new_doc.metadata.id = ObjectRegistry.new_id()
|
153
|
+
new_doc.metadata.child_id = ""
|
154
|
+
new_doc.metadata.parent_id = ""
|
155
|
+
ObjectRegistry.register_object(new_doc)
|
156
|
+
return new_doc
|
157
|
+
|
158
|
+
@staticmethod
|
159
|
+
def from_id(id: str) -> Optional["ChatDocument"]:
|
160
|
+
return cast(ChatDocument, ObjectRegistry.get(id))
|
161
|
+
|
162
|
+
@staticmethod
|
163
|
+
def delete_id(id: str) -> None:
|
164
|
+
"""Remove ChatDocument with given id from ObjectRegistry,
|
165
|
+
and all its descendants.
|
166
|
+
"""
|
167
|
+
chat_doc = ChatDocument.from_id(id)
|
168
|
+
# first delete all descendants
|
169
|
+
while chat_doc is not None:
|
170
|
+
next_chat_doc = chat_doc.child
|
171
|
+
ObjectRegistry.remove(chat_doc.id())
|
172
|
+
chat_doc = next_chat_doc
|
173
|
+
|
174
|
+
def __str__(self) -> str:
|
175
|
+
fields = self.log_fields()
|
176
|
+
tool_str = ""
|
177
|
+
if fields.tool_type != "":
|
178
|
+
tool_str = f"{fields.tool_type}[{fields.tool}]: "
|
179
|
+
recipient_str = ""
|
180
|
+
if fields.recipient != "":
|
181
|
+
recipient_str = f"=>{fields.recipient}: "
|
182
|
+
return (
|
183
|
+
f"{fields.sender_entity}[{fields.sender_name}] "
|
184
|
+
f"{recipient_str}{tool_str}{fields.content}"
|
185
|
+
)
|
186
|
+
|
187
|
+
def get_tool_names(self) -> List[str]:
|
188
|
+
"""
|
189
|
+
Get names of attempted tool usages (JSON or non-JSON) in the content
|
190
|
+
of the message.
|
191
|
+
Returns:
|
192
|
+
List[str]: list of *attempted* tool names
|
193
|
+
(We say "attempted" since we ONLY look at the `request` component of the
|
194
|
+
tool-call representation, and we're not fully parsing it into the
|
195
|
+
corresponding tool message class)
|
196
|
+
|
197
|
+
"""
|
198
|
+
tool_candidates = XMLToolMessage.find_candidates(self.content)
|
199
|
+
if len(tool_candidates) == 0:
|
200
|
+
tool_candidates = extract_top_level_json(self.content)
|
201
|
+
if len(tool_candidates) == 0:
|
202
|
+
return []
|
203
|
+
tools = [json.loads(tc).get("request") for tc in tool_candidates]
|
204
|
+
else:
|
205
|
+
tool_dicts = [
|
206
|
+
XMLToolMessage.extract_field_values(tc) for tc in tool_candidates
|
207
|
+
]
|
208
|
+
tools = [td.get("request") for td in tool_dicts if td is not None]
|
209
|
+
return [str(tool) for tool in tools if tool is not None]
|
210
|
+
|
211
|
+
def log_fields(self) -> ChatDocLoggerFields:
|
212
|
+
"""
|
213
|
+
Fields for logging in csv/tsv logger
|
214
|
+
Returns:
|
215
|
+
List[str]: list of fields
|
216
|
+
"""
|
217
|
+
tool_type = "" # FUNC or TOOL
|
218
|
+
tool = "" # tool name or function name
|
219
|
+
|
220
|
+
# Skip tool detection for system messages - they contain tool instructions,
|
221
|
+
# not actual tool calls
|
222
|
+
if self.metadata.sender != Entity.SYSTEM:
|
223
|
+
oai_tools = (
|
224
|
+
[]
|
225
|
+
if self.oai_tool_calls is None
|
226
|
+
else [t for t in self.oai_tool_calls if t.function is not None]
|
227
|
+
)
|
228
|
+
if self.function_call is not None:
|
229
|
+
tool_type = "FUNC"
|
230
|
+
tool = self.function_call.name
|
231
|
+
elif len(oai_tools) > 0:
|
232
|
+
tool_type = "OAI_TOOL"
|
233
|
+
tool = ",".join(t.function.name for t in oai_tools) # type: ignore
|
234
|
+
else:
|
235
|
+
try:
|
236
|
+
json_tools = self.get_tool_names()
|
237
|
+
except Exception:
|
238
|
+
json_tools = []
|
239
|
+
if json_tools != []:
|
240
|
+
tool_type = "TOOL"
|
241
|
+
tool = json_tools[0]
|
242
|
+
recipient = self.metadata.recipient
|
243
|
+
content = self.content
|
244
|
+
sender_entity = self.metadata.sender
|
245
|
+
sender_name = self.metadata.sender_name
|
246
|
+
if tool_type == "FUNC":
|
247
|
+
content += str(self.function_call)
|
248
|
+
return ChatDocLoggerFields(
|
249
|
+
sender_entity=sender_entity,
|
250
|
+
sender_name=sender_name,
|
251
|
+
recipient=recipient,
|
252
|
+
block=self.metadata.block,
|
253
|
+
tool_type=tool_type,
|
254
|
+
tool=tool,
|
255
|
+
content=content,
|
256
|
+
)
|
257
|
+
|
258
|
+
def tsv_str(self) -> str:
|
259
|
+
fields = self.log_fields()
|
260
|
+
fields.content = shorten_text(fields.content, 80)
|
261
|
+
field_values = fields.model_dump().values()
|
262
|
+
return "\t".join(str(v) for v in field_values)
|
263
|
+
|
264
|
+
def pop_tool_ids(self) -> None:
|
265
|
+
"""
|
266
|
+
Pop the last tool_id from the stack of tool_ids.
|
267
|
+
"""
|
268
|
+
if len(self.metadata.tool_ids) > 0:
|
269
|
+
self.metadata.tool_ids.pop()
|
270
|
+
|
271
|
+
@staticmethod
|
272
|
+
def _clean_fn_call(fc: LLMFunctionCall | None) -> None:
|
273
|
+
# Sometimes an OpenAI LLM (esp gpt-4o) may generate a function-call
|
274
|
+
# with oddities:
|
275
|
+
# (a) the `name` is set, as well as `arguments.request` is set,
|
276
|
+
# and in langroid we use the `request` value as the `name`.
|
277
|
+
# In this case we override the `name` with the `request` value.
|
278
|
+
# (b) the `name` looks like "functions blah" or just "functions"
|
279
|
+
# In this case we strip the "functions" part.
|
280
|
+
if fc is None:
|
281
|
+
return
|
282
|
+
fc.name = fc.name.replace("functions", "").strip()
|
283
|
+
if fc.arguments is not None:
|
284
|
+
request = fc.arguments.get("request")
|
285
|
+
if request is not None and request != "":
|
286
|
+
fc.name = request
|
287
|
+
fc.arguments.pop("request")
|
288
|
+
|
289
|
+
@staticmethod
|
290
|
+
def from_LLMResponse(
|
291
|
+
response: LLMResponse,
|
292
|
+
displayed: bool = False,
|
293
|
+
) -> "ChatDocument":
|
294
|
+
"""
|
295
|
+
Convert LLMResponse to ChatDocument.
|
296
|
+
Args:
|
297
|
+
response (LLMResponse): LLMResponse to convert.
|
298
|
+
displayed (bool): Whether this response was displayed to the user.
|
299
|
+
Returns:
|
300
|
+
ChatDocument: ChatDocument representation of this LLMResponse.
|
301
|
+
"""
|
302
|
+
recipient, message = response.get_recipient_and_message()
|
303
|
+
message = message.strip()
|
304
|
+
if message in ["''", '""']:
|
305
|
+
message = ""
|
306
|
+
if response.function_call is not None:
|
307
|
+
ChatDocument._clean_fn_call(response.function_call)
|
308
|
+
if response.oai_tool_calls is not None:
|
309
|
+
# there must be at least one if it's not None
|
310
|
+
for oai_tc in response.oai_tool_calls:
|
311
|
+
ChatDocument._clean_fn_call(oai_tc.function)
|
312
|
+
return ChatDocument(
|
313
|
+
content=message,
|
314
|
+
reasoning=response.reasoning,
|
315
|
+
content_any=message,
|
316
|
+
oai_tool_calls=response.oai_tool_calls,
|
317
|
+
function_call=response.function_call,
|
318
|
+
metadata=ChatDocMetaData(
|
319
|
+
source=Entity.LLM,
|
320
|
+
sender=Entity.LLM,
|
321
|
+
usage=response.usage,
|
322
|
+
displayed=displayed,
|
323
|
+
cached=response.cached,
|
324
|
+
recipient=recipient,
|
325
|
+
),
|
326
|
+
)
|
327
|
+
|
328
|
+
@staticmethod
|
329
|
+
def from_str(msg: str) -> "ChatDocument":
|
330
|
+
# first check whether msg is structured as TO <recipient>: <message>
|
331
|
+
recipient, message = parse_message(msg)
|
332
|
+
if recipient == "":
|
333
|
+
# check if any top level json specifies a 'recipient'
|
334
|
+
recipient = top_level_json_field(msg, "recipient")
|
335
|
+
message = msg # retain the whole msg in this case
|
336
|
+
return ChatDocument(
|
337
|
+
content=message,
|
338
|
+
content_any=message,
|
339
|
+
metadata=ChatDocMetaData(
|
340
|
+
source=Entity.USER,
|
341
|
+
sender=Entity.USER,
|
342
|
+
recipient=recipient,
|
343
|
+
),
|
344
|
+
)
|
345
|
+
|
346
|
+
@staticmethod
|
347
|
+
def from_LLMMessage(
|
348
|
+
message: LLMMessage,
|
349
|
+
sender_name: str = "",
|
350
|
+
recipient: str = "",
|
351
|
+
) -> "ChatDocument":
|
352
|
+
"""
|
353
|
+
Convert LLMMessage to ChatDocument.
|
354
|
+
|
355
|
+
Args:
|
356
|
+
message (LLMMessage): LLMMessage to convert.
|
357
|
+
sender_name (str): Name of the sender. Defaults to "".
|
358
|
+
recipient (str): Name of the recipient. Defaults to "".
|
359
|
+
|
360
|
+
Returns:
|
361
|
+
ChatDocument: ChatDocument representation of this LLMMessage.
|
362
|
+
"""
|
363
|
+
# Map LLMMessage Role to ChatDocument Entity
|
364
|
+
role_to_entity = {
|
365
|
+
Role.USER: Entity.USER,
|
366
|
+
Role.SYSTEM: Entity.SYSTEM,
|
367
|
+
Role.ASSISTANT: Entity.LLM,
|
368
|
+
Role.FUNCTION: Entity.LLM,
|
369
|
+
Role.TOOL: Entity.LLM,
|
370
|
+
}
|
371
|
+
|
372
|
+
sender_entity = role_to_entity.get(message.role, Entity.USER)
|
373
|
+
|
374
|
+
return ChatDocument(
|
375
|
+
content=message.content or "",
|
376
|
+
content_any=message.content,
|
377
|
+
files=message.files,
|
378
|
+
function_call=message.function_call,
|
379
|
+
oai_tool_calls=message.tool_calls,
|
380
|
+
metadata=ChatDocMetaData(
|
381
|
+
source=sender_entity,
|
382
|
+
sender=sender_entity,
|
383
|
+
sender_name=sender_name,
|
384
|
+
recipient=recipient,
|
385
|
+
oai_tool_id=message.tool_call_id,
|
386
|
+
tool_ids=[message.tool_id] if message.tool_id else [],
|
387
|
+
),
|
388
|
+
)
|
389
|
+
|
390
|
+
@staticmethod
|
391
|
+
def to_LLMMessage(
|
392
|
+
message: Union[str, "ChatDocument"],
|
393
|
+
oai_tools: Optional[List[OpenAIToolCall]] = None,
|
394
|
+
) -> List[LLMMessage]:
|
395
|
+
"""
|
396
|
+
Convert to list of LLMMessage, to incorporate into msg-history sent to LLM API.
|
397
|
+
Usually there will be just a single LLMMessage, but when the ChatDocument
|
398
|
+
contains results from multiple OpenAI tool-calls, we would have a sequence
|
399
|
+
LLMMessages, one per tool-call result.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
message (str|ChatDocument): Message to convert.
|
403
|
+
oai_tools (Optional[List[OpenAIToolCall]]): Tool-calls currently awaiting
|
404
|
+
response, from the ChatAgent's latest message.
|
405
|
+
Returns:
|
406
|
+
List[LLMMessage]: list of LLMMessages corresponding to this ChatDocument.
|
407
|
+
"""
|
408
|
+
|
409
|
+
sender_role = Role.USER
|
410
|
+
if isinstance(message, str):
|
411
|
+
message = ChatDocument.from_str(message)
|
412
|
+
content = message.content or to_string(message.content_any) or ""
|
413
|
+
fun_call = message.function_call
|
414
|
+
oai_tool_calls = message.oai_tool_calls
|
415
|
+
if message.metadata.sender == Entity.USER and fun_call is not None:
|
416
|
+
# This may happen when a (parent agent's) LLM generates a
|
417
|
+
# a Function-call, and it ends up being sent to the current task's
|
418
|
+
# LLM (possibly because the function-call is mis-named or has other
|
419
|
+
# issues and couldn't be handled by handler methods).
|
420
|
+
# But a function-call can only be generated by an entity with
|
421
|
+
# Role.ASSISTANT, so we instead put the content of the function-call
|
422
|
+
# in the content of the message.
|
423
|
+
content += " " + str(fun_call)
|
424
|
+
fun_call = None
|
425
|
+
if message.metadata.sender == Entity.USER and oai_tool_calls is not None:
|
426
|
+
# same reasoning as for function-call above
|
427
|
+
content += " " + "\n\n".join(str(tc) for tc in oai_tool_calls)
|
428
|
+
oai_tool_calls = None
|
429
|
+
# some LLM APIs (e.g. gemini) don't like empty msg
|
430
|
+
content = content or " "
|
431
|
+
sender_name = message.metadata.sender_name
|
432
|
+
tool_ids = message.metadata.tool_ids
|
433
|
+
tool_id = tool_ids[-1] if len(tool_ids) > 0 else ""
|
434
|
+
chat_document_id = message.id()
|
435
|
+
if message.metadata.sender == Entity.SYSTEM:
|
436
|
+
sender_role = Role.SYSTEM
|
437
|
+
if (
|
438
|
+
message.metadata.parent is not None
|
439
|
+
and message.metadata.parent.function_call is not None
|
440
|
+
):
|
441
|
+
# This is a response to a function call, so set the role to FUNCTION.
|
442
|
+
sender_role = Role.FUNCTION
|
443
|
+
sender_name = message.metadata.parent.function_call.name
|
444
|
+
elif oai_tools is not None and len(oai_tools) > 0:
|
445
|
+
pending_tool_ids = [tc.id for tc in oai_tools]
|
446
|
+
# The ChatAgent has pending OpenAI tool-call(s),
|
447
|
+
# so the current ChatDocument contains
|
448
|
+
# results for some/all/none of them.
|
449
|
+
|
450
|
+
if len(oai_tools) == 1:
|
451
|
+
# Case 1:
|
452
|
+
# There was exactly 1 pending tool-call, and in this case
|
453
|
+
# the result would be a plain string in `content`
|
454
|
+
return [
|
455
|
+
LLMMessage(
|
456
|
+
role=Role.TOOL,
|
457
|
+
tool_call_id=oai_tools[0].id,
|
458
|
+
content=content,
|
459
|
+
files=message.files,
|
460
|
+
chat_document_id=chat_document_id,
|
461
|
+
)
|
462
|
+
]
|
463
|
+
|
464
|
+
elif (
|
465
|
+
message.metadata.oai_tool_id is not None
|
466
|
+
and message.metadata.oai_tool_id in pending_tool_ids
|
467
|
+
):
|
468
|
+
# Case 2:
|
469
|
+
# ChatDocument.content has result of a single tool-call
|
470
|
+
return [
|
471
|
+
LLMMessage(
|
472
|
+
role=Role.TOOL,
|
473
|
+
tool_call_id=message.metadata.oai_tool_id,
|
474
|
+
content=content,
|
475
|
+
files=message.files,
|
476
|
+
chat_document_id=chat_document_id,
|
477
|
+
)
|
478
|
+
]
|
479
|
+
elif message.oai_tool_id2result is not None:
|
480
|
+
# Case 2:
|
481
|
+
# There were > 1 tool-calls awaiting response,
|
482
|
+
assert (
|
483
|
+
len(message.oai_tool_id2result) > 1
|
484
|
+
), "oai_tool_id2result must have more than 1 item."
|
485
|
+
return [
|
486
|
+
LLMMessage(
|
487
|
+
role=Role.TOOL,
|
488
|
+
tool_call_id=tool_id,
|
489
|
+
content=result or " ",
|
490
|
+
files=message.files,
|
491
|
+
chat_document_id=chat_document_id,
|
492
|
+
)
|
493
|
+
for tool_id, result in message.oai_tool_id2result.items()
|
494
|
+
]
|
495
|
+
elif message.metadata.sender == Entity.LLM:
|
496
|
+
sender_role = Role.ASSISTANT
|
497
|
+
|
498
|
+
return [
|
499
|
+
LLMMessage(
|
500
|
+
role=sender_role,
|
501
|
+
tool_id=tool_id, # for OpenAI Assistant
|
502
|
+
content=content,
|
503
|
+
files=message.files,
|
504
|
+
function_call=fun_call,
|
505
|
+
tool_calls=oai_tool_calls,
|
506
|
+
name=sender_name,
|
507
|
+
chat_document_id=chat_document_id,
|
508
|
+
)
|
509
|
+
]
|
510
|
+
|
511
|
+
|
512
|
+
LLMMessage.update_forward_refs()
|
513
|
+
ChatDocMetaData.update_forward_refs()
|
@@ -15,6 +15,7 @@ from openai.types.beta.assistant_update_params import (
|
|
15
15
|
)
|
16
16
|
from openai.types.beta.threads import Message, Run
|
17
17
|
from openai.types.beta.threads.runs import RunStep
|
18
|
+
from pydantic import BaseModel
|
18
19
|
from rich import print
|
19
20
|
|
20
21
|
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
@@ -26,7 +27,6 @@ from langroid.language_models.openai_gpt import (
|
|
26
27
|
OpenAIGPT,
|
27
28
|
OpenAIGPTConfig,
|
28
29
|
)
|
29
|
-
from langroid.pydantic_v1 import BaseModel
|
30
30
|
from langroid.utils.configuration import settings
|
31
31
|
from langroid.utils.system import generate_user_id, update_hash
|
32
32
|
|
@@ -44,7 +44,7 @@ class AssistantTool(BaseModel):
|
|
44
44
|
function: Dict[str, Any] | None = None
|
45
45
|
|
46
46
|
def dct(self) -> Dict[str, Any]:
|
47
|
-
d = super().
|
47
|
+
d = super().model_dump()
|
48
48
|
d["type"] = d["type"].value
|
49
49
|
if self.type != ToolType.FUNCTION:
|
50
50
|
d.pop("function")
|
@@ -72,14 +72,14 @@ class RunStatus(str, Enum):
|
|
72
72
|
class OpenAIAssistantConfig(ChatAgentConfig):
|
73
73
|
use_cached_assistant: bool = False # set in script via user dialog
|
74
74
|
assistant_id: str | None = None
|
75
|
-
use_tools = False
|
76
|
-
use_functions_api = True
|
75
|
+
use_tools: bool = False
|
76
|
+
use_functions_api: bool = True
|
77
77
|
use_cached_thread: bool = False # set in script via user dialog
|
78
78
|
thread_id: str | None = None
|
79
79
|
# set to True once we can add Assistant msgs in threads
|
80
80
|
cache_responses: bool = True
|
81
81
|
timeout: int = 30 # can be different from llm.timeout
|
82
|
-
llm = OpenAIGPTConfig(chat_model=OpenAIChatModel.GPT4o)
|
82
|
+
llm: OpenAIGPTConfig = OpenAIGPTConfig(chat_model=OpenAIChatModel.GPT4o)
|
83
83
|
tools: List[AssistantTool] = []
|
84
84
|
files: List[str] = []
|
85
85
|
|
@@ -214,7 +214,7 @@ class OpenAIAssistant(ChatAgent):
|
|
214
214
|
[
|
215
215
|
{
|
216
216
|
"type": "function", # type: ignore
|
217
|
-
"function": f.
|
217
|
+
"function": f.model_dump(),
|
218
218
|
}
|
219
219
|
for f in functions
|
220
220
|
]
|
@@ -272,7 +272,7 @@ class OpenAIAssistant(ChatAgent):
|
|
272
272
|
cached_dict = self.llm.cache.retrieve(key)
|
273
273
|
if cached_dict is None:
|
274
274
|
return None
|
275
|
-
return LLMResponse.
|
275
|
+
return LLMResponse.model_validate(cached_dict)
|
276
276
|
|
277
277
|
def _cache_store(self) -> None:
|
278
278
|
"""
|
@@ -638,7 +638,7 @@ class OpenAIAssistant(ChatAgent):
|
|
638
638
|
cached=False, # TODO - revisit when able to insert Assistant responses
|
639
639
|
)
|
640
640
|
if self.llm.cache is not None:
|
641
|
-
self.llm.cache.store(key, result.
|
641
|
+
self.llm.cache.store(key, result.model_dump())
|
642
642
|
return result
|
643
643
|
|
644
644
|
def _parse_run_required_action(self) -> List[AssistantToolCall]:
|
@@ -773,7 +773,7 @@ class OpenAIAssistant(ChatAgent):
|
|
773
773
|
# it looks like assistant produced it
|
774
774
|
if self.config.cache_responses:
|
775
775
|
self._add_thread_message(
|
776
|
-
json.dumps(response.
|
776
|
+
json.dumps(response.model_dump()), role=Role.ASSISTANT
|
777
777
|
)
|
778
778
|
return response # type: ignore
|
779
779
|
else:
|