langroid 0.19.4__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +19 -13
- langroid/agent/special/arangodb/arangodb_agent.py +514 -0
- langroid/agent/special/arangodb/system_messages.py +157 -0
- langroid/agent/special/arangodb/tools.py +39 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +120 -54
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/sql/sql_chat_agent.py +8 -3
- langroid/agent/task.py +17 -1
- langroid/agent/tools/orchestration.py +1 -1
- langroid/agent/tools/recipient_tool.py +9 -0
- langroid/parsing/parser.py +6 -0
- {langroid-0.19.4.dist-info → langroid-0.20.0.dist-info}/METADATA +5 -1
- {langroid-0.19.4.dist-info → langroid-0.20.0.dist-info}/RECORD +18 -14
- pyproject.toml +7 -1
- langroid/agent/special/neo4j/utils/system_message.py +0 -64
- /langroid/agent/special/{neo4j/utils → arangodb}/__init__.py +0 -0
- {langroid-0.19.4.dist-info → langroid-0.20.0.dist-info}/LICENSE +0 -0
- {langroid-0.19.4.dist-info → langroid-0.20.0.dist-info}/WHEEL +0 -0
langroid/agent/base.py
CHANGED
@@ -94,6 +94,9 @@ class AgentConfig(BaseSettings):
|
|
94
94
|
respond_tools_only: bool = False # respond only to tool messages (not plain text)?
|
95
95
|
# allow multiple tool messages in a single response?
|
96
96
|
allow_multiple_tools: bool = True
|
97
|
+
human_prompt: str = (
|
98
|
+
"Human (respond or q, x to exit current level, " "or hit enter to continue)"
|
99
|
+
)
|
97
100
|
|
98
101
|
@validator("name")
|
99
102
|
def check_name_alphanum(cls, v: str) -> str:
|
@@ -411,16 +414,13 @@ class Agent(ABC):
|
|
411
414
|
results = self.handle_message(msg)
|
412
415
|
if results is None:
|
413
416
|
return None
|
414
|
-
if isinstance(results, ChatDocument):
|
415
|
-
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
416
|
-
results.metadata.tool_ids = (
|
417
|
-
[] if isinstance(msg, str) else msg.metadata.tool_ids
|
418
|
-
)
|
419
|
-
return results
|
420
417
|
if not settings.quiet:
|
421
|
-
|
422
|
-
|
423
|
-
)
|
418
|
+
if isinstance(results, str):
|
419
|
+
results_str = results
|
420
|
+
elif isinstance(results, ChatDocument):
|
421
|
+
results_str = results.content
|
422
|
+
elif isinstance(results, dict):
|
423
|
+
results_str = json.dumps(results, indent=2)
|
424
424
|
console.print(f"[red]{self.indent}", end="")
|
425
425
|
print(f"[red]Agent: {escape(results_str)}")
|
426
426
|
maybe_json = len(extract_top_level_json(results_str)) > 0
|
@@ -428,6 +428,12 @@ class Agent(ABC):
|
|
428
428
|
content=results_str,
|
429
429
|
language="json" if maybe_json else "text",
|
430
430
|
)
|
431
|
+
if isinstance(results, ChatDocument):
|
432
|
+
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
433
|
+
results.metadata.tool_ids = (
|
434
|
+
[] if isinstance(msg, str) else msg.metadata.tool_ids
|
435
|
+
)
|
436
|
+
return results
|
431
437
|
sender_name = self.config.name
|
432
438
|
if isinstance(msg, ChatDocument) and msg.function_call is not None:
|
433
439
|
# if result was from handling an LLM `function_call`,
|
@@ -655,9 +661,9 @@ class Agent(ABC):
|
|
655
661
|
user_msg = self.callbacks.get_user_response(prompt="")
|
656
662
|
else:
|
657
663
|
user_msg = Prompt.ask(
|
658
|
-
f"[blue]{self.indent}
|
659
|
-
|
660
|
-
f"
|
664
|
+
f"[blue]{self.indent}"
|
665
|
+
+ self.config.human_prompt
|
666
|
+
+ f"\n{self.indent}"
|
661
667
|
).strip()
|
662
668
|
|
663
669
|
tool_ids = []
|
@@ -668,7 +674,7 @@ class Agent(ABC):
|
|
668
674
|
return None
|
669
675
|
else:
|
670
676
|
if user_msg.startswith("SYSTEM"):
|
671
|
-
user_msg = user_msg
|
677
|
+
user_msg = user_msg.replace("SYSTEM", "").strip()
|
672
678
|
source = Entity.SYSTEM
|
673
679
|
sender = Entity.SYSTEM
|
674
680
|
else:
|
@@ -0,0 +1,514 @@
|
|
1
|
+
import datetime
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import time
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
6
|
+
|
7
|
+
from arango.client import ArangoClient
|
8
|
+
from arango.database import StandardDatabase
|
9
|
+
from arango.exceptions import ArangoError, ServerConnectionError
|
10
|
+
from numpy import ceil
|
11
|
+
from rich import print
|
12
|
+
from rich.console import Console
|
13
|
+
|
14
|
+
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
15
|
+
from langroid.agent.chat_document import ChatDocument
|
16
|
+
from langroid.agent.special.arangodb.system_messages import (
|
17
|
+
ADDRESSING_INSTRUCTION,
|
18
|
+
DEFAULT_ARANGO_CHAT_SYSTEM_MESSAGE,
|
19
|
+
DONE_INSTRUCTION,
|
20
|
+
SCHEMA_PROVIDED_SYS_MSG,
|
21
|
+
SCHEMA_TOOLS_SYS_MSG,
|
22
|
+
)
|
23
|
+
from langroid.agent.special.arangodb.tools import (
|
24
|
+
AQLCreationTool,
|
25
|
+
AQLRetrievalTool,
|
26
|
+
ArangoSchemaTool,
|
27
|
+
aql_retrieval_tool_name,
|
28
|
+
arango_schema_tool_name,
|
29
|
+
)
|
30
|
+
from langroid.agent.tools.orchestration import DoneTool, ForwardTool
|
31
|
+
from langroid.exceptions import LangroidImportError
|
32
|
+
from langroid.mytypes import Entity
|
33
|
+
from langroid.pydantic_v1 import BaseModel, BaseSettings
|
34
|
+
from langroid.utils.constants import SEND_TO
|
35
|
+
|
36
|
+
logger = logging.getLogger(__name__)
|
37
|
+
console = Console()
|
38
|
+
|
39
|
+
ARANGO_ERROR_MSG = "There was an error in your AQL Query"
|
40
|
+
T = TypeVar("T")
|
41
|
+
|
42
|
+
|
43
|
+
class ArangoSettings(BaseSettings):
|
44
|
+
client: ArangoClient | None = None
|
45
|
+
db: StandardDatabase | None = None
|
46
|
+
url: str = ""
|
47
|
+
username: str = ""
|
48
|
+
password: str = ""
|
49
|
+
database: str = ""
|
50
|
+
|
51
|
+
class Config:
|
52
|
+
env_prefix = "ARANGO_"
|
53
|
+
|
54
|
+
|
55
|
+
class QueryResult(BaseModel):
|
56
|
+
success: bool
|
57
|
+
data: Optional[
|
58
|
+
Union[
|
59
|
+
str,
|
60
|
+
int,
|
61
|
+
float,
|
62
|
+
bool,
|
63
|
+
None,
|
64
|
+
List[Any],
|
65
|
+
Dict[str, Any],
|
66
|
+
List[Dict[str, Any]],
|
67
|
+
]
|
68
|
+
] = None
|
69
|
+
|
70
|
+
class Config:
|
71
|
+
# Allow arbitrary types for flexibility
|
72
|
+
arbitrary_types_allowed = True
|
73
|
+
|
74
|
+
# Handle JSON serialization of special types
|
75
|
+
json_encoders = {
|
76
|
+
# Add custom encoders if needed, e.g.:
|
77
|
+
datetime.datetime: lambda v: v.isoformat(),
|
78
|
+
# Could add others for specific ArangoDB types
|
79
|
+
}
|
80
|
+
|
81
|
+
# Validate all assignments
|
82
|
+
validate_assignment = True
|
83
|
+
|
84
|
+
# Frozen=True if we want immutability
|
85
|
+
frozen = False
|
86
|
+
|
87
|
+
|
88
|
+
class ArangoChatAgentConfig(ChatAgentConfig):
|
89
|
+
arango_settings: ArangoSettings = ArangoSettings()
|
90
|
+
system_message: str = DEFAULT_ARANGO_CHAT_SYSTEM_MESSAGE
|
91
|
+
kg_schema: Optional[Dict[str, List[Dict[str, Any]]]] = None
|
92
|
+
database_created: bool = False
|
93
|
+
use_schema_tools: bool = True
|
94
|
+
use_functions_api: bool = True
|
95
|
+
max_result_tokens: int = 1000 # truncate long results to this many tokens
|
96
|
+
use_tools: bool = False
|
97
|
+
schema_sample_pct: float = 0
|
98
|
+
# whether the agent is used in a continuous chat with user,
|
99
|
+
# as opposed to returning a result from the task.run()
|
100
|
+
chat_mode: bool = False
|
101
|
+
addressing_prefix: str = ""
|
102
|
+
|
103
|
+
|
104
|
+
class ArangoChatAgent(ChatAgent):
|
105
|
+
def __init__(self, config: ArangoChatAgentConfig):
|
106
|
+
self.config: ArangoChatAgentConfig = config
|
107
|
+
self._validate_config()
|
108
|
+
self._import_arango()
|
109
|
+
self._initialize_db()
|
110
|
+
self._init_tools_sys_message()
|
111
|
+
self.init_state()
|
112
|
+
|
113
|
+
def init_state(self) -> None:
|
114
|
+
super().init_state()
|
115
|
+
self.current_retrieval_aql_query: str = ""
|
116
|
+
|
117
|
+
def _validate_config(self) -> None:
|
118
|
+
assert isinstance(self.config, ArangoChatAgentConfig)
|
119
|
+
if (
|
120
|
+
self.config.arango_settings.client is None
|
121
|
+
or self.config.arango_settings.db is None
|
122
|
+
):
|
123
|
+
if not all(
|
124
|
+
[
|
125
|
+
self.config.arango_settings.url,
|
126
|
+
self.config.arango_settings.username,
|
127
|
+
self.config.arango_settings.password,
|
128
|
+
self.config.arango_settings.database,
|
129
|
+
]
|
130
|
+
):
|
131
|
+
raise ValueError("ArangoDB connection info must be provided")
|
132
|
+
|
133
|
+
def _import_arango(self) -> None:
|
134
|
+
global ArangoClient
|
135
|
+
try:
|
136
|
+
from arango.client import ArangoClient
|
137
|
+
except ImportError:
|
138
|
+
raise LangroidImportError("python-arango", "arango")
|
139
|
+
|
140
|
+
def _has_any_data(self) -> bool:
|
141
|
+
for c in self.db.collections(): # type: ignore
|
142
|
+
if c["name"].startswith("_"):
|
143
|
+
continue
|
144
|
+
if self.db.collection(c["name"]).count() > 0: # type: ignore
|
145
|
+
return True
|
146
|
+
return False
|
147
|
+
|
148
|
+
def _initialize_db(self) -> None:
|
149
|
+
try:
|
150
|
+
logger.info("Initializing ArangoDB client connection...")
|
151
|
+
self.client = self.config.arango_settings.client or ArangoClient(
|
152
|
+
hosts=self.config.arango_settings.url
|
153
|
+
)
|
154
|
+
|
155
|
+
logger.info("Connecting to database...")
|
156
|
+
self.db = self.config.arango_settings.db or self.client.db(
|
157
|
+
self.config.arango_settings.database,
|
158
|
+
username=self.config.arango_settings.username,
|
159
|
+
password=self.config.arango_settings.password,
|
160
|
+
)
|
161
|
+
|
162
|
+
logger.info("Checking for existing data in collections...")
|
163
|
+
# Check if any non-system collection has data
|
164
|
+
self.config.database_created = self._has_any_data()
|
165
|
+
|
166
|
+
# If database has data, get schema
|
167
|
+
if self.config.database_created:
|
168
|
+
logger.info("Database has existing data, retrieving schema...")
|
169
|
+
# this updates self.config.kg_schema
|
170
|
+
self.arango_schema_tool(None)
|
171
|
+
else:
|
172
|
+
logger.info("No existing data found in database")
|
173
|
+
|
174
|
+
except Exception as e:
|
175
|
+
logger.error(f"Database initialization failed: {e}")
|
176
|
+
raise ConnectionError(f"Failed to initialize ArangoDB connection: {e}")
|
177
|
+
|
178
|
+
def close(self) -> None:
|
179
|
+
if self.client:
|
180
|
+
self.client.close()
|
181
|
+
|
182
|
+
@staticmethod
|
183
|
+
def cleanup_graph_db(db) -> None: # type: ignore
|
184
|
+
# First delete graphs to properly handle edge collections
|
185
|
+
for graph in db.graphs():
|
186
|
+
graph_name = graph["name"]
|
187
|
+
if not graph_name.startswith("_"): # Skip system graphs
|
188
|
+
try:
|
189
|
+
db.delete_graph(graph_name)
|
190
|
+
except Exception as e:
|
191
|
+
print(f"Failed to delete graph {graph_name}: {e}")
|
192
|
+
|
193
|
+
# Clear existing collections
|
194
|
+
for collection in db.collections():
|
195
|
+
if not collection["name"].startswith("_"): # Skip system collections
|
196
|
+
try:
|
197
|
+
db.delete_collection(collection["name"])
|
198
|
+
except Exception as e:
|
199
|
+
print(f"Failed to delete collection {collection['name']}: {e}")
|
200
|
+
|
201
|
+
def with_retry(
|
202
|
+
self, func: Callable[[], T], max_retries: int = 3, delay: float = 1.0
|
203
|
+
) -> T:
|
204
|
+
"""Execute a function with retries on connection error"""
|
205
|
+
for attempt in range(max_retries):
|
206
|
+
try:
|
207
|
+
return func()
|
208
|
+
except ArangoError:
|
209
|
+
if attempt == max_retries - 1:
|
210
|
+
raise
|
211
|
+
logger.warning(
|
212
|
+
f"Connection failed (attempt {attempt + 1}/{max_retries}). "
|
213
|
+
f"Retrying in {delay} seconds..."
|
214
|
+
)
|
215
|
+
time.sleep(delay)
|
216
|
+
# Reconnect if needed
|
217
|
+
self._initialize_db()
|
218
|
+
return func() # Final attempt after loop if not raised
|
219
|
+
|
220
|
+
def read_query(
|
221
|
+
self, query: str, bind_vars: Optional[Dict[Any, Any]] = None
|
222
|
+
) -> QueryResult:
|
223
|
+
"""Execute a read query with connection retry."""
|
224
|
+
if not self.db:
|
225
|
+
return QueryResult(
|
226
|
+
success=False, data="No database connection is established."
|
227
|
+
)
|
228
|
+
|
229
|
+
def execute_read() -> QueryResult:
|
230
|
+
try:
|
231
|
+
cursor = self.db.aql.execute(query, bind_vars=bind_vars)
|
232
|
+
records = [doc for doc in cursor] # type: ignore
|
233
|
+
logger.warning(f"Records retrieved: {records}")
|
234
|
+
return QueryResult(success=True, data=records if records else [])
|
235
|
+
except Exception as e:
|
236
|
+
if isinstance(e, ServerConnectionError):
|
237
|
+
raise
|
238
|
+
logger.error(f"Failed to execute query: {query}\n{e}")
|
239
|
+
error_message = self.retry_query(e, query)
|
240
|
+
return QueryResult(success=False, data=error_message)
|
241
|
+
|
242
|
+
try:
|
243
|
+
return self.with_retry(execute_read) # type: ignore
|
244
|
+
except Exception as e:
|
245
|
+
return QueryResult(
|
246
|
+
success=False, data=f"Failed after max retries: {str(e)}"
|
247
|
+
)
|
248
|
+
|
249
|
+
def write_query(
|
250
|
+
self, query: str, bind_vars: Optional[Dict[Any, Any]] = None
|
251
|
+
) -> QueryResult:
|
252
|
+
"""Execute a write query with connection retry."""
|
253
|
+
if not self.db:
|
254
|
+
return QueryResult(
|
255
|
+
success=False, data="No database connection is established."
|
256
|
+
)
|
257
|
+
|
258
|
+
def execute_write() -> QueryResult:
|
259
|
+
try:
|
260
|
+
self.db.aql.execute(query, bind_vars=bind_vars)
|
261
|
+
return QueryResult(success=True)
|
262
|
+
except Exception as e:
|
263
|
+
if isinstance(e, ServerConnectionError):
|
264
|
+
raise
|
265
|
+
logger.error(f"Failed to execute query: {query}\n{e}")
|
266
|
+
error_message = self.retry_query(e, query)
|
267
|
+
return QueryResult(success=False, data=error_message)
|
268
|
+
|
269
|
+
try:
|
270
|
+
return self.with_retry(execute_write) # type: ignore
|
271
|
+
except Exception as e:
|
272
|
+
return QueryResult(
|
273
|
+
success=False, data=f"Failed after max retries: {str(e)}"
|
274
|
+
)
|
275
|
+
|
276
|
+
def aql_retrieval_tool(self, msg: AQLRetrievalTool) -> str:
|
277
|
+
"""Handle AQL query for data retrieval"""
|
278
|
+
if not self.tried_schema:
|
279
|
+
return f"""
|
280
|
+
You need to use `{arango_schema_tool_name}` first to get the
|
281
|
+
database schema before using `{aql_retrieval_tool_name}`. This ensures
|
282
|
+
you know the correct collection names and edge definitions.
|
283
|
+
"""
|
284
|
+
elif not self.config.database_created:
|
285
|
+
return """
|
286
|
+
You need to create the database first using `{aql_creation_tool_name}`.
|
287
|
+
"""
|
288
|
+
query = msg.aql_query
|
289
|
+
self.current_retrieval_aql_query = query
|
290
|
+
logger.info(f"Executing AQL query: {query}")
|
291
|
+
response = self.read_query(query)
|
292
|
+
|
293
|
+
if isinstance(response.data, list) and len(response.data) == 0:
|
294
|
+
return """
|
295
|
+
No results found. Check if your collection names are correct -
|
296
|
+
they are case-sensitive. Use exact names from the schema.
|
297
|
+
Try modifying your query based on the RETRY-SUGGESTIONS
|
298
|
+
in your instructions.
|
299
|
+
"""
|
300
|
+
# truncate long results
|
301
|
+
result = str(response.data)
|
302
|
+
n_toks = self.num_tokens(result)
|
303
|
+
if n_toks > self.config.max_result_tokens:
|
304
|
+
logger.warning(
|
305
|
+
f"""
|
306
|
+
Your query resulted in a large result of
|
307
|
+
{n_toks} tokens,
|
308
|
+
which will be truncated to {self.config.max_result_tokens} tokens.
|
309
|
+
If this does not give satisfactory results,
|
310
|
+
please retry with a more focused query.
|
311
|
+
"""
|
312
|
+
)
|
313
|
+
if self.parser is not None:
|
314
|
+
result = self.parser.truncate_tokens(
|
315
|
+
result,
|
316
|
+
self.config.max_result_tokens,
|
317
|
+
)
|
318
|
+
else:
|
319
|
+
result = result[: self.config.max_result_tokens * 4] # truncate roughly
|
320
|
+
return result
|
321
|
+
|
322
|
+
def aql_creation_tool(self, msg: AQLCreationTool) -> str:
|
323
|
+
"""Handle AQL query for creating data"""
|
324
|
+
query = msg.aql_query
|
325
|
+
logger.info(f"Executing AQL query: {query}")
|
326
|
+
response = self.write_query(query)
|
327
|
+
|
328
|
+
if response.success:
|
329
|
+
self.config.database_created = True
|
330
|
+
return "AQL query executed successfully"
|
331
|
+
return str(response.data)
|
332
|
+
|
333
|
+
def arango_schema_tool(
|
334
|
+
self,
|
335
|
+
msg: ArangoSchemaTool | None,
|
336
|
+
) -> Dict[str, List[Dict[str, Any]]] | str:
|
337
|
+
"""Get database schema including collections, properties, and relationships"""
|
338
|
+
self.tried_schema = True
|
339
|
+
if self.config.kg_schema is not None and len(self.config.kg_schema) > 0:
|
340
|
+
return self.config.kg_schema
|
341
|
+
try:
|
342
|
+
# Get graph schemas
|
343
|
+
graph_schema = [
|
344
|
+
{"graph_name": g["name"], "edge_definitions": g["edge_definitions"]}
|
345
|
+
for g in self.db.graphs() # type: ignore
|
346
|
+
]
|
347
|
+
|
348
|
+
# Get collection schemas
|
349
|
+
collection_schema = []
|
350
|
+
for collection in self.db.collections(): # type: ignore
|
351
|
+
if collection["name"].startswith("_"): # Skip system collections
|
352
|
+
continue
|
353
|
+
|
354
|
+
col_name = collection["name"]
|
355
|
+
col_type = collection["type"]
|
356
|
+
col_size = self.db.collection(col_name).count()
|
357
|
+
|
358
|
+
if col_size == 0: # Skip empty collections
|
359
|
+
continue
|
360
|
+
|
361
|
+
# Calculate sample size
|
362
|
+
limit_amount = (
|
363
|
+
ceil(
|
364
|
+
self.config.schema_sample_pct * col_size / 100.0 # type: ignore
|
365
|
+
)
|
366
|
+
or 1
|
367
|
+
)
|
368
|
+
|
369
|
+
# Query to get sample documents and their properties
|
370
|
+
sample_query = f"""
|
371
|
+
FOR doc in {col_name}
|
372
|
+
LIMIT {limit_amount}
|
373
|
+
RETURN doc
|
374
|
+
"""
|
375
|
+
|
376
|
+
properties = []
|
377
|
+
example_doc = None
|
378
|
+
|
379
|
+
def simplify_doc(doc: Any) -> Any:
|
380
|
+
if isinstance(doc, list) and len(doc) > 0:
|
381
|
+
return [simplify_doc(doc[0])]
|
382
|
+
if isinstance(doc, dict):
|
383
|
+
return {k: simplify_doc(v) for k, v in doc.items()}
|
384
|
+
return doc
|
385
|
+
|
386
|
+
for doc in self.db.aql.execute(sample_query): # type: ignore
|
387
|
+
if example_doc is None:
|
388
|
+
example_doc = simplify_doc(doc)
|
389
|
+
for key, value in doc.items():
|
390
|
+
prop = {"name": key, "type": type(value).__name__}
|
391
|
+
if prop not in properties:
|
392
|
+
properties.append(prop)
|
393
|
+
|
394
|
+
collection_schema.append(
|
395
|
+
{
|
396
|
+
"collection_name": col_name,
|
397
|
+
"collection_type": col_type,
|
398
|
+
f"{col_type}_properties": properties,
|
399
|
+
f"example_{col_type}": example_doc,
|
400
|
+
}
|
401
|
+
)
|
402
|
+
|
403
|
+
schema = {
|
404
|
+
"Graph Schema": graph_schema,
|
405
|
+
"Collection Schema": collection_schema,
|
406
|
+
}
|
407
|
+
schema_str = json.dumps(schema, indent=2)
|
408
|
+
logger.warning(f"Schema retrieved:\n{schema_str}")
|
409
|
+
# save schema to file "logs/arangoo-schema.json"
|
410
|
+
with open("logs/arango-schema.json", "w") as f:
|
411
|
+
f.write(schema_str)
|
412
|
+
self.config.kg_schema = schema # type: ignore
|
413
|
+
return schema
|
414
|
+
|
415
|
+
except Exception as e:
|
416
|
+
logger.error(f"Schema retrieval failed: {str(e)}")
|
417
|
+
return f"Failed to retrieve schema: {str(e)}"
|
418
|
+
|
419
|
+
def _init_tools_sys_message(self) -> None:
|
420
|
+
"""Initialize system msg and enable tools"""
|
421
|
+
self.tried_schema = False
|
422
|
+
message = self._format_message()
|
423
|
+
self.config.system_message = self.config.system_message.format(mode=message)
|
424
|
+
|
425
|
+
if self.config.chat_mode:
|
426
|
+
self.config.addressing_prefix = self.config.addressing_prefix or SEND_TO
|
427
|
+
self.config.system_message += ADDRESSING_INSTRUCTION.format(
|
428
|
+
prefix=self.config.addressing_prefix
|
429
|
+
)
|
430
|
+
else:
|
431
|
+
self.config.system_message += DONE_INSTRUCTION
|
432
|
+
|
433
|
+
super().__init__(self.config)
|
434
|
+
# Note we are enabling GraphSchemaTool regardless of whether
|
435
|
+
# self.config.use_schema_tools is True or False, because
|
436
|
+
# even when schema provided, the agent may later want to get the schema,
|
437
|
+
# e.g. if the db evolves, or if it needs to bring in the schema
|
438
|
+
|
439
|
+
self.enable_message(
|
440
|
+
[
|
441
|
+
ArangoSchemaTool,
|
442
|
+
AQLRetrievalTool,
|
443
|
+
AQLCreationTool,
|
444
|
+
ForwardTool,
|
445
|
+
]
|
446
|
+
)
|
447
|
+
if not self.config.chat_mode:
|
448
|
+
self.enable_message(DoneTool)
|
449
|
+
|
450
|
+
def _format_message(self) -> str:
|
451
|
+
if self.db is None:
|
452
|
+
raise ValueError("Database connection not established")
|
453
|
+
|
454
|
+
assert isinstance(self.config, ArangoChatAgentConfig)
|
455
|
+
return (
|
456
|
+
SCHEMA_TOOLS_SYS_MSG
|
457
|
+
if self.config.use_schema_tools
|
458
|
+
else SCHEMA_PROVIDED_SYS_MSG.format(schema=self.arango_schema_tool(None))
|
459
|
+
)
|
460
|
+
|
461
|
+
def handle_message_fallback(
|
462
|
+
self, msg: str | ChatDocument
|
463
|
+
) -> str | ForwardTool | None:
|
464
|
+
"""When LLM sends a no-tool msg, assume user is the intended recipient,
|
465
|
+
and if in interactive mode, forward the msg to the user.
|
466
|
+
"""
|
467
|
+
done_tool_name = DoneTool.default_value("request")
|
468
|
+
forward_tool_name = ForwardTool.default_value("request")
|
469
|
+
aql_retrieval_tool_instructions = AQLRetrievalTool.instructions()
|
470
|
+
# TODO the aql_retrieval_tool_instructions may be empty/minimal
|
471
|
+
# when using self.config.use_functions_api = True.
|
472
|
+
tools_instruction = f"""
|
473
|
+
For example you may want to use the TOOL
|
474
|
+
`{aql_retrieval_tool_name}` according to these instructions:
|
475
|
+
{aql_retrieval_tool_instructions}
|
476
|
+
"""
|
477
|
+
if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM:
|
478
|
+
if self.interactive:
|
479
|
+
return ForwardTool(agent="User")
|
480
|
+
else:
|
481
|
+
if self.config.chat_mode:
|
482
|
+
return f"""
|
483
|
+
Since you did not explicitly address the User, it is not clear
|
484
|
+
whether:
|
485
|
+
- you intend this to be the final response to the
|
486
|
+
user's query/request, in which case you must use the
|
487
|
+
`{forward_tool_name}` to indicate this.
|
488
|
+
- OR, you FORGOT to use an Appropriate TOOL,
|
489
|
+
in which case you should use the available tools to
|
490
|
+
make progress on the user's query/request.
|
491
|
+
{tools_instruction}
|
492
|
+
"""
|
493
|
+
return f"""
|
494
|
+
The intent of your response is not clear:
|
495
|
+
- if you intended this to be the FINAL answer to the user's query,
|
496
|
+
then use the `{done_tool_name}` to indicate so,
|
497
|
+
with the `content` set to the answer or result.
|
498
|
+
- otherwise, use one of the available tools to make progress
|
499
|
+
to arrive at the final answer.
|
500
|
+
{tools_instruction}
|
501
|
+
"""
|
502
|
+
return None
|
503
|
+
|
504
|
+
def retry_query(self, e: Exception, query: str) -> str:
|
505
|
+
"""Generate error message for failed AQL query"""
|
506
|
+
logger.error(f"AQL Query failed: {query}\nException: {e}")
|
507
|
+
|
508
|
+
error_message = f"""\
|
509
|
+
{ARANGO_ERROR_MSG}: '{query}'
|
510
|
+
{str(e)}
|
511
|
+
Please try again with a corrected query.
|
512
|
+
"""
|
513
|
+
|
514
|
+
return error_message
|