langroid 0.31.2__py3-none-any.whl → 0.33.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langroid-0.31.2.dist-info → langroid-0.33.3.dist-info}/METADATA +150 -124
- langroid-0.33.3.dist-info/RECORD +7 -0
- {langroid-0.31.2.dist-info → langroid-0.33.3.dist-info}/WHEEL +1 -1
- langroid-0.33.3.dist-info/entry_points.txt +4 -0
- pyproject.toml +317 -212
- langroid/__init__.py +0 -106
- langroid/agent/.chainlit/config.toml +0 -121
- langroid/agent/.chainlit/translations/bn.json +0 -231
- langroid/agent/.chainlit/translations/en-US.json +0 -229
- langroid/agent/.chainlit/translations/gu.json +0 -231
- langroid/agent/.chainlit/translations/he-IL.json +0 -231
- langroid/agent/.chainlit/translations/hi.json +0 -231
- langroid/agent/.chainlit/translations/kn.json +0 -231
- langroid/agent/.chainlit/translations/ml.json +0 -231
- langroid/agent/.chainlit/translations/mr.json +0 -231
- langroid/agent/.chainlit/translations/ta.json +0 -231
- langroid/agent/.chainlit/translations/te.json +0 -231
- langroid/agent/.chainlit/translations/zh-CN.json +0 -229
- langroid/agent/__init__.py +0 -41
- langroid/agent/base.py +0 -1981
- langroid/agent/batch.py +0 -398
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +0 -598
- langroid/agent/chat_agent.py +0 -1899
- langroid/agent/chat_document.py +0 -454
- langroid/agent/helpers.py +0 -0
- langroid/agent/junk +0 -13
- langroid/agent/openai_assistant.py +0 -882
- langroid/agent/special/__init__.py +0 -59
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +0 -656
- langroid/agent/special/arangodb/system_messages.py +0 -186
- langroid/agent/special/arangodb/tools.py +0 -107
- langroid/agent/special/arangodb/utils.py +0 -36
- langroid/agent/special/doc_chat_agent.py +0 -1466
- langroid/agent/special/lance_doc_chat_agent.py +0 -262
- langroid/agent/special/lance_rag/__init__.py +0 -9
- langroid/agent/special/lance_rag/critic_agent.py +0 -198
- langroid/agent/special/lance_rag/lance_rag_task.py +0 -82
- langroid/agent/special/lance_rag/query_planner_agent.py +0 -260
- langroid/agent/special/lance_tools.py +0 -61
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +0 -174
- langroid/agent/special/neo4j/neo4j_chat_agent.py +0 -433
- langroid/agent/special/neo4j/system_messages.py +0 -120
- langroid/agent/special/neo4j/tools.py +0 -32
- langroid/agent/special/relevance_extractor_agent.py +0 -127
- langroid/agent/special/retriever_agent.py +0 -56
- langroid/agent/special/sql/__init__.py +0 -17
- langroid/agent/special/sql/sql_chat_agent.py +0 -654
- langroid/agent/special/sql/utils/__init__.py +0 -21
- langroid/agent/special/sql/utils/description_extractors.py +0 -190
- langroid/agent/special/sql/utils/populate_metadata.py +0 -85
- langroid/agent/special/sql/utils/system_message.py +0 -35
- langroid/agent/special/sql/utils/tools.py +0 -64
- langroid/agent/special/table_chat_agent.py +0 -263
- langroid/agent/structured_message.py +0 -9
- langroid/agent/task.py +0 -2093
- langroid/agent/tool_message.py +0 -393
- langroid/agent/tools/__init__.py +0 -38
- langroid/agent/tools/duckduckgo_search_tool.py +0 -50
- langroid/agent/tools/file_tools.py +0 -234
- langroid/agent/tools/google_search_tool.py +0 -39
- langroid/agent/tools/metaphor_search_tool.py +0 -67
- langroid/agent/tools/orchestration.py +0 -303
- langroid/agent/tools/recipient_tool.py +0 -235
- langroid/agent/tools/retrieval_tool.py +0 -32
- langroid/agent/tools/rewind_tool.py +0 -137
- langroid/agent/tools/segment_extract_tool.py +0 -41
- langroid/agent/typed_task.py +0 -19
- langroid/agent/xml_tool_message.py +0 -382
- langroid/agent_config.py +0 -0
- langroid/cachedb/__init__.py +0 -17
- langroid/cachedb/base.py +0 -58
- langroid/cachedb/momento_cachedb.py +0 -108
- langroid/cachedb/redis_cachedb.py +0 -153
- langroid/embedding_models/__init__.py +0 -39
- langroid/embedding_models/base.py +0 -74
- langroid/embedding_models/clustering.py +0 -189
- langroid/embedding_models/models.py +0 -461
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +0 -19
- langroid/embedding_models/protoc/embeddings_pb2.py +0 -33
- langroid/embedding_models/protoc/embeddings_pb2.pyi +0 -50
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +0 -79
- langroid/embedding_models/remote_embeds.py +0 -153
- langroid/exceptions.py +0 -65
- langroid/experimental/team-save.py +0 -391
- langroid/language_models/.chainlit/config.toml +0 -121
- langroid/language_models/.chainlit/translations/en-US.json +0 -231
- langroid/language_models/__init__.py +0 -53
- langroid/language_models/azure_openai.py +0 -153
- langroid/language_models/base.py +0 -678
- langroid/language_models/config.py +0 -18
- langroid/language_models/mock_lm.py +0 -124
- langroid/language_models/openai_gpt.py +0 -1923
- langroid/language_models/prompt_formatter/__init__.py +0 -16
- langroid/language_models/prompt_formatter/base.py +0 -40
- langroid/language_models/prompt_formatter/hf_formatter.py +0 -132
- langroid/language_models/prompt_formatter/llama2_formatter.py +0 -75
- langroid/language_models/utils.py +0 -147
- langroid/mytypes.py +0 -84
- langroid/parsing/__init__.py +0 -52
- langroid/parsing/agent_chats.py +0 -38
- langroid/parsing/code-parsing.md +0 -86
- langroid/parsing/code_parser.py +0 -121
- langroid/parsing/config.py +0 -0
- langroid/parsing/document_parser.py +0 -718
- langroid/parsing/image_text.py +0 -32
- langroid/parsing/para_sentence_split.py +0 -62
- langroid/parsing/parse_json.py +0 -155
- langroid/parsing/parser.py +0 -313
- langroid/parsing/repo_loader.py +0 -790
- langroid/parsing/routing.py +0 -36
- langroid/parsing/search.py +0 -275
- langroid/parsing/spider.py +0 -102
- langroid/parsing/table_loader.py +0 -94
- langroid/parsing/url_loader.py +0 -111
- langroid/parsing/url_loader_cookies.py +0 -73
- langroid/parsing/urls.py +0 -273
- langroid/parsing/utils.py +0 -373
- langroid/parsing/web_search.py +0 -155
- langroid/prompts/__init__.py +0 -9
- langroid/prompts/chat-gpt4-system-prompt.md +0 -68
- langroid/prompts/dialog.py +0 -17
- langroid/prompts/prompts_config.py +0 -5
- langroid/prompts/templates.py +0 -141
- langroid/pydantic_v1/__init__.py +0 -10
- langroid/pydantic_v1/main.py +0 -4
- langroid/utils/.chainlit/config.toml +0 -121
- langroid/utils/.chainlit/translations/en-US.json +0 -231
- langroid/utils/__init__.py +0 -19
- langroid/utils/algorithms/__init__.py +0 -3
- langroid/utils/algorithms/graph.py +0 -103
- langroid/utils/configuration.py +0 -98
- langroid/utils/constants.py +0 -30
- langroid/utils/docker.py +0 -37
- langroid/utils/git_utils.py +0 -252
- langroid/utils/globals.py +0 -49
- langroid/utils/llms/__init__.py +0 -0
- langroid/utils/llms/strings.py +0 -8
- langroid/utils/logging.py +0 -135
- langroid/utils/object_registry.py +0 -66
- langroid/utils/output/__init__.py +0 -20
- langroid/utils/output/citations.py +0 -41
- langroid/utils/output/printing.py +0 -99
- langroid/utils/output/status.py +0 -40
- langroid/utils/pandas_utils.py +0 -30
- langroid/utils/pydantic_utils.py +0 -602
- langroid/utils/system.py +0 -286
- langroid/utils/types.py +0 -93
- langroid/utils/web/__init__.py +0 -0
- langroid/utils/web/login.py +0 -83
- langroid/vector_store/__init__.py +0 -50
- langroid/vector_store/base.py +0 -357
- langroid/vector_store/chromadb.py +0 -214
- langroid/vector_store/lancedb.py +0 -401
- langroid/vector_store/meilisearch.py +0 -299
- langroid/vector_store/momento.py +0 -278
- langroid/vector_store/qdrant_cloud.py +0 -6
- langroid/vector_store/qdrantdb.py +0 -468
- langroid-0.31.2.dist-info/RECORD +0 -162
- {langroid-0.31.2.dist-info → langroid-0.33.3.dist-info/licenses}/LICENSE +0 -0
@@ -1,656 +0,0 @@
|
|
1
|
-
import datetime
|
2
|
-
import json
|
3
|
-
import logging
|
4
|
-
import time
|
5
|
-
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
6
|
-
|
7
|
-
from arango.client import ArangoClient
|
8
|
-
from arango.database import StandardDatabase
|
9
|
-
from arango.exceptions import ArangoError, ServerConnectionError
|
10
|
-
from numpy import ceil
|
11
|
-
from rich import print
|
12
|
-
from rich.console import Console
|
13
|
-
|
14
|
-
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
15
|
-
from langroid.agent.chat_document import ChatDocument
|
16
|
-
from langroid.agent.special.arangodb.system_messages import (
|
17
|
-
ADDRESSING_INSTRUCTION,
|
18
|
-
DEFAULT_ARANGO_CHAT_SYSTEM_MESSAGE,
|
19
|
-
DONE_INSTRUCTION,
|
20
|
-
SCHEMA_PROVIDED_SYS_MSG,
|
21
|
-
SCHEMA_TOOLS_SYS_MSG,
|
22
|
-
)
|
23
|
-
from langroid.agent.special.arangodb.tools import (
|
24
|
-
AQLCreationTool,
|
25
|
-
AQLRetrievalTool,
|
26
|
-
ArangoSchemaTool,
|
27
|
-
aql_retrieval_tool_name,
|
28
|
-
arango_schema_tool_name,
|
29
|
-
)
|
30
|
-
from langroid.agent.special.arangodb.utils import count_fields, trim_schema
|
31
|
-
from langroid.agent.tools.orchestration import DoneTool, ForwardTool
|
32
|
-
from langroid.exceptions import LangroidImportError
|
33
|
-
from langroid.mytypes import Entity
|
34
|
-
from langroid.pydantic_v1 import BaseModel, BaseSettings
|
35
|
-
from langroid.utils.constants import SEND_TO
|
36
|
-
|
37
|
-
logger = logging.getLogger(__name__)
|
38
|
-
console = Console()
|
39
|
-
|
40
|
-
ARANGO_ERROR_MSG = "There was an error in your AQL Query"
|
41
|
-
T = TypeVar("T")
|
42
|
-
|
43
|
-
|
44
|
-
class ArangoSettings(BaseSettings):
|
45
|
-
client: ArangoClient | None = None
|
46
|
-
db: StandardDatabase | None = None
|
47
|
-
url: str = ""
|
48
|
-
username: str = ""
|
49
|
-
password: str = ""
|
50
|
-
database: str = ""
|
51
|
-
|
52
|
-
class Config:
|
53
|
-
env_prefix = "ARANGO_"
|
54
|
-
|
55
|
-
|
56
|
-
class QueryResult(BaseModel):
|
57
|
-
success: bool
|
58
|
-
data: Optional[
|
59
|
-
Union[
|
60
|
-
str,
|
61
|
-
int,
|
62
|
-
float,
|
63
|
-
bool,
|
64
|
-
None,
|
65
|
-
List[Any],
|
66
|
-
Dict[str, Any],
|
67
|
-
List[Dict[str, Any]],
|
68
|
-
]
|
69
|
-
] = None
|
70
|
-
|
71
|
-
class Config:
|
72
|
-
# Allow arbitrary types for flexibility
|
73
|
-
arbitrary_types_allowed = True
|
74
|
-
|
75
|
-
# Handle JSON serialization of special types
|
76
|
-
json_encoders = {
|
77
|
-
# Add custom encoders if needed, e.g.:
|
78
|
-
datetime.datetime: lambda v: v.isoformat(),
|
79
|
-
# Could add others for specific ArangoDB types
|
80
|
-
}
|
81
|
-
|
82
|
-
# Validate all assignments
|
83
|
-
validate_assignment = True
|
84
|
-
|
85
|
-
# Frozen=True if we want immutability
|
86
|
-
frozen = False
|
87
|
-
|
88
|
-
|
89
|
-
class ArangoChatAgentConfig(ChatAgentConfig):
|
90
|
-
arango_settings: ArangoSettings = ArangoSettings()
|
91
|
-
system_message: str = DEFAULT_ARANGO_CHAT_SYSTEM_MESSAGE
|
92
|
-
kg_schema: str | Dict[str, List[Dict[str, Any]]] | None = None
|
93
|
-
database_created: bool = False
|
94
|
-
prepopulate_schema: bool = True
|
95
|
-
use_functions_api: bool = True
|
96
|
-
max_num_results: int = 10 # how many results to return from AQL query
|
97
|
-
max_schema_fields: int = 500 # max fields to show in schema
|
98
|
-
max_tries: int = 10 # how many attempts to answer user question
|
99
|
-
use_tools: bool = False
|
100
|
-
schema_sample_pct: float = 0
|
101
|
-
# whether the agent is used in a continuous chat with user,
|
102
|
-
# as opposed to returning a result from the task.run()
|
103
|
-
chat_mode: bool = False
|
104
|
-
addressing_prefix: str = ""
|
105
|
-
|
106
|
-
|
107
|
-
class ArangoChatAgent(ChatAgent):
|
108
|
-
def __init__(self, config: ArangoChatAgentConfig):
|
109
|
-
super().__init__(config)
|
110
|
-
self.config: ArangoChatAgentConfig = config
|
111
|
-
self.init_state()
|
112
|
-
self._validate_config()
|
113
|
-
self._import_arango()
|
114
|
-
self._initialize_db()
|
115
|
-
self._init_tools_sys_message()
|
116
|
-
|
117
|
-
def init_state(self) -> None:
|
118
|
-
super().init_state()
|
119
|
-
self.current_retrieval_aql_query: str = ""
|
120
|
-
self.current_schema_params: ArangoSchemaTool = ArangoSchemaTool()
|
121
|
-
self.num_tries = 0 # how many attempts to answer user question
|
122
|
-
|
123
|
-
def user_response(
|
124
|
-
self,
|
125
|
-
msg: Optional[str | ChatDocument] = None,
|
126
|
-
) -> Optional[ChatDocument]:
|
127
|
-
response = super().user_response(msg)
|
128
|
-
if response is None:
|
129
|
-
return None
|
130
|
-
response_str = response.content if response is not None else ""
|
131
|
-
if response_str != "":
|
132
|
-
self.num_tries = 0 # reset number of tries if user responds
|
133
|
-
return response
|
134
|
-
|
135
|
-
def llm_response(
|
136
|
-
self, message: Optional[str | ChatDocument] = None
|
137
|
-
) -> Optional[ChatDocument]:
|
138
|
-
if self.num_tries > self.config.max_tries:
|
139
|
-
if self.config.chat_mode:
|
140
|
-
return self.create_llm_response(
|
141
|
-
content=f"""
|
142
|
-
{self.config.addressing_prefix}User
|
143
|
-
I give up, since I have exceeded the
|
144
|
-
maximum number of tries ({self.config.max_tries}).
|
145
|
-
Feel free to give me some hints!
|
146
|
-
"""
|
147
|
-
)
|
148
|
-
else:
|
149
|
-
return self.create_llm_response(
|
150
|
-
tool_messages=[
|
151
|
-
DoneTool(
|
152
|
-
content=f"""
|
153
|
-
Exceeded maximum number of tries ({self.config.max_tries}).
|
154
|
-
"""
|
155
|
-
)
|
156
|
-
]
|
157
|
-
)
|
158
|
-
|
159
|
-
if isinstance(message, ChatDocument) and message.metadata.sender == Entity.USER:
|
160
|
-
message.content = (
|
161
|
-
message.content
|
162
|
-
+ "\n"
|
163
|
-
+ """
|
164
|
-
(REMEMBER, Do NOT use more than ONE TOOL/FUNCTION at a time!
|
165
|
-
you must WAIT for a helper to send you the RESULT(S) before
|
166
|
-
making another TOOL/FUNCTION call)
|
167
|
-
"""
|
168
|
-
)
|
169
|
-
|
170
|
-
response = super().llm_response(message)
|
171
|
-
if (
|
172
|
-
response is not None
|
173
|
-
and self.config.chat_mode
|
174
|
-
and self.config.addressing_prefix in response.content
|
175
|
-
and self.has_tool_message_attempt(response)
|
176
|
-
):
|
177
|
-
# response contains both a user-addressing and a tool, which
|
178
|
-
# is not allowed, so remove the user-addressing prefix
|
179
|
-
response.content = response.content.replace(
|
180
|
-
self.config.addressing_prefix, ""
|
181
|
-
)
|
182
|
-
|
183
|
-
return response
|
184
|
-
|
185
|
-
def _validate_config(self) -> None:
|
186
|
-
assert isinstance(self.config, ArangoChatAgentConfig)
|
187
|
-
if (
|
188
|
-
self.config.arango_settings.client is None
|
189
|
-
or self.config.arango_settings.db is None
|
190
|
-
):
|
191
|
-
if not all(
|
192
|
-
[
|
193
|
-
self.config.arango_settings.url,
|
194
|
-
self.config.arango_settings.username,
|
195
|
-
self.config.arango_settings.password,
|
196
|
-
self.config.arango_settings.database,
|
197
|
-
]
|
198
|
-
):
|
199
|
-
raise ValueError("ArangoDB connection info must be provided")
|
200
|
-
|
201
|
-
def _import_arango(self) -> None:
|
202
|
-
global ArangoClient
|
203
|
-
try:
|
204
|
-
from arango.client import ArangoClient
|
205
|
-
except ImportError:
|
206
|
-
raise LangroidImportError("python-arango", "arango")
|
207
|
-
|
208
|
-
def _has_any_data(self) -> bool:
|
209
|
-
for c in self.db.collections(): # type: ignore
|
210
|
-
if c["name"].startswith("_"):
|
211
|
-
continue
|
212
|
-
if self.db.collection(c["name"]).count() > 0: # type: ignore
|
213
|
-
return True
|
214
|
-
return False
|
215
|
-
|
216
|
-
def _initialize_db(self) -> None:
|
217
|
-
try:
|
218
|
-
logger.info("Initializing ArangoDB client connection...")
|
219
|
-
self.client = self.config.arango_settings.client or ArangoClient(
|
220
|
-
hosts=self.config.arango_settings.url
|
221
|
-
)
|
222
|
-
|
223
|
-
logger.info("Connecting to database...")
|
224
|
-
self.db = self.config.arango_settings.db or self.client.db(
|
225
|
-
self.config.arango_settings.database,
|
226
|
-
username=self.config.arango_settings.username,
|
227
|
-
password=self.config.arango_settings.password,
|
228
|
-
)
|
229
|
-
|
230
|
-
logger.info("Checking for existing data in collections...")
|
231
|
-
# Check if any non-system collection has data
|
232
|
-
self.config.database_created = self._has_any_data()
|
233
|
-
|
234
|
-
# If database has data, get schema
|
235
|
-
if self.config.database_created:
|
236
|
-
logger.info("Database has existing data, retrieving schema...")
|
237
|
-
# this updates self.config.kg_schema
|
238
|
-
self.arango_schema_tool(None)
|
239
|
-
else:
|
240
|
-
logger.info("No existing data found in database")
|
241
|
-
|
242
|
-
except Exception as e:
|
243
|
-
logger.error(f"Database initialization failed: {e}")
|
244
|
-
raise ConnectionError(f"Failed to initialize ArangoDB connection: {e}")
|
245
|
-
|
246
|
-
def close(self) -> None:
|
247
|
-
if self.client:
|
248
|
-
self.client.close()
|
249
|
-
|
250
|
-
@staticmethod
|
251
|
-
def cleanup_graph_db(db) -> None: # type: ignore
|
252
|
-
# First delete graphs to properly handle edge collections
|
253
|
-
for graph in db.graphs():
|
254
|
-
graph_name = graph["name"]
|
255
|
-
if not graph_name.startswith("_"): # Skip system graphs
|
256
|
-
try:
|
257
|
-
db.delete_graph(graph_name)
|
258
|
-
except Exception as e:
|
259
|
-
print(f"Failed to delete graph {graph_name}: {e}")
|
260
|
-
|
261
|
-
# Clear existing collections
|
262
|
-
for collection in db.collections():
|
263
|
-
if not collection["name"].startswith("_"): # Skip system collections
|
264
|
-
try:
|
265
|
-
db.delete_collection(collection["name"])
|
266
|
-
except Exception as e:
|
267
|
-
print(f"Failed to delete collection {collection['name']}: {e}")
|
268
|
-
|
269
|
-
def with_retry(
|
270
|
-
self, func: Callable[[], T], max_retries: int = 3, delay: float = 1.0
|
271
|
-
) -> T:
|
272
|
-
"""Execute a function with retries on connection error"""
|
273
|
-
for attempt in range(max_retries):
|
274
|
-
try:
|
275
|
-
return func()
|
276
|
-
except ArangoError:
|
277
|
-
if attempt == max_retries - 1:
|
278
|
-
raise
|
279
|
-
logger.warning(
|
280
|
-
f"Connection failed (attempt {attempt + 1}/{max_retries}). "
|
281
|
-
f"Retrying in {delay} seconds..."
|
282
|
-
)
|
283
|
-
time.sleep(delay)
|
284
|
-
# Reconnect if needed
|
285
|
-
self._initialize_db()
|
286
|
-
return func() # Final attempt after loop if not raised
|
287
|
-
|
288
|
-
def read_query(
|
289
|
-
self, query: str, bind_vars: Optional[Dict[Any, Any]] = None
|
290
|
-
) -> QueryResult:
|
291
|
-
"""Execute a read query with connection retry."""
|
292
|
-
if not self.db:
|
293
|
-
return QueryResult(
|
294
|
-
success=False, data="No database connection is established."
|
295
|
-
)
|
296
|
-
|
297
|
-
def execute_read() -> QueryResult:
|
298
|
-
try:
|
299
|
-
cursor = self.db.aql.execute(query, bind_vars=bind_vars)
|
300
|
-
records = [doc for doc in cursor] # type: ignore
|
301
|
-
records = records[: self.config.max_num_results]
|
302
|
-
logger.warning(f"Records retrieved: {records}")
|
303
|
-
return QueryResult(success=True, data=records if records else [])
|
304
|
-
except Exception as e:
|
305
|
-
if isinstance(e, ServerConnectionError):
|
306
|
-
raise
|
307
|
-
logger.error(f"Failed to execute query: {query}\n{e}")
|
308
|
-
error_message = self.retry_query(e, query)
|
309
|
-
return QueryResult(success=False, data=error_message)
|
310
|
-
|
311
|
-
try:
|
312
|
-
return self.with_retry(execute_read) # type: ignore
|
313
|
-
except Exception as e:
|
314
|
-
return QueryResult(
|
315
|
-
success=False, data=f"Failed after max retries: {str(e)}"
|
316
|
-
)
|
317
|
-
|
318
|
-
def write_query(
|
319
|
-
self, query: str, bind_vars: Optional[Dict[Any, Any]] = None
|
320
|
-
) -> QueryResult:
|
321
|
-
"""Execute a write query with connection retry."""
|
322
|
-
if not self.db:
|
323
|
-
return QueryResult(
|
324
|
-
success=False, data="No database connection is established."
|
325
|
-
)
|
326
|
-
|
327
|
-
def execute_write() -> QueryResult:
|
328
|
-
try:
|
329
|
-
self.db.aql.execute(query, bind_vars=bind_vars)
|
330
|
-
return QueryResult(success=True)
|
331
|
-
except Exception as e:
|
332
|
-
if isinstance(e, ServerConnectionError):
|
333
|
-
raise
|
334
|
-
logger.error(f"Failed to execute query: {query}\n{e}")
|
335
|
-
error_message = self.retry_query(e, query)
|
336
|
-
return QueryResult(success=False, data=error_message)
|
337
|
-
|
338
|
-
try:
|
339
|
-
return self.with_retry(execute_write) # type: ignore
|
340
|
-
except Exception as e:
|
341
|
-
return QueryResult(
|
342
|
-
success=False, data=f"Failed after max retries: {str(e)}"
|
343
|
-
)
|
344
|
-
|
345
|
-
def aql_retrieval_tool(self, msg: AQLRetrievalTool) -> str:
|
346
|
-
"""Handle AQL query for data retrieval"""
|
347
|
-
if not self.tried_schema:
|
348
|
-
return f"""
|
349
|
-
You need to use `{arango_schema_tool_name}` first to get the
|
350
|
-
database schema before using `{aql_retrieval_tool_name}`. This ensures
|
351
|
-
you know the correct collection names and edge definitions.
|
352
|
-
"""
|
353
|
-
elif not self.config.database_created:
|
354
|
-
return """
|
355
|
-
You need to create the database first using `{aql_creation_tool_name}`.
|
356
|
-
"""
|
357
|
-
self.num_tries += 1
|
358
|
-
query = msg.aql_query
|
359
|
-
if query == self.current_retrieval_aql_query:
|
360
|
-
return """
|
361
|
-
You have already tried this query, so you will get the same results again!
|
362
|
-
If you need to retry, please MODIFY the query to get different results.
|
363
|
-
"""
|
364
|
-
self.current_retrieval_aql_query = query
|
365
|
-
logger.info(f"Executing AQL query: {query}")
|
366
|
-
response = self.read_query(query)
|
367
|
-
|
368
|
-
if isinstance(response.data, list) and len(response.data) == 0:
|
369
|
-
return """
|
370
|
-
No results found. Check if your collection names are correct -
|
371
|
-
they are case-sensitive. Use exact names from the schema.
|
372
|
-
Try modifying your query based on the RETRY-SUGGESTIONS
|
373
|
-
in your instructions.
|
374
|
-
"""
|
375
|
-
return str(response.data)
|
376
|
-
|
377
|
-
def aql_creation_tool(self, msg: AQLCreationTool) -> str:
|
378
|
-
"""Handle AQL query for creating data"""
|
379
|
-
self.num_tries += 1
|
380
|
-
query = msg.aql_query
|
381
|
-
logger.info(f"Executing AQL query: {query}")
|
382
|
-
response = self.write_query(query)
|
383
|
-
|
384
|
-
if response.success:
|
385
|
-
self.config.database_created = True
|
386
|
-
return "AQL query executed successfully"
|
387
|
-
return str(response.data)
|
388
|
-
|
389
|
-
def arango_schema_tool(
|
390
|
-
self,
|
391
|
-
msg: ArangoSchemaTool | None,
|
392
|
-
) -> Dict[str, List[Dict[str, Any]]] | str:
|
393
|
-
"""Get database schema. If collections=None, include all collections.
|
394
|
-
If properties=False, show only connection info,
|
395
|
-
else show all properties and example-docs.
|
396
|
-
"""
|
397
|
-
|
398
|
-
if (
|
399
|
-
msg is not None
|
400
|
-
and msg.collections == self.current_schema_params.collections
|
401
|
-
and msg.properties == self.current_schema_params.properties
|
402
|
-
):
|
403
|
-
return """
|
404
|
-
You have already tried this schema TOOL, so you will get the same results
|
405
|
-
again! Please MODIFY the tool params `collections` or `properties` to get
|
406
|
-
different results.
|
407
|
-
"""
|
408
|
-
|
409
|
-
if msg is not None:
|
410
|
-
collections = msg.collections
|
411
|
-
properties = msg.properties
|
412
|
-
else:
|
413
|
-
collections = None
|
414
|
-
properties = True
|
415
|
-
self.tried_schema = True
|
416
|
-
if (
|
417
|
-
self.config.kg_schema is not None
|
418
|
-
and len(self.config.kg_schema) > 0
|
419
|
-
and msg is None
|
420
|
-
):
|
421
|
-
# we are trying to pre-populate full schema before the agent runs,
|
422
|
-
# so get it if it's already available
|
423
|
-
# (Note of course that this "full schema" may actually be incomplete)
|
424
|
-
return self.config.kg_schema
|
425
|
-
|
426
|
-
# increment tries only if the LLM is asking for the schema,
|
427
|
-
# in which case msg will not be None
|
428
|
-
self.num_tries += msg is not None
|
429
|
-
|
430
|
-
try:
|
431
|
-
# Get graph schemas (keeping full graph info)
|
432
|
-
graph_schema = [
|
433
|
-
{"graph_name": g["name"], "edge_definitions": g["edge_definitions"]}
|
434
|
-
for g in self.db.graphs() # type: ignore
|
435
|
-
]
|
436
|
-
|
437
|
-
# Get collection schemas
|
438
|
-
collection_schema = []
|
439
|
-
for collection in self.db.collections(): # type: ignore
|
440
|
-
if collection["name"].startswith("_"):
|
441
|
-
continue
|
442
|
-
|
443
|
-
col_name = collection["name"]
|
444
|
-
if collections and col_name not in collections:
|
445
|
-
continue
|
446
|
-
|
447
|
-
col_type = collection["type"]
|
448
|
-
col_size = self.db.collection(col_name).count()
|
449
|
-
|
450
|
-
if col_size == 0:
|
451
|
-
continue
|
452
|
-
|
453
|
-
if properties:
|
454
|
-
# Full property collection with sampling
|
455
|
-
lim = self.config.schema_sample_pct * col_size # type: ignore
|
456
|
-
limit_amount = ceil(lim / 100.0) or 1
|
457
|
-
sample_query = f"""
|
458
|
-
FOR doc in {col_name}
|
459
|
-
LIMIT {limit_amount}
|
460
|
-
RETURN doc
|
461
|
-
"""
|
462
|
-
|
463
|
-
properties_list = []
|
464
|
-
example_doc = None
|
465
|
-
|
466
|
-
def simplify_doc(doc: Any) -> Any:
|
467
|
-
if isinstance(doc, list) and len(doc) > 0:
|
468
|
-
return [simplify_doc(doc[0])]
|
469
|
-
if isinstance(doc, dict):
|
470
|
-
return {k: simplify_doc(v) for k, v in doc.items()}
|
471
|
-
return doc
|
472
|
-
|
473
|
-
for doc in self.db.aql.execute(sample_query): # type: ignore
|
474
|
-
if example_doc is None:
|
475
|
-
example_doc = simplify_doc(doc)
|
476
|
-
for key, value in doc.items():
|
477
|
-
prop = {"name": key, "type": type(value).__name__}
|
478
|
-
if prop not in properties_list:
|
479
|
-
properties_list.append(prop)
|
480
|
-
|
481
|
-
collection_schema.append(
|
482
|
-
{
|
483
|
-
"collection_name": col_name,
|
484
|
-
"collection_type": col_type,
|
485
|
-
f"{col_type}_properties": properties_list,
|
486
|
-
f"example_{col_type}": example_doc,
|
487
|
-
}
|
488
|
-
)
|
489
|
-
else:
|
490
|
-
# Basic info + from/to for edges only
|
491
|
-
collection_info = {
|
492
|
-
"collection_name": col_name,
|
493
|
-
"collection_type": col_type,
|
494
|
-
}
|
495
|
-
if col_type == "edge":
|
496
|
-
# Get a sample edge to extract from/to fields
|
497
|
-
sample_edge = next(
|
498
|
-
self.db.aql.execute( # type: ignore
|
499
|
-
f"FOR e IN {col_name} LIMIT 1 RETURN e"
|
500
|
-
),
|
501
|
-
None,
|
502
|
-
)
|
503
|
-
if sample_edge:
|
504
|
-
collection_info["from_collection"] = sample_edge[
|
505
|
-
"_from"
|
506
|
-
].split("/")[0]
|
507
|
-
collection_info["to_collection"] = sample_edge["_to"].split(
|
508
|
-
"/"
|
509
|
-
)[0]
|
510
|
-
|
511
|
-
collection_schema.append(collection_info)
|
512
|
-
|
513
|
-
schema = {
|
514
|
-
"Graph Schema": graph_schema,
|
515
|
-
"Collection Schema": collection_schema,
|
516
|
-
}
|
517
|
-
schema_str = json.dumps(schema, indent=2)
|
518
|
-
logger.warning(f"Schema retrieved:\n{schema_str}")
|
519
|
-
with open("logs/arango-schema.json", "w") as f:
|
520
|
-
f.write(schema_str)
|
521
|
-
if (n_fields := count_fields(schema)) > self.config.max_schema_fields:
|
522
|
-
logger.warning(
|
523
|
-
f"""
|
524
|
-
Schema has {n_fields} fields, which exceeds the maximum of
|
525
|
-
{self.config.max_schema_fields}. Showing a trimmed version
|
526
|
-
that only includes edge info and no other properties.
|
527
|
-
"""
|
528
|
-
)
|
529
|
-
schema = trim_schema(schema)
|
530
|
-
n_fields = count_fields(schema)
|
531
|
-
logger.warning(f"Schema trimmed down to {n_fields} fields.")
|
532
|
-
schema_str = (
|
533
|
-
json.dumps(schema)
|
534
|
-
+ "\n"
|
535
|
-
+ f"""
|
536
|
-
|
537
|
-
CAUTION: The requested schema was too large, so
|
538
|
-
the schema has been trimmed down to show only all collection names,
|
539
|
-
their types,
|
540
|
-
and edge relationships (from/to collections) without any properties.
|
541
|
-
To find out more about the schema, you can EITHER:
|
542
|
-
- Use the `{arango_schema_tool_name}` tool again with the
|
543
|
-
`properties` arg set to True, and `collections` arg set to
|
544
|
-
specific collections you want to know more about, OR
|
545
|
-
- Use the `{aql_retrieval_tool_name}` tool to learn more about
|
546
|
-
the schema by querying the database.
|
547
|
-
|
548
|
-
"""
|
549
|
-
)
|
550
|
-
if msg is None:
|
551
|
-
self.config.kg_schema = schema_str
|
552
|
-
return schema_str
|
553
|
-
self.config.kg_schema = schema
|
554
|
-
return schema
|
555
|
-
|
556
|
-
except Exception as e:
|
557
|
-
logger.error(f"Schema retrieval failed: {str(e)}")
|
558
|
-
return f"Failed to retrieve schema: {str(e)}"
|
559
|
-
|
560
|
-
def _init_tools_sys_message(self) -> None:
|
561
|
-
"""Initialize system msg and enable tools"""
|
562
|
-
self.tried_schema = False
|
563
|
-
message = self._format_message()
|
564
|
-
self.config.system_message = self.config.system_message.format(mode=message)
|
565
|
-
|
566
|
-
if self.config.chat_mode:
|
567
|
-
self.config.addressing_prefix = self.config.addressing_prefix or SEND_TO
|
568
|
-
self.config.system_message += ADDRESSING_INSTRUCTION.format(
|
569
|
-
prefix=self.config.addressing_prefix
|
570
|
-
)
|
571
|
-
else:
|
572
|
-
self.config.system_message += DONE_INSTRUCTION
|
573
|
-
|
574
|
-
super().__init__(self.config)
|
575
|
-
# Note we are enabling GraphSchemaTool regardless of whether
|
576
|
-
# self.config.prepopulate_schema is True or False, because
|
577
|
-
# even when schema provided, the agent may later want to get the schema,
|
578
|
-
# e.g. if the db evolves, or schema was trimmed due to size, or
|
579
|
-
# if it needs to bring in the schema into recent context.
|
580
|
-
|
581
|
-
self.enable_message(
|
582
|
-
[
|
583
|
-
ArangoSchemaTool,
|
584
|
-
AQLRetrievalTool,
|
585
|
-
AQLCreationTool,
|
586
|
-
ForwardTool,
|
587
|
-
]
|
588
|
-
)
|
589
|
-
if not self.config.chat_mode:
|
590
|
-
self.enable_message(DoneTool)
|
591
|
-
|
592
|
-
def _format_message(self) -> str:
|
593
|
-
if self.db is None:
|
594
|
-
raise ValueError("Database connection not established")
|
595
|
-
|
596
|
-
assert isinstance(self.config, ArangoChatAgentConfig)
|
597
|
-
return (
|
598
|
-
SCHEMA_TOOLS_SYS_MSG
|
599
|
-
if not self.config.prepopulate_schema
|
600
|
-
else SCHEMA_PROVIDED_SYS_MSG.format(schema=self.arango_schema_tool(None))
|
601
|
-
)
|
602
|
-
|
603
|
-
def handle_message_fallback(
|
604
|
-
self, msg: str | ChatDocument
|
605
|
-
) -> str | ForwardTool | None:
|
606
|
-
"""When LLM sends a no-tool msg, assume user is the intended recipient,
|
607
|
-
and if in interactive mode, forward the msg to the user.
|
608
|
-
"""
|
609
|
-
done_tool_name = DoneTool.default_value("request")
|
610
|
-
forward_tool_name = ForwardTool.default_value("request")
|
611
|
-
aql_retrieval_tool_instructions = AQLRetrievalTool.instructions()
|
612
|
-
# TODO the aql_retrieval_tool_instructions may be empty/minimal
|
613
|
-
# when using self.config.use_functions_api = True.
|
614
|
-
tools_instruction = f"""
|
615
|
-
For example you may want to use the TOOL
|
616
|
-
`{aql_retrieval_tool_name}` according to these instructions:
|
617
|
-
{aql_retrieval_tool_instructions}
|
618
|
-
"""
|
619
|
-
if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM:
|
620
|
-
if self.interactive:
|
621
|
-
return ForwardTool(agent="User")
|
622
|
-
else:
|
623
|
-
if self.config.chat_mode:
|
624
|
-
return f"""
|
625
|
-
Since you did not explicitly address the User, it is not clear
|
626
|
-
whether:
|
627
|
-
- you intend this to be the final response to the
|
628
|
-
user's query/request, in which case you must use the
|
629
|
-
`{forward_tool_name}` to indicate this.
|
630
|
-
- OR, you FORGOT to use an Appropriate TOOL,
|
631
|
-
in which case you should use the available tools to
|
632
|
-
make progress on the user's query/request.
|
633
|
-
{tools_instruction}
|
634
|
-
"""
|
635
|
-
return f"""
|
636
|
-
The intent of your response is not clear:
|
637
|
-
- if you intended this to be the FINAL answer to the user's query,
|
638
|
-
then use the `{done_tool_name}` to indicate so,
|
639
|
-
with the `content` set to the answer or result.
|
640
|
-
- otherwise, use one of the available tools to make progress
|
641
|
-
to arrive at the final answer.
|
642
|
-
{tools_instruction}
|
643
|
-
"""
|
644
|
-
return None
|
645
|
-
|
646
|
-
def retry_query(self, e: Exception, query: str) -> str:
|
647
|
-
"""Generate error message for failed AQL query"""
|
648
|
-
logger.error(f"AQL Query failed: {query}\nException: {e}")
|
649
|
-
|
650
|
-
error_message = f"""\
|
651
|
-
{ARANGO_ERROR_MSG}: '{query}'
|
652
|
-
{str(e)}
|
653
|
-
Please try again with a corrected query.
|
654
|
-
"""
|
655
|
-
|
656
|
-
return error_message
|