langroid 0.58.2__py3-none-any.whl → 0.59.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. langroid/agent/base.py +39 -17
  2. langroid/agent/base.py-e +2216 -0
  3. langroid/agent/callbacks/chainlit.py +2 -1
  4. langroid/agent/chat_agent.py +73 -55
  5. langroid/agent/chat_agent.py-e +2086 -0
  6. langroid/agent/chat_document.py +7 -7
  7. langroid/agent/chat_document.py-e +513 -0
  8. langroid/agent/openai_assistant.py +9 -9
  9. langroid/agent/openai_assistant.py-e +882 -0
  10. langroid/agent/special/arangodb/arangodb_agent.py +10 -18
  11. langroid/agent/special/arangodb/arangodb_agent.py-e +648 -0
  12. langroid/agent/special/arangodb/tools.py +3 -3
  13. langroid/agent/special/doc_chat_agent.py +16 -14
  14. langroid/agent/special/lance_rag/critic_agent.py +2 -2
  15. langroid/agent/special/lance_rag/query_planner_agent.py +4 -4
  16. langroid/agent/special/lance_tools.py +6 -5
  17. langroid/agent/special/lance_tools.py-e +61 -0
  18. langroid/agent/special/neo4j/neo4j_chat_agent.py +3 -7
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py-e +430 -0
  20. langroid/agent/special/relevance_extractor_agent.py +1 -1
  21. langroid/agent/special/sql/sql_chat_agent.py +11 -3
  22. langroid/agent/task.py +9 -87
  23. langroid/agent/task.py-e +2418 -0
  24. langroid/agent/tool_message.py +33 -17
  25. langroid/agent/tool_message.py-e +400 -0
  26. langroid/agent/tools/file_tools.py +4 -2
  27. langroid/agent/tools/file_tools.py-e +234 -0
  28. langroid/agent/tools/mcp/fastmcp_client.py +19 -6
  29. langroid/agent/tools/mcp/fastmcp_client.py-e +584 -0
  30. langroid/agent/tools/orchestration.py +22 -17
  31. langroid/agent/tools/orchestration.py-e +301 -0
  32. langroid/agent/tools/recipient_tool.py +3 -3
  33. langroid/agent/tools/task_tool.py +22 -16
  34. langroid/agent/tools/task_tool.py-e +249 -0
  35. langroid/agent/xml_tool_message.py +90 -35
  36. langroid/agent/xml_tool_message.py-e +392 -0
  37. langroid/cachedb/base.py +1 -1
  38. langroid/embedding_models/base.py +2 -2
  39. langroid/embedding_models/models.py +3 -7
  40. langroid/embedding_models/models.py-e +563 -0
  41. langroid/exceptions.py +4 -1
  42. langroid/language_models/azure_openai.py +2 -2
  43. langroid/language_models/azure_openai.py-e +134 -0
  44. langroid/language_models/base.py +6 -4
  45. langroid/language_models/base.py-e +812 -0
  46. langroid/language_models/client_cache.py +64 -0
  47. langroid/language_models/config.py +2 -4
  48. langroid/language_models/config.py-e +18 -0
  49. langroid/language_models/model_info.py +9 -1
  50. langroid/language_models/model_info.py-e +483 -0
  51. langroid/language_models/openai_gpt.py +119 -20
  52. langroid/language_models/openai_gpt.py-e +2280 -0
  53. langroid/language_models/provider_params.py +3 -22
  54. langroid/language_models/provider_params.py-e +153 -0
  55. langroid/mytypes.py +11 -4
  56. langroid/mytypes.py-e +132 -0
  57. langroid/parsing/code_parser.py +1 -1
  58. langroid/parsing/file_attachment.py +1 -1
  59. langroid/parsing/file_attachment.py-e +246 -0
  60. langroid/parsing/md_parser.py +14 -4
  61. langroid/parsing/md_parser.py-e +574 -0
  62. langroid/parsing/parser.py +22 -7
  63. langroid/parsing/parser.py-e +410 -0
  64. langroid/parsing/repo_loader.py +3 -1
  65. langroid/parsing/repo_loader.py-e +812 -0
  66. langroid/parsing/search.py +1 -1
  67. langroid/parsing/url_loader.py +17 -51
  68. langroid/parsing/url_loader.py-e +683 -0
  69. langroid/parsing/urls.py +5 -4
  70. langroid/parsing/urls.py-e +279 -0
  71. langroid/prompts/prompts_config.py +1 -1
  72. langroid/pydantic_v1/__init__.py +45 -6
  73. langroid/pydantic_v1/__init__.py-e +36 -0
  74. langroid/pydantic_v1/main.py +11 -4
  75. langroid/pydantic_v1/main.py-e +11 -0
  76. langroid/utils/configuration.py +13 -11
  77. langroid/utils/configuration.py-e +141 -0
  78. langroid/utils/constants.py +1 -1
  79. langroid/utils/constants.py-e +32 -0
  80. langroid/utils/globals.py +21 -5
  81. langroid/utils/globals.py-e +49 -0
  82. langroid/utils/html_logger.py +2 -1
  83. langroid/utils/html_logger.py-e +825 -0
  84. langroid/utils/object_registry.py +1 -1
  85. langroid/utils/object_registry.py-e +66 -0
  86. langroid/utils/pydantic_utils.py +55 -28
  87. langroid/utils/pydantic_utils.py-e +602 -0
  88. langroid/utils/types.py +2 -2
  89. langroid/utils/types.py-e +113 -0
  90. langroid/vector_store/base.py +3 -3
  91. langroid/vector_store/lancedb.py +5 -5
  92. langroid/vector_store/lancedb.py-e +404 -0
  93. langroid/vector_store/meilisearch.py +2 -2
  94. langroid/vector_store/pineconedb.py +4 -4
  95. langroid/vector_store/pineconedb.py-e +427 -0
  96. langroid/vector_store/postgres.py +1 -1
  97. langroid/vector_store/qdrantdb.py +3 -3
  98. langroid/vector_store/weaviatedb.py +1 -1
  99. {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/METADATA +3 -2
  100. langroid-0.59.0b1.dist-info/RECORD +181 -0
  101. langroid/agent/special/doc_chat_task.py +0 -0
  102. langroid/mcp/__init__.py +0 -1
  103. langroid/mcp/server/__init__.py +0 -1
  104. langroid-0.58.2.dist-info/RECORD +0 -145
  105. {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/WHEEL +0 -0
  106. {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -8,6 +8,8 @@ from arango.client import ArangoClient
8
8
  from arango.database import StandardDatabase
9
9
  from arango.exceptions import ArangoError, ServerConnectionError
10
10
  from numpy import ceil
11
+ from pydantic import BaseModel, ConfigDict
12
+ from pydantic_settings import BaseSettings, SettingsConfigDict
11
13
  from rich import print
12
14
  from rich.console import Console
13
15
 
@@ -31,7 +33,6 @@ from langroid.agent.special.arangodb.utils import count_fields, trim_schema
31
33
  from langroid.agent.tools.orchestration import DoneTool, ForwardTool
32
34
  from langroid.exceptions import LangroidImportError
33
35
  from langroid.mytypes import Entity
34
- from langroid.pydantic_v1 import BaseModel, BaseSettings
35
36
  from langroid.utils.constants import SEND_TO
36
37
 
37
38
  logger = logging.getLogger(__name__)
@@ -49,8 +50,7 @@ class ArangoSettings(BaseSettings):
49
50
  password: str = ""
50
51
  database: str = ""
51
52
 
52
- class Config:
53
- env_prefix = "ARANGO_"
53
+ model_config = SettingsConfigDict(env_prefix="ARANGO_")
54
54
 
55
55
 
56
56
  class QueryResult(BaseModel):
@@ -68,22 +68,14 @@ class QueryResult(BaseModel):
68
68
  ]
69
69
  ] = None
70
70
 
71
- class Config:
72
- # Allow arbitrary types for flexibility
73
- arbitrary_types_allowed = True
74
-
75
- # Handle JSON serialization of special types
76
- json_encoders = {
77
- # Add custom encoders if needed, e.g.:
71
+ model_config = ConfigDict(
72
+ arbitrary_types_allowed=True,
73
+ json_encoders={
78
74
  datetime.datetime: lambda v: v.isoformat(),
79
- # Could add others for specific ArangoDB types
80
- }
81
-
82
- # Validate all assignments
83
- validate_assignment = True
84
-
85
- # Frozen=True if we want immutability
86
- frozen = False
75
+ },
76
+ validate_assignment=True,
77
+ frozen=False,
78
+ )
87
79
 
88
80
 
89
81
  class ArangoChatAgentConfig(ChatAgentConfig):
@@ -0,0 +1,648 @@
1
+ import datetime
2
+ import json
3
+ import logging
4
+ import time
5
+ from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
6
+
7
+ from arango.client import ArangoClient
8
+ from arango.database import StandardDatabase
9
+ from arango.exceptions import ArangoError, ServerConnectionError
10
+ from numpy import ceil
11
+ from pydantic_settings import BaseSettings
12
+ from rich import print
13
+ from rich.console import Console
14
+
15
+ from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
16
+ from langroid.agent.chat_document import ChatDocument
17
+ from langroid.agent.special.arangodb.system_messages import (
18
+ ADDRESSING_INSTRUCTION,
19
+ DEFAULT_ARANGO_CHAT_SYSTEM_MESSAGE,
20
+ DONE_INSTRUCTION,
21
+ SCHEMA_PROVIDED_SYS_MSG,
22
+ SCHEMA_TOOLS_SYS_MSG,
23
+ )
24
+ from langroid.agent.special.arangodb.tools import (
25
+ AQLCreationTool,
26
+ AQLRetrievalTool,
27
+ ArangoSchemaTool,
28
+ aql_retrieval_tool_name,
29
+ arango_schema_tool_name,
30
+ )
31
+ from langroid.agent.special.arangodb.utils import count_fields, trim_schema
32
+ from langroid.agent.tools.orchestration import DoneTool, ForwardTool
33
+ from langroid.exceptions import LangroidImportError
34
+ from langroid.mytypes import Entity
35
+ from pydantic import BaseModel, ConfigDict
36
+ from langroid.utils.constants import SEND_TO
37
+
38
+ logger = logging.getLogger(__name__)
39
+ console = Console()
40
+
41
+ ARANGO_ERROR_MSG = "There was an error in your AQL Query"
42
+ T = TypeVar("T")
43
+
44
+
45
+ class ArangoSettings(BaseSettings):
46
+ client: ArangoClient | None = None
47
+ db: StandardDatabase | None = None
48
+ url: str = ""
49
+ username: str = ""
50
+ password: str = ""
51
+ database: str = ""
52
+
53
+ model_config = ConfigDict(env_prefix="ARANGO_")
54
+
55
+
56
+ class QueryResult(BaseModel):
57
+ success: bool
58
+ data: Optional[
59
+ Union[
60
+ str,
61
+ int,
62
+ float,
63
+ bool,
64
+ None,
65
+ List[Any],
66
+ Dict[str, Any],
67
+ List[Dict[str, Any]],
68
+ ]
69
+ ] = None
70
+
71
+ model_config = ConfigDict(
72
+ arbitrary_types_allowed=True,
73
+ json_encoders={
74
+ datetime.datetime: lambda v: v.isoformat(),
75
+ },
76
+ validate_assignment=True,
77
+ frozen=False,
78
+ )
79
+
80
+
81
+ class ArangoChatAgentConfig(ChatAgentConfig):
82
+ arango_settings: ArangoSettings = ArangoSettings()
83
+ system_message: str = DEFAULT_ARANGO_CHAT_SYSTEM_MESSAGE
84
+ kg_schema: str | Dict[str, List[Dict[str, Any]]] | None = None
85
+ database_created: bool = False
86
+ prepopulate_schema: bool = True
87
+ use_functions_api: bool = True
88
+ max_num_results: int = 10 # how many results to return from AQL query
89
+ max_schema_fields: int = 500 # max fields to show in schema
90
+ max_tries: int = 10 # how many attempts to answer user question
91
+ use_tools: bool = False
92
+ schema_sample_pct: float = 0
93
+ # whether the agent is used in a continuous chat with user,
94
+ # as opposed to returning a result from the task.run()
95
+ chat_mode: bool = False
96
+ addressing_prefix: str = ""
97
+
98
+
99
+ class ArangoChatAgent(ChatAgent):
100
+ def __init__(self, config: ArangoChatAgentConfig):
101
+ super().__init__(config)
102
+ self.config: ArangoChatAgentConfig = config
103
+ self.init_state()
104
+ self._validate_config()
105
+ self._import_arango()
106
+ self._initialize_db()
107
+ self._init_tools_sys_message()
108
+
109
+ def init_state(self) -> None:
110
+ super().init_state()
111
+ self.current_retrieval_aql_query: str = ""
112
+ self.current_schema_params: ArangoSchemaTool = ArangoSchemaTool()
113
+ self.num_tries = 0 # how many attempts to answer user question
114
+
115
+ def user_response(
116
+ self,
117
+ msg: Optional[str | ChatDocument] = None,
118
+ ) -> Optional[ChatDocument]:
119
+ response = super().user_response(msg)
120
+ if response is None:
121
+ return None
122
+ response_str = response.content if response is not None else ""
123
+ if response_str != "":
124
+ self.num_tries = 0 # reset number of tries if user responds
125
+ return response
126
+
127
+ def llm_response(
128
+ self, message: Optional[str | ChatDocument] = None
129
+ ) -> Optional[ChatDocument]:
130
+ if self.num_tries > self.config.max_tries:
131
+ if self.config.chat_mode:
132
+ return self.create_llm_response(
133
+ content=f"""
134
+ {self.config.addressing_prefix}User
135
+ I give up, since I have exceeded the
136
+ maximum number of tries ({self.config.max_tries}).
137
+ Feel free to give me some hints!
138
+ """
139
+ )
140
+ else:
141
+ return self.create_llm_response(
142
+ tool_messages=[
143
+ DoneTool(
144
+ content=f"""
145
+ Exceeded maximum number of tries ({self.config.max_tries}).
146
+ """
147
+ )
148
+ ]
149
+ )
150
+
151
+ if isinstance(message, ChatDocument) and message.metadata.sender == Entity.USER:
152
+ message.content = (
153
+ message.content
154
+ + "\n"
155
+ + """
156
+ (REMEMBER, Do NOT use more than ONE TOOL/FUNCTION at a time!
157
+ you must WAIT for a helper to send you the RESULT(S) before
158
+ making another TOOL/FUNCTION call)
159
+ """
160
+ )
161
+
162
+ response = super().llm_response(message)
163
+ if (
164
+ response is not None
165
+ and self.config.chat_mode
166
+ and self.config.addressing_prefix in response.content
167
+ and self.has_tool_message_attempt(response)
168
+ ):
169
+ # response contains both a user-addressing and a tool, which
170
+ # is not allowed, so remove the user-addressing prefix
171
+ response.content = response.content.replace(
172
+ self.config.addressing_prefix, ""
173
+ )
174
+
175
+ return response
176
+
177
+ def _validate_config(self) -> None:
178
+ assert isinstance(self.config, ArangoChatAgentConfig)
179
+ if (
180
+ self.config.arango_settings.client is None
181
+ or self.config.arango_settings.db is None
182
+ ):
183
+ if not all(
184
+ [
185
+ self.config.arango_settings.url,
186
+ self.config.arango_settings.username,
187
+ self.config.arango_settings.password,
188
+ self.config.arango_settings.database,
189
+ ]
190
+ ):
191
+ raise ValueError("ArangoDB connection info must be provided")
192
+
193
+ def _import_arango(self) -> None:
194
+ global ArangoClient
195
+ try:
196
+ from arango.client import ArangoClient
197
+ except ImportError:
198
+ raise LangroidImportError("python-arango", "arango")
199
+
200
+ def _has_any_data(self) -> bool:
201
+ for c in self.db.collections(): # type: ignore
202
+ if c["name"].startswith("_"):
203
+ continue
204
+ if self.db.collection(c["name"]).count() > 0: # type: ignore
205
+ return True
206
+ return False
207
+
208
+ def _initialize_db(self) -> None:
209
+ try:
210
+ logger.info("Initializing ArangoDB client connection...")
211
+ self.client = self.config.arango_settings.client or ArangoClient(
212
+ hosts=self.config.arango_settings.url
213
+ )
214
+
215
+ logger.info("Connecting to database...")
216
+ self.db = self.config.arango_settings.db or self.client.db(
217
+ self.config.arango_settings.database,
218
+ username=self.config.arango_settings.username,
219
+ password=self.config.arango_settings.password,
220
+ )
221
+
222
+ logger.info("Checking for existing data in collections...")
223
+ # Check if any non-system collection has data
224
+ self.config.database_created = self._has_any_data()
225
+
226
+ # If database has data, get schema
227
+ if self.config.database_created:
228
+ logger.info("Database has existing data, retrieving schema...")
229
+ # this updates self.config.kg_schema
230
+ self.arango_schema_tool(None)
231
+ else:
232
+ logger.info("No existing data found in database")
233
+
234
+ except Exception as e:
235
+ logger.error(f"Database initialization failed: {e}")
236
+ raise ConnectionError(f"Failed to initialize ArangoDB connection: {e}")
237
+
238
+ def close(self) -> None:
239
+ if self.client:
240
+ self.client.close()
241
+
242
+ @staticmethod
243
+ def cleanup_graph_db(db) -> None: # type: ignore
244
+ # First delete graphs to properly handle edge collections
245
+ for graph in db.graphs():
246
+ graph_name = graph["name"]
247
+ if not graph_name.startswith("_"): # Skip system graphs
248
+ try:
249
+ db.delete_graph(graph_name)
250
+ except Exception as e:
251
+ print(f"Failed to delete graph {graph_name}: {e}")
252
+
253
+ # Clear existing collections
254
+ for collection in db.collections():
255
+ if not collection["name"].startswith("_"): # Skip system collections
256
+ try:
257
+ db.delete_collection(collection["name"])
258
+ except Exception as e:
259
+ print(f"Failed to delete collection {collection['name']}: {e}")
260
+
261
+ def with_retry(
262
+ self, func: Callable[[], T], max_retries: int = 3, delay: float = 1.0
263
+ ) -> T:
264
+ """Execute a function with retries on connection error"""
265
+ for attempt in range(max_retries):
266
+ try:
267
+ return func()
268
+ except ArangoError:
269
+ if attempt == max_retries - 1:
270
+ raise
271
+ logger.warning(
272
+ f"Connection failed (attempt {attempt + 1}/{max_retries}). "
273
+ f"Retrying in {delay} seconds..."
274
+ )
275
+ time.sleep(delay)
276
+ # Reconnect if needed
277
+ self._initialize_db()
278
+ return func() # Final attempt after loop if not raised
279
+
280
+ def read_query(
281
+ self, query: str, bind_vars: Optional[Dict[Any, Any]] = None
282
+ ) -> QueryResult:
283
+ """Execute a read query with connection retry."""
284
+ if not self.db:
285
+ return QueryResult(
286
+ success=False, data="No database connection is established."
287
+ )
288
+
289
+ def execute_read() -> QueryResult:
290
+ try:
291
+ cursor = self.db.aql.execute(query, bind_vars=bind_vars)
292
+ records = [doc for doc in cursor] # type: ignore
293
+ records = records[: self.config.max_num_results]
294
+ logger.warning(f"Records retrieved: {records}")
295
+ return QueryResult(success=True, data=records if records else [])
296
+ except Exception as e:
297
+ if isinstance(e, ServerConnectionError):
298
+ raise
299
+ logger.error(f"Failed to execute query: {query}\n{e}")
300
+ error_message = self.retry_query(e, query)
301
+ return QueryResult(success=False, data=error_message)
302
+
303
+ try:
304
+ return self.with_retry(execute_read) # type: ignore
305
+ except Exception as e:
306
+ return QueryResult(
307
+ success=False, data=f"Failed after max retries: {str(e)}"
308
+ )
309
+
310
+ def write_query(
311
+ self, query: str, bind_vars: Optional[Dict[Any, Any]] = None
312
+ ) -> QueryResult:
313
+ """Execute a write query with connection retry."""
314
+ if not self.db:
315
+ return QueryResult(
316
+ success=False, data="No database connection is established."
317
+ )
318
+
319
+ def execute_write() -> QueryResult:
320
+ try:
321
+ self.db.aql.execute(query, bind_vars=bind_vars)
322
+ return QueryResult(success=True)
323
+ except Exception as e:
324
+ if isinstance(e, ServerConnectionError):
325
+ raise
326
+ logger.error(f"Failed to execute query: {query}\n{e}")
327
+ error_message = self.retry_query(e, query)
328
+ return QueryResult(success=False, data=error_message)
329
+
330
+ try:
331
+ return self.with_retry(execute_write) # type: ignore
332
+ except Exception as e:
333
+ return QueryResult(
334
+ success=False, data=f"Failed after max retries: {str(e)}"
335
+ )
336
+
337
+ def aql_retrieval_tool(self, msg: AQLRetrievalTool) -> str:
338
+ """Handle AQL query for data retrieval"""
339
+ if not self.tried_schema:
340
+ return f"""
341
+ You need to use `{arango_schema_tool_name}` first to get the
342
+ database schema before using `{aql_retrieval_tool_name}`. This ensures
343
+ you know the correct collection names and edge definitions.
344
+ """
345
+ elif not self.config.database_created:
346
+ return """
347
+ You need to create the database first using `{aql_creation_tool_name}`.
348
+ """
349
+ self.num_tries += 1
350
+ query = msg.aql_query
351
+ if query == self.current_retrieval_aql_query:
352
+ return """
353
+ You have already tried this query, so you will get the same results again!
354
+ If you need to retry, please MODIFY the query to get different results.
355
+ """
356
+ self.current_retrieval_aql_query = query
357
+ logger.info(f"Executing AQL query: {query}")
358
+ response = self.read_query(query)
359
+
360
+ if isinstance(response.data, list) and len(response.data) == 0:
361
+ return """
362
+ No results found. Check if your collection names are correct -
363
+ they are case-sensitive. Use exact names from the schema.
364
+ Try modifying your query based on the RETRY-SUGGESTIONS
365
+ in your instructions.
366
+ """
367
+ return str(response.data)
368
+
369
+ def aql_creation_tool(self, msg: AQLCreationTool) -> str:
370
+ """Handle AQL query for creating data"""
371
+ self.num_tries += 1
372
+ query = msg.aql_query
373
+ logger.info(f"Executing AQL query: {query}")
374
+ response = self.write_query(query)
375
+
376
+ if response.success:
377
+ self.config.database_created = True
378
+ return "AQL query executed successfully"
379
+ return str(response.data)
380
+
381
+ def arango_schema_tool(
382
+ self,
383
+ msg: ArangoSchemaTool | None,
384
+ ) -> Dict[str, List[Dict[str, Any]]] | str:
385
+ """Get database schema. If collections=None, include all collections.
386
+ If properties=False, show only connection info,
387
+ else show all properties and example-docs.
388
+ """
389
+
390
+ if (
391
+ msg is not None
392
+ and msg.collections == self.current_schema_params.collections
393
+ and msg.properties == self.current_schema_params.properties
394
+ ):
395
+ return """
396
+ You have already tried this schema TOOL, so you will get the same results
397
+ again! Please MODIFY the tool params `collections` or `properties` to get
398
+ different results.
399
+ """
400
+
401
+ if msg is not None:
402
+ collections = msg.collections
403
+ properties = msg.properties
404
+ else:
405
+ collections = None
406
+ properties = True
407
+ self.tried_schema = True
408
+ if (
409
+ self.config.kg_schema is not None
410
+ and len(self.config.kg_schema) > 0
411
+ and msg is None
412
+ ):
413
+ # we are trying to pre-populate full schema before the agent runs,
414
+ # so get it if it's already available
415
+ # (Note of course that this "full schema" may actually be incomplete)
416
+ return self.config.kg_schema
417
+
418
+ # increment tries only if the LLM is asking for the schema,
419
+ # in which case msg will not be None
420
+ self.num_tries += msg is not None
421
+
422
+ try:
423
+ # Get graph schemas (keeping full graph info)
424
+ graph_schema = [
425
+ {"graph_name": g["name"], "edge_definitions": g["edge_definitions"]}
426
+ for g in self.db.graphs() # type: ignore
427
+ ]
428
+
429
+ # Get collection schemas
430
+ collection_schema = []
431
+ for collection in self.db.collections(): # type: ignore
432
+ if collection["name"].startswith("_"):
433
+ continue
434
+
435
+ col_name = collection["name"]
436
+ if collections and col_name not in collections:
437
+ continue
438
+
439
+ col_type = collection["type"]
440
+ col_size = self.db.collection(col_name).count()
441
+
442
+ if col_size == 0:
443
+ continue
444
+
445
+ if properties:
446
+ # Full property collection with sampling
447
+ lim = self.config.schema_sample_pct * col_size # type: ignore
448
+ limit_amount = ceil(lim / 100.0) or 1
449
+ sample_query = f"""
450
+ FOR doc in {col_name}
451
+ LIMIT {limit_amount}
452
+ RETURN doc
453
+ """
454
+
455
+ properties_list = []
456
+ example_doc = None
457
+
458
+ def simplify_doc(doc: Any) -> Any:
459
+ if isinstance(doc, list) and len(doc) > 0:
460
+ return [simplify_doc(doc[0])]
461
+ if isinstance(doc, dict):
462
+ return {k: simplify_doc(v) for k, v in doc.items()}
463
+ return doc
464
+
465
+ for doc in self.db.aql.execute(sample_query): # type: ignore
466
+ if example_doc is None:
467
+ example_doc = simplify_doc(doc)
468
+ for key, value in doc.items():
469
+ prop = {"name": key, "type": type(value).__name__}
470
+ if prop not in properties_list:
471
+ properties_list.append(prop)
472
+
473
+ collection_schema.append(
474
+ {
475
+ "collection_name": col_name,
476
+ "collection_type": col_type,
477
+ f"{col_type}_properties": properties_list,
478
+ f"example_{col_type}": example_doc,
479
+ }
480
+ )
481
+ else:
482
+ # Basic info + from/to for edges only
483
+ collection_info = {
484
+ "collection_name": col_name,
485
+ "collection_type": col_type,
486
+ }
487
+ if col_type == "edge":
488
+ # Get a sample edge to extract from/to fields
489
+ sample_edge = next(
490
+ self.db.aql.execute( # type: ignore
491
+ f"FOR e IN {col_name} LIMIT 1 RETURN e"
492
+ ),
493
+ None,
494
+ )
495
+ if sample_edge:
496
+ collection_info["from_collection"] = sample_edge[
497
+ "_from"
498
+ ].split("/")[0]
499
+ collection_info["to_collection"] = sample_edge["_to"].split(
500
+ "/"
501
+ )[0]
502
+
503
+ collection_schema.append(collection_info)
504
+
505
+ schema = {
506
+ "Graph Schema": graph_schema,
507
+ "Collection Schema": collection_schema,
508
+ }
509
+ schema_str = json.dumps(schema, indent=2)
510
+ logger.warning(f"Schema retrieved:\n{schema_str}")
511
+ with open("logs/arango-schema.json", "w") as f:
512
+ f.write(schema_str)
513
+ if (n_fields := count_fields(schema)) > self.config.max_schema_fields:
514
+ logger.warning(
515
+ f"""
516
+ Schema has {n_fields} fields, which exceeds the maximum of
517
+ {self.config.max_schema_fields}. Showing a trimmed version
518
+ that only includes edge info and no other properties.
519
+ """
520
+ )
521
+ schema = trim_schema(schema)
522
+ n_fields = count_fields(schema)
523
+ logger.warning(f"Schema trimmed down to {n_fields} fields.")
524
+ schema_str = (
525
+ json.dumps(schema)
526
+ + "\n"
527
+ + f"""
528
+
529
+ CAUTION: The requested schema was too large, so
530
+ the schema has been trimmed down to show only all collection names,
531
+ their types,
532
+ and edge relationships (from/to collections) without any properties.
533
+ To find out more about the schema, you can EITHER:
534
+ - Use the `{arango_schema_tool_name}` tool again with the
535
+ `properties` arg set to True, and `collections` arg set to
536
+ specific collections you want to know more about, OR
537
+ - Use the `{aql_retrieval_tool_name}` tool to learn more about
538
+ the schema by querying the database.
539
+
540
+ """
541
+ )
542
+ if msg is None:
543
+ self.config.kg_schema = schema_str
544
+ return schema_str
545
+ self.config.kg_schema = schema
546
+ return schema
547
+
548
+ except Exception as e:
549
+ logger.error(f"Schema retrieval failed: {str(e)}")
550
+ return f"Failed to retrieve schema: {str(e)}"
551
+
552
+ def _init_tools_sys_message(self) -> None:
553
+ """Initialize system msg and enable tools"""
554
+ self.tried_schema = False
555
+ message = self._format_message()
556
+ self.config.system_message = self.config.system_message.format(mode=message)
557
+
558
+ if self.config.chat_mode:
559
+ self.config.addressing_prefix = self.config.addressing_prefix or SEND_TO
560
+ self.config.system_message += ADDRESSING_INSTRUCTION.format(
561
+ prefix=self.config.addressing_prefix
562
+ )
563
+ else:
564
+ self.config.system_message += DONE_INSTRUCTION
565
+
566
+ super().__init__(self.config)
567
+ # Note we are enabling GraphSchemaTool regardless of whether
568
+ # self.config.prepopulate_schema is True or False, because
569
+ # even when schema provided, the agent may later want to get the schema,
570
+ # e.g. if the db evolves, or schema was trimmed due to size, or
571
+ # if it needs to bring in the schema into recent context.
572
+
573
+ self.enable_message(
574
+ [
575
+ ArangoSchemaTool,
576
+ AQLRetrievalTool,
577
+ AQLCreationTool,
578
+ ForwardTool,
579
+ ]
580
+ )
581
+ if not self.config.chat_mode:
582
+ self.enable_message(DoneTool)
583
+
584
+ def _format_message(self) -> str:
585
+ if self.db is None:
586
+ raise ValueError("Database connection not established")
587
+
588
+ assert isinstance(self.config, ArangoChatAgentConfig)
589
+ return (
590
+ SCHEMA_TOOLS_SYS_MSG
591
+ if not self.config.prepopulate_schema
592
+ else SCHEMA_PROVIDED_SYS_MSG.format(schema=self.arango_schema_tool(None))
593
+ )
594
+
595
+ def handle_message_fallback(
596
+ self, msg: str | ChatDocument
597
+ ) -> str | ForwardTool | None:
598
+ """When LLM sends a no-tool msg, assume user is the intended recipient,
599
+ and if in interactive mode, forward the msg to the user.
600
+ """
601
+ done_tool_name = DoneTool.default_value("request")
602
+ forward_tool_name = ForwardTool.default_value("request")
603
+ aql_retrieval_tool_instructions = AQLRetrievalTool.instructions()
604
+ # TODO the aql_retrieval_tool_instructions may be empty/minimal
605
+ # when using self.config.use_functions_api = True.
606
+ tools_instruction = f"""
607
+ For example you may want to use the TOOL
608
+ `{aql_retrieval_tool_name}` according to these instructions:
609
+ {aql_retrieval_tool_instructions}
610
+ """
611
+ if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM:
612
+ if self.interactive:
613
+ return ForwardTool(agent="User")
614
+ else:
615
+ if self.config.chat_mode:
616
+ return f"""
617
+ Since you did not explicitly address the User, it is not clear
618
+ whether:
619
+ - you intend this to be the final response to the
620
+ user's query/request, in which case you must use the
621
+ `{forward_tool_name}` to indicate this.
622
+ - OR, you FORGOT to use an Appropriate TOOL,
623
+ in which case you should use the available tools to
624
+ make progress on the user's query/request.
625
+ {tools_instruction}
626
+ """
627
+ return f"""
628
+ The intent of your response is not clear:
629
+ - if you intended this to be the FINAL answer to the user's query,
630
+ then use the `{done_tool_name}` to indicate so,
631
+ with the `content` set to the answer or result.
632
+ - otherwise, use one of the available tools to make progress
633
+ to arrive at the final answer.
634
+ {tools_instruction}
635
+ """
636
+ return None
637
+
638
+ def retry_query(self, e: Exception, query: str) -> str:
639
+ """Generate error message for failed AQL query"""
640
+ logger.error(f"AQL Query failed: {query}\nException: {e}")
641
+
642
+ error_message = f"""\
643
+ {ARANGO_ERROR_MSG}: '{query}'
644
+ {str(e)}
645
+ Please try again with a corrected query.
646
+ """
647
+
648
+ return error_message