langroid 0.33.4__py3-none-any.whl → 0.33.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +106 -0
- langroid/agent/__init__.py +41 -0
- langroid/agent/base.py +1983 -0
- langroid/agent/batch.py +398 -0
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +598 -0
- langroid/agent/chat_agent.py +1899 -0
- langroid/agent/chat_document.py +454 -0
- langroid/agent/openai_assistant.py +882 -0
- langroid/agent/special/__init__.py +59 -0
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +656 -0
- langroid/agent/special/arangodb/system_messages.py +186 -0
- langroid/agent/special/arangodb/tools.py +107 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/doc_chat_agent.py +1466 -0
- langroid/agent/special/lance_doc_chat_agent.py +262 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +198 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
- langroid/agent/special/lance_tools.py +61 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +56 -0
- langroid/agent/special/sql/__init__.py +17 -0
- langroid/agent/special/sql/sql_chat_agent.py +654 -0
- langroid/agent/special/sql/utils/__init__.py +21 -0
- langroid/agent/special/sql/utils/description_extractors.py +190 -0
- langroid/agent/special/sql/utils/populate_metadata.py +85 -0
- langroid/agent/special/sql/utils/system_message.py +35 -0
- langroid/agent/special/sql/utils/tools.py +64 -0
- langroid/agent/special/table_chat_agent.py +263 -0
- langroid/agent/task.py +2095 -0
- langroid/agent/tool_message.py +393 -0
- langroid/agent/tools/__init__.py +38 -0
- langroid/agent/tools/duckduckgo_search_tool.py +50 -0
- langroid/agent/tools/file_tools.py +234 -0
- langroid/agent/tools/google_search_tool.py +39 -0
- langroid/agent/tools/metaphor_search_tool.py +68 -0
- langroid/agent/tools/orchestration.py +303 -0
- langroid/agent/tools/recipient_tool.py +235 -0
- langroid/agent/tools/retrieval_tool.py +32 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/agent/tools/segment_extract_tool.py +41 -0
- langroid/agent/xml_tool_message.py +382 -0
- langroid/cachedb/__init__.py +17 -0
- langroid/cachedb/base.py +58 -0
- langroid/cachedb/momento_cachedb.py +108 -0
- langroid/cachedb/redis_cachedb.py +153 -0
- langroid/embedding_models/__init__.py +39 -0
- langroid/embedding_models/base.py +74 -0
- langroid/embedding_models/models.py +461 -0
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/exceptions.py +71 -0
- langroid/language_models/__init__.py +53 -0
- langroid/language_models/azure_openai.py +153 -0
- langroid/language_models/base.py +678 -0
- langroid/language_models/config.py +18 -0
- langroid/language_models/mock_lm.py +124 -0
- langroid/language_models/openai_gpt.py +1964 -0
- langroid/language_models/prompt_formatter/__init__.py +16 -0
- langroid/language_models/prompt_formatter/base.py +40 -0
- langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
- langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
- langroid/language_models/utils.py +151 -0
- langroid/mytypes.py +84 -0
- langroid/parsing/__init__.py +52 -0
- langroid/parsing/agent_chats.py +38 -0
- langroid/parsing/code_parser.py +121 -0
- langroid/parsing/document_parser.py +718 -0
- langroid/parsing/para_sentence_split.py +62 -0
- langroid/parsing/parse_json.py +155 -0
- langroid/parsing/parser.py +313 -0
- langroid/parsing/repo_loader.py +790 -0
- langroid/parsing/routing.py +36 -0
- langroid/parsing/search.py +275 -0
- langroid/parsing/spider.py +102 -0
- langroid/parsing/table_loader.py +94 -0
- langroid/parsing/url_loader.py +111 -0
- langroid/parsing/urls.py +273 -0
- langroid/parsing/utils.py +373 -0
- langroid/parsing/web_search.py +156 -0
- langroid/prompts/__init__.py +9 -0
- langroid/prompts/dialog.py +17 -0
- langroid/prompts/prompts_config.py +5 -0
- langroid/prompts/templates.py +141 -0
- langroid/pydantic_v1/__init__.py +10 -0
- langroid/pydantic_v1/main.py +4 -0
- langroid/utils/__init__.py +19 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +98 -0
- langroid/utils/constants.py +30 -0
- langroid/utils/git_utils.py +252 -0
- langroid/utils/globals.py +49 -0
- langroid/utils/logging.py +135 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/output/__init__.py +20 -0
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +99 -0
- langroid/utils/output/status.py +40 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +602 -0
- langroid/utils/system.py +286 -0
- langroid/utils/types.py +93 -0
- langroid/vector_store/__init__.py +50 -0
- langroid/vector_store/base.py +359 -0
- langroid/vector_store/chromadb.py +214 -0
- langroid/vector_store/lancedb.py +406 -0
- langroid/vector_store/meilisearch.py +299 -0
- langroid/vector_store/momento.py +278 -0
- langroid/vector_store/qdrantdb.py +468 -0
- {langroid-0.33.4.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
- langroid-0.33.7.dist-info/RECORD +127 -0
- {langroid-0.33.4.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
- langroid-0.33.4.dist-info/RECORD +0 -7
- langroid-0.33.4.dist-info/entry_points.txt +0 -4
- pyproject.toml +0 -356
- {langroid-0.33.4.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,454 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import copy
|
4
|
+
import json
|
5
|
+
from collections import OrderedDict
|
6
|
+
from enum import Enum
|
7
|
+
from typing import Any, Dict, List, Optional, Union, cast
|
8
|
+
|
9
|
+
from langroid.agent.tool_message import ToolMessage
|
10
|
+
from langroid.agent.xml_tool_message import XMLToolMessage
|
11
|
+
from langroid.language_models.base import (
|
12
|
+
LLMFunctionCall,
|
13
|
+
LLMMessage,
|
14
|
+
LLMResponse,
|
15
|
+
LLMTokenUsage,
|
16
|
+
OpenAIToolCall,
|
17
|
+
Role,
|
18
|
+
ToolChoiceTypes,
|
19
|
+
)
|
20
|
+
from langroid.mytypes import DocMetaData, Document, Entity
|
21
|
+
from langroid.parsing.agent_chats import parse_message
|
22
|
+
from langroid.parsing.parse_json import extract_top_level_json, top_level_json_field
|
23
|
+
from langroid.pydantic_v1 import BaseModel, Extra
|
24
|
+
from langroid.utils.object_registry import ObjectRegistry
|
25
|
+
from langroid.utils.output.printing import shorten_text
|
26
|
+
from langroid.utils.types import to_string
|
27
|
+
|
28
|
+
|
29
|
+
class ChatDocAttachment(BaseModel):
|
30
|
+
# any additional data that should be attached to the document
|
31
|
+
class Config:
|
32
|
+
extra = Extra.allow
|
33
|
+
|
34
|
+
|
35
|
+
class StatusCode(str, Enum):
|
36
|
+
"""Codes meant to be returned by task.run(). Some are not used yet."""
|
37
|
+
|
38
|
+
OK = "OK"
|
39
|
+
ERROR = "ERROR"
|
40
|
+
DONE = "DONE"
|
41
|
+
STALLED = "STALLED"
|
42
|
+
INF_LOOP = "INF_LOOP"
|
43
|
+
KILL = "KILL"
|
44
|
+
FIXED_TURNS = "FIXED_TURNS" # reached intended number of turns
|
45
|
+
MAX_TURNS = "MAX_TURNS" # hit max-turns limit
|
46
|
+
MAX_COST = "MAX_COST"
|
47
|
+
MAX_TOKENS = "MAX_TOKENS"
|
48
|
+
TIMEOUT = "TIMEOUT"
|
49
|
+
NO_ANSWER = "NO_ANSWER"
|
50
|
+
USER_QUIT = "USER_QUIT"
|
51
|
+
|
52
|
+
|
53
|
+
class ChatDocMetaData(DocMetaData):
|
54
|
+
parent_id: str = "" # msg (ChatDocument) to which this is a response
|
55
|
+
child_id: str = "" # ChatDocument that has response to this message
|
56
|
+
agent_id: str = "" # ChatAgent that generated this message
|
57
|
+
msg_idx: int = -1 # index of this message in the agent `message_history`
|
58
|
+
sender: Entity # sender of the message
|
59
|
+
# tool_id corresponding to single tool result in ChatDocument.content
|
60
|
+
oai_tool_id: str | None = None
|
61
|
+
tool_ids: List[str] = [] # stack of tool_ids; used by OpenAIAssistant
|
62
|
+
block: None | Entity = None
|
63
|
+
sender_name: str = ""
|
64
|
+
recipient: str = ""
|
65
|
+
usage: Optional[LLMTokenUsage] = None
|
66
|
+
cached: bool = False
|
67
|
+
displayed: bool = False
|
68
|
+
has_citation: bool = False
|
69
|
+
status: Optional[StatusCode] = None
|
70
|
+
|
71
|
+
@property
|
72
|
+
def parent(self) -> Optional["ChatDocument"]:
|
73
|
+
return ChatDocument.from_id(self.parent_id)
|
74
|
+
|
75
|
+
@property
|
76
|
+
def child(self) -> Optional["ChatDocument"]:
|
77
|
+
return ChatDocument.from_id(self.child_id)
|
78
|
+
|
79
|
+
|
80
|
+
class ChatDocLoggerFields(BaseModel):
|
81
|
+
sender_entity: Entity = Entity.USER
|
82
|
+
sender_name: str = ""
|
83
|
+
recipient: str = ""
|
84
|
+
block: Entity | None = None
|
85
|
+
tool_type: str = ""
|
86
|
+
tool: str = ""
|
87
|
+
content: str = ""
|
88
|
+
|
89
|
+
@classmethod
|
90
|
+
def tsv_header(cls) -> str:
|
91
|
+
field_names = cls().dict().keys()
|
92
|
+
return "\t".join(field_names)
|
93
|
+
|
94
|
+
|
95
|
+
class ChatDocument(Document):
|
96
|
+
"""
|
97
|
+
Represents a message in a conversation among agents. All responders of an agent
|
98
|
+
have signature ChatDocument -> ChatDocument (modulo None, str, etc),
|
99
|
+
and so does the Task.run() method.
|
100
|
+
|
101
|
+
Attributes:
|
102
|
+
oai_tool_calls (Optional[List[OpenAIToolCall]]):
|
103
|
+
Tool-calls from an OpenAI-compatible API
|
104
|
+
oai_tool_id2results (Optional[OrderedDict[str, str]]):
|
105
|
+
Results of tool-calls from OpenAI (dict is a map of tool_id -> result)
|
106
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, str]: Param controlling how the
|
107
|
+
LLM should choose tool-use in its response
|
108
|
+
(auto, none, required, or a specific tool)
|
109
|
+
function_call (Optional[LLMFunctionCall]):
|
110
|
+
Function-call from an OpenAI-compatible API
|
111
|
+
(deprecated by OpenAI, in favor of tool-calls)
|
112
|
+
tool_messages (List[ToolMessage]): Langroid ToolMessages extracted from
|
113
|
+
- `content` field (via JSON parsing),
|
114
|
+
- `oai_tool_calls`, or
|
115
|
+
- `function_call`
|
116
|
+
metadata (ChatDocMetaData): Metadata for the message, e.g. sender, recipient.
|
117
|
+
attachment (None | ChatDocAttachment): Any additional data attached.
|
118
|
+
"""
|
119
|
+
|
120
|
+
content_any: Any = None # to hold arbitrary data returned by responders
|
121
|
+
oai_tool_calls: Optional[List[OpenAIToolCall]] = None
|
122
|
+
oai_tool_id2result: Optional[OrderedDict[str, str]] = None
|
123
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto"
|
124
|
+
function_call: Optional[LLMFunctionCall] = None
|
125
|
+
# tools that are explicitly added by agent response/handler,
|
126
|
+
# or tools recognized in the ChatDocument as handle-able tools
|
127
|
+
tool_messages: List[ToolMessage] = []
|
128
|
+
# all known tools in the msg that are in an agent's llm_tools_known list,
|
129
|
+
# even if non-used/handled
|
130
|
+
all_tool_messages: List[ToolMessage] = []
|
131
|
+
|
132
|
+
metadata: ChatDocMetaData
|
133
|
+
attachment: None | ChatDocAttachment = None
|
134
|
+
|
135
|
+
def __init__(self, **data: Any):
|
136
|
+
super().__init__(**data)
|
137
|
+
ObjectRegistry.register_object(self)
|
138
|
+
|
139
|
+
@property
|
140
|
+
def parent(self) -> Optional["ChatDocument"]:
|
141
|
+
return ChatDocument.from_id(self.metadata.parent_id)
|
142
|
+
|
143
|
+
@property
|
144
|
+
def child(self) -> Optional["ChatDocument"]:
|
145
|
+
return ChatDocument.from_id(self.metadata.child_id)
|
146
|
+
|
147
|
+
@staticmethod
|
148
|
+
def deepcopy(doc: ChatDocument) -> ChatDocument:
|
149
|
+
new_doc = copy.deepcopy(doc)
|
150
|
+
new_doc.metadata.id = ObjectRegistry.new_id()
|
151
|
+
new_doc.metadata.child_id = ""
|
152
|
+
new_doc.metadata.parent_id = ""
|
153
|
+
ObjectRegistry.register_object(new_doc)
|
154
|
+
return new_doc
|
155
|
+
|
156
|
+
@staticmethod
|
157
|
+
def from_id(id: str) -> Optional["ChatDocument"]:
|
158
|
+
return cast(ChatDocument, ObjectRegistry.get(id))
|
159
|
+
|
160
|
+
@staticmethod
|
161
|
+
def delete_id(id: str) -> None:
|
162
|
+
"""Remove ChatDocument with given id from ObjectRegistry,
|
163
|
+
and all its descendants.
|
164
|
+
"""
|
165
|
+
chat_doc = ChatDocument.from_id(id)
|
166
|
+
# first delete all descendants
|
167
|
+
while chat_doc is not None:
|
168
|
+
next_chat_doc = chat_doc.child
|
169
|
+
ObjectRegistry.remove(chat_doc.id())
|
170
|
+
chat_doc = next_chat_doc
|
171
|
+
|
172
|
+
def __str__(self) -> str:
|
173
|
+
fields = self.log_fields()
|
174
|
+
tool_str = ""
|
175
|
+
if fields.tool_type != "":
|
176
|
+
tool_str = f"{fields.tool_type}[{fields.tool}]: "
|
177
|
+
recipient_str = ""
|
178
|
+
if fields.recipient != "":
|
179
|
+
recipient_str = f"=>{fields.recipient}: "
|
180
|
+
return (
|
181
|
+
f"{fields.sender_entity}[{fields.sender_name}] "
|
182
|
+
f"{recipient_str}{tool_str}{fields.content}"
|
183
|
+
)
|
184
|
+
|
185
|
+
def get_tool_names(self) -> List[str]:
|
186
|
+
"""
|
187
|
+
Get names of attempted tool usages (JSON or non-JSON) in the content
|
188
|
+
of the message.
|
189
|
+
Returns:
|
190
|
+
List[str]: list of *attempted* tool names
|
191
|
+
(We say "attempted" since we ONLY look at the `request` component of the
|
192
|
+
tool-call representation, and we're not fully parsing it into the
|
193
|
+
corresponding tool message class)
|
194
|
+
|
195
|
+
"""
|
196
|
+
tool_candidates = XMLToolMessage.find_candidates(self.content)
|
197
|
+
if len(tool_candidates) == 0:
|
198
|
+
tool_candidates = extract_top_level_json(self.content)
|
199
|
+
if len(tool_candidates) == 0:
|
200
|
+
return []
|
201
|
+
tools = [json.loads(tc).get("request") for tc in tool_candidates]
|
202
|
+
else:
|
203
|
+
tool_dicts = [
|
204
|
+
XMLToolMessage.extract_field_values(tc) for tc in tool_candidates
|
205
|
+
]
|
206
|
+
tools = [td.get("request") for td in tool_dicts if td is not None]
|
207
|
+
return [str(tool) for tool in tools if tool is not None]
|
208
|
+
|
209
|
+
def log_fields(self) -> ChatDocLoggerFields:
|
210
|
+
"""
|
211
|
+
Fields for logging in csv/tsv logger
|
212
|
+
Returns:
|
213
|
+
List[str]: list of fields
|
214
|
+
"""
|
215
|
+
tool_type = "" # FUNC or TOOL
|
216
|
+
tool = "" # tool name or function name
|
217
|
+
if self.function_call is not None:
|
218
|
+
tool_type = "FUNC"
|
219
|
+
tool = self.function_call.name
|
220
|
+
else:
|
221
|
+
try:
|
222
|
+
json_tools = self.get_tool_names()
|
223
|
+
except Exception:
|
224
|
+
json_tools = []
|
225
|
+
if json_tools != []:
|
226
|
+
tool_type = "TOOL"
|
227
|
+
tool = json_tools[0]
|
228
|
+
recipient = self.metadata.recipient
|
229
|
+
content = self.content
|
230
|
+
sender_entity = self.metadata.sender
|
231
|
+
sender_name = self.metadata.sender_name
|
232
|
+
if tool_type == "FUNC":
|
233
|
+
content += str(self.function_call)
|
234
|
+
return ChatDocLoggerFields(
|
235
|
+
sender_entity=sender_entity,
|
236
|
+
sender_name=sender_name,
|
237
|
+
recipient=recipient,
|
238
|
+
block=self.metadata.block,
|
239
|
+
tool_type=tool_type,
|
240
|
+
tool=tool,
|
241
|
+
content=content,
|
242
|
+
)
|
243
|
+
|
244
|
+
def tsv_str(self) -> str:
|
245
|
+
fields = self.log_fields()
|
246
|
+
fields.content = shorten_text(fields.content, 80)
|
247
|
+
field_values = fields.dict().values()
|
248
|
+
return "\t".join(str(v) for v in field_values)
|
249
|
+
|
250
|
+
def pop_tool_ids(self) -> None:
|
251
|
+
"""
|
252
|
+
Pop the last tool_id from the stack of tool_ids.
|
253
|
+
"""
|
254
|
+
if len(self.metadata.tool_ids) > 0:
|
255
|
+
self.metadata.tool_ids.pop()
|
256
|
+
|
257
|
+
@staticmethod
|
258
|
+
def _clean_fn_call(fc: LLMFunctionCall | None) -> None:
|
259
|
+
# Sometimes an OpenAI LLM (esp gpt-4o) may generate a function-call
|
260
|
+
# with odditities:
|
261
|
+
# (a) the `name` is set, as well as `arguments.request` is set,
|
262
|
+
# and in langroid we use the `request` value as the `name`.
|
263
|
+
# In this case we override the `name` with the `request` value.
|
264
|
+
# (b) the `name` looks like "functions blah" or just "functions"
|
265
|
+
# In this case we strip the "functions" part.
|
266
|
+
if fc is None:
|
267
|
+
return
|
268
|
+
fc.name = fc.name.replace("functions", "").strip()
|
269
|
+
if fc.arguments is not None:
|
270
|
+
request = fc.arguments.get("request")
|
271
|
+
if request is not None and request != "":
|
272
|
+
fc.name = request
|
273
|
+
fc.arguments.pop("request")
|
274
|
+
|
275
|
+
@staticmethod
|
276
|
+
def from_LLMResponse(
|
277
|
+
response: LLMResponse,
|
278
|
+
displayed: bool = False,
|
279
|
+
) -> "ChatDocument":
|
280
|
+
"""
|
281
|
+
Convert LLMResponse to ChatDocument.
|
282
|
+
Args:
|
283
|
+
response (LLMResponse): LLMResponse to convert.
|
284
|
+
displayed (bool): Whether this response was displayed to the user.
|
285
|
+
Returns:
|
286
|
+
ChatDocument: ChatDocument representation of this LLMResponse.
|
287
|
+
"""
|
288
|
+
recipient, message = response.get_recipient_and_message()
|
289
|
+
message = message.strip()
|
290
|
+
if message in ["''", '""']:
|
291
|
+
message = ""
|
292
|
+
if response.function_call is not None:
|
293
|
+
ChatDocument._clean_fn_call(response.function_call)
|
294
|
+
if response.oai_tool_calls is not None:
|
295
|
+
# there must be at least one if it's not None
|
296
|
+
for oai_tc in response.oai_tool_calls:
|
297
|
+
ChatDocument._clean_fn_call(oai_tc.function)
|
298
|
+
return ChatDocument(
|
299
|
+
content=message,
|
300
|
+
content_any=message,
|
301
|
+
oai_tool_calls=response.oai_tool_calls,
|
302
|
+
function_call=response.function_call,
|
303
|
+
metadata=ChatDocMetaData(
|
304
|
+
source=Entity.LLM,
|
305
|
+
sender=Entity.LLM,
|
306
|
+
usage=response.usage,
|
307
|
+
displayed=displayed,
|
308
|
+
cached=response.cached,
|
309
|
+
recipient=recipient,
|
310
|
+
),
|
311
|
+
)
|
312
|
+
|
313
|
+
@staticmethod
|
314
|
+
def from_str(msg: str) -> "ChatDocument":
|
315
|
+
# first check whether msg is structured as TO <recipient>: <message>
|
316
|
+
recipient, message = parse_message(msg)
|
317
|
+
if recipient == "":
|
318
|
+
# check if any top level json specifies a 'recipient'
|
319
|
+
recipient = top_level_json_field(msg, "recipient")
|
320
|
+
message = msg # retain the whole msg in this case
|
321
|
+
return ChatDocument(
|
322
|
+
content=message,
|
323
|
+
content_any=message,
|
324
|
+
metadata=ChatDocMetaData(
|
325
|
+
source=Entity.USER,
|
326
|
+
sender=Entity.USER,
|
327
|
+
recipient=recipient,
|
328
|
+
),
|
329
|
+
)
|
330
|
+
|
331
|
+
@staticmethod
|
332
|
+
def to_LLMMessage(
|
333
|
+
message: Union[str, "ChatDocument"],
|
334
|
+
oai_tools: Optional[List[OpenAIToolCall]] = None,
|
335
|
+
) -> List[LLMMessage]:
|
336
|
+
"""
|
337
|
+
Convert to list of LLMMessage, to incorporate into msg-history sent to LLM API.
|
338
|
+
Usually there will be just a single LLMMessage, but when the ChatDocument
|
339
|
+
contains results from multiple OpenAI tool-calls, we would have a sequence
|
340
|
+
LLMMessages, one per tool-call result.
|
341
|
+
|
342
|
+
Args:
|
343
|
+
message (str|ChatDocument): Message to convert.
|
344
|
+
oai_tools (Optional[List[OpenAIToolCall]]): Tool-calls currently awaiting
|
345
|
+
response, from the ChatAgent's latest message.
|
346
|
+
Returns:
|
347
|
+
List[LLMMessage]: list of LLMMessages corresponding to this ChatDocument.
|
348
|
+
"""
|
349
|
+
sender_name = None
|
350
|
+
sender_role = Role.USER
|
351
|
+
fun_call = None
|
352
|
+
oai_tool_calls = None
|
353
|
+
tool_id = "" # for OpenAI Assistant
|
354
|
+
chat_document_id: str = ""
|
355
|
+
if isinstance(message, ChatDocument):
|
356
|
+
content = message.content or to_string(message.content_any) or ""
|
357
|
+
fun_call = message.function_call
|
358
|
+
oai_tool_calls = message.oai_tool_calls
|
359
|
+
if message.metadata.sender == Entity.USER and fun_call is not None:
|
360
|
+
# This may happen when a (parent agent's) LLM generates a
|
361
|
+
# a Function-call, and it ends up being sent to the current task's
|
362
|
+
# LLM (possibly because the function-call is mis-named or has other
|
363
|
+
# issues and couldn't be handled by handler methods).
|
364
|
+
# But a function-call can only be generated by an entity with
|
365
|
+
# Role.ASSISTANT, so we instead put the content of the function-call
|
366
|
+
# in the content of the message.
|
367
|
+
content += " " + str(fun_call)
|
368
|
+
fun_call = None
|
369
|
+
if message.metadata.sender == Entity.USER and oai_tool_calls is not None:
|
370
|
+
# same reasoning as for function-call above
|
371
|
+
content += " " + "\n\n".join(str(tc) for tc in oai_tool_calls)
|
372
|
+
oai_tool_calls = None
|
373
|
+
sender_name = message.metadata.sender_name
|
374
|
+
tool_ids = message.metadata.tool_ids
|
375
|
+
tool_id = tool_ids[-1] if len(tool_ids) > 0 else ""
|
376
|
+
chat_document_id = message.id()
|
377
|
+
if message.metadata.sender == Entity.SYSTEM:
|
378
|
+
sender_role = Role.SYSTEM
|
379
|
+
if (
|
380
|
+
message.metadata.parent is not None
|
381
|
+
and message.metadata.parent.function_call is not None
|
382
|
+
):
|
383
|
+
# This is a response to a function call, so set the role to FUNCTION.
|
384
|
+
sender_role = Role.FUNCTION
|
385
|
+
sender_name = message.metadata.parent.function_call.name
|
386
|
+
elif oai_tools is not None and len(oai_tools) > 0:
|
387
|
+
pending_tool_ids = [tc.id for tc in oai_tools]
|
388
|
+
# The ChatAgent has pending OpenAI tool-call(s),
|
389
|
+
# so the current ChatDocument contains
|
390
|
+
# results for some/all/none of them.
|
391
|
+
|
392
|
+
if len(oai_tools) == 1:
|
393
|
+
# Case 1:
|
394
|
+
# There was exactly 1 pending tool-call, and in this case
|
395
|
+
# the result would be a plain string in `content`
|
396
|
+
return [
|
397
|
+
LLMMessage(
|
398
|
+
role=Role.TOOL,
|
399
|
+
tool_call_id=oai_tools[0].id,
|
400
|
+
content=content,
|
401
|
+
chat_document_id=chat_document_id,
|
402
|
+
)
|
403
|
+
]
|
404
|
+
|
405
|
+
elif (
|
406
|
+
message.metadata.oai_tool_id is not None
|
407
|
+
and message.metadata.oai_tool_id in pending_tool_ids
|
408
|
+
):
|
409
|
+
# Case 2:
|
410
|
+
# ChatDocument.content has result of a single tool-call
|
411
|
+
return [
|
412
|
+
LLMMessage(
|
413
|
+
role=Role.TOOL,
|
414
|
+
tool_call_id=message.metadata.oai_tool_id,
|
415
|
+
content=content,
|
416
|
+
chat_document_id=chat_document_id,
|
417
|
+
)
|
418
|
+
]
|
419
|
+
elif message.oai_tool_id2result is not None:
|
420
|
+
# Case 2:
|
421
|
+
# There were > 1 tool-calls awaiting response,
|
422
|
+
assert (
|
423
|
+
len(message.oai_tool_id2result) > 1
|
424
|
+
), "oai_tool_id2result must have more than 1 item."
|
425
|
+
return [
|
426
|
+
LLMMessage(
|
427
|
+
role=Role.TOOL,
|
428
|
+
tool_call_id=tool_id,
|
429
|
+
content=result,
|
430
|
+
chat_document_id=chat_document_id,
|
431
|
+
)
|
432
|
+
for tool_id, result in message.oai_tool_id2result.items()
|
433
|
+
]
|
434
|
+
elif message.metadata.sender == Entity.LLM:
|
435
|
+
sender_role = Role.ASSISTANT
|
436
|
+
else:
|
437
|
+
# LLM can only respond to text content, so extract it
|
438
|
+
content = message
|
439
|
+
|
440
|
+
return [
|
441
|
+
LLMMessage(
|
442
|
+
role=sender_role,
|
443
|
+
tool_id=tool_id, # for OpenAI Assistant
|
444
|
+
content=content,
|
445
|
+
function_call=fun_call,
|
446
|
+
tool_calls=oai_tool_calls,
|
447
|
+
name=sender_name,
|
448
|
+
chat_document_id=chat_document_id,
|
449
|
+
)
|
450
|
+
]
|
451
|
+
|
452
|
+
|
453
|
+
LLMMessage.update_forward_refs()
|
454
|
+
ChatDocMetaData.update_forward_refs()
|