langroid 0.19.5__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,183 @@
1
+ from langroid.agent.special.arangodb.tools import (
2
+ aql_creation_tool_name,
3
+ aql_retrieval_tool_name,
4
+ arango_schema_tool_name,
5
+ )
6
+ from langroid.agent.tools.orchestration import DoneTool
7
+
8
+ done_tool_name = DoneTool.default_value("request")
9
+
10
+ arango_schema_tool_description = f"""
11
+ `{arango_schema_tool_name}` tool/function-call to find the schema
12
+ of the graph database, or for some SPECIFIC collections, i.e. get information on
13
+ (document and edge), their attributes, and graph definitions available in your
14
+ ArangoDB database. You MUST use this tool BEFORE attempting to use the
15
+ `{aql_retrieval_tool_name}` tool/function-call, to ensure that you are using the
16
+ correct collection names and attributes in your `{aql_retrieval_tool_name}` tool.
17
+ """
18
+
19
+ aql_retrieval_tool_description = f"""
20
+ `{aql_retrieval_tool_name}` tool/function-call to retrieve information from
21
+ the database using AQL (ArangoDB Query Language) queries, to answer
22
+ the user's questions, OR for you to learn more about the SCHEMA of the database.
23
+ """
24
+
25
+ aql_creation_tool_description = f"""
26
+ `{aql_creation_tool_name}` tool/function-call to execute AQL query that creates
27
+ documents/edges in the database.
28
+ """
29
+
30
+ aql_retrieval_query_example = """
31
+ EXAMPLE:
32
+ Suppose you are asked this question "Does Bob have a father?".
33
+ Then you will go through the following steps, where YOU indicates
34
+ the message YOU will be sending, and RESULTS indicates the RESULTS
35
+ you will receive from the helper executing the query:
36
+
37
+ 1. YOU:
38
+ {{ "request": "aql_retrieval_tool",
39
+ "aql_query": "FOR v, e, p in ... [query truncated for brevity]..."}}
40
+
41
+ 2. RESULTS:
42
+ [.. results from the query...]
43
+ 3. YOU: [ since results were not satisfactory, you try ANOTHER query]
44
+ {{ "request": "aql_retrieval_tool",
45
+ "aql_query": "blah blah ... [query truncated for brevity]..."}}
46
+ }}
47
+ 4. RESULTS:
48
+ [.. results from the query...]
49
+ 5. YOU: [ now you have the answer, you can generate your response ]
50
+ The answer is YES, Bob has a father, and his name is John.
51
+ """
52
+
53
+ aql_query_instructions = """
54
+ When writing AQL queries:
55
+ 1. Use the exact property names shown in the schema
56
+ 2. Pay attention to the 'type' field of each node
57
+ 3. Note that all names are case-sensitive:
58
+ - collection names
59
+ - property names
60
+ - node type values
61
+ - relationship type values
62
+ 4. Always include type filters in your queries, e.g.:
63
+ FILTER doc.type == '<type-from-schema>'
64
+
65
+ The schema shows:
66
+ - Collections (usually 'nodes' and 'edges')
67
+ - Node types in each collection
68
+ - Available properties for each node type
69
+ - Relationship types and their properties
70
+
71
+ Examine the schema carefully before writing queries to ensure:
72
+ - Correct property names
73
+ - Correct node types
74
+ - Correct relationship types
75
+
76
+ You must be smart about using the right collection names and attributes
77
+ based on the English description. If you are thinking of using a collection
78
+ or attribute that does not exist, you are probably on the wrong track,
79
+ so you should try your best to answer based on existing collections and attributes.
80
+ DO NOT assume any collections or graphs other than those above.
81
+ """
82
+
83
+ tool_result_instruction = """
84
+ REMEMBER:
85
+ [1] DO NOT FORGET TO USE ONE OF THE AVAILABLE TOOLS TO ANSWER THE USER'S QUERY!!
86
+ [2] When using a TOOL/FUNCTION, you MUST WAIT for the tool result before continuing
87
+ with your response. DO NOT MAKE UP RESULTS FROM A TOOL!
88
+ [3] YOU MUST NOT ANSWER queries from your OWN KNOWLEDGE; ALWAYS RELY ON
89
+ the result of a TOOL/FUNCTION to compose your response.
90
+ [4] Use ONLY ONE TOOL/FUNCTION at a TIME!
91
+ """
92
+ # sys msg to use when schema already provided initially,
93
+ # so agent should not use schema tool
94
+ SCHEMA_PROVIDED_SYS_MSG = f"""You are a data scientist and expert in Graph Databases,
95
+ with expertise in answering questions by interacting with an ArangoDB database.
96
+
97
+ The schema below describes the ArangoDB database structure,
98
+ collections (document and edge),
99
+ and their attribute keys available in your ArangoDB database.
100
+
101
+ === SCHEMA ===
102
+ {{schema}}
103
+ === END SCHEMA ===
104
+
105
+
106
+ To help with the user's question or database update/creation request,
107
+ you have access to these tools:
108
+
109
+ - {aql_retrieval_tool_description}
110
+
111
+ - {aql_creation_tool_description}
112
+
113
+
114
+ {tool_result_instruction}
115
+ """
116
+
117
+ # sys msg to use when schema is not initially provided,
118
+ # and we want agent to use schema tool to get schema
119
+ SCHEMA_TOOLS_SYS_MSG = f"""You are a data scientist and expert in
120
+ Arango Graph Databases,
121
+ with expertise in answering questions by querying ArangoDB database
122
+ using the Arango Query Language (AQL).
123
+ You have access to the following tools:
124
+
125
+ - {arango_schema_tool_description}
126
+
127
+ - {aql_retrieval_tool_description}
128
+
129
+ - {aql_creation_tool_description}
130
+
131
+ {tool_result_instruction}
132
+ """
133
+
134
+ DEFAULT_ARANGO_CHAT_SYSTEM_MESSAGE = f"""
135
+ {{mode}}
136
+
137
+ You do not need to be able to answer a question with just one query.
138
+ You can make a query, WAIT for the result,
139
+ THEN make ANOTHER query, WAIT for result,
140
+ THEN make ANOTHER query, and so on, until you have the answer.
141
+
142
+ {aql_query_instructions}
143
+
144
+ RETRY-SUGGESTIONS:
145
+ If you receive a null or other unexpected result,
146
+ (a) make sure you use the available TOOLs correctly,
147
+ (b) USE `{arango_schema_tool_name}` tool/function-call to get all collections,
148
+ their attributes and graph definitions available in your ArangoDB database.
149
+ (c) Collection names are CASE-SENSITIVE -- make sure you adhere to the exact
150
+ collection name you found in the schema.
151
+ (d) see if you have made an assumption in your AQL query, and try another way,
152
+ or use `{aql_retrieval_tool_name}` to explore the database contents before
153
+ submitting your final query.
154
+ (f) Try APPROXIMATE or PARTIAL MATCHES to strings in the user's query,
155
+ e.g. user may ask about "Godfather" instead of "The Godfather",
156
+ or try using CASE-INSENSITIVE MATCHES.
157
+
158
+ Start by asking what the user needs help with.
159
+
160
+ {tool_result_instruction}
161
+
162
+ {aql_retrieval_query_example}
163
+ """
164
+
165
+ ADDRESSING_INSTRUCTION = """
166
+ IMPORTANT - Whenever you are NOT writing an AQL query, make sure you address the
167
+ user using {prefix}User. You MUST use the EXACT syntax {prefix} !!!
168
+
169
+ In other words, you ALWAYS EITHER:
170
+ - write an AQL query using one of the tools,
171
+ - OR address the user using {prefix}User.
172
+
173
+ YOU CANNOT ADDRESS THE USER WHEN USING A TOOL!!
174
+ """
175
+
176
+ DONE_INSTRUCTION = f"""
177
+ When you are SURE you have the CORRECT answer to a user's query or request,
178
+ use the `{done_tool_name}` with `content` set to the answer or result.
179
+ If you DO NOT think you have the answer to the user's query or request,
180
+ you SHOULD NOT use the `{done_tool_name}` tool.
181
+ Instead, you must CONTINUE to improve your queries (tools) to get the correct answer,
182
+ and finally use the `{done_tool_name}` tool to send the correct answer to the user.
183
+ """
@@ -0,0 +1,102 @@
1
+ from typing import List, Tuple
2
+
3
+ from langroid.agent.tool_message import ToolMessage
4
+
5
+
6
+ class AQLRetrievalTool(ToolMessage):
7
+ request: str = "aql_retrieval_tool"
8
+ purpose: str = """
9
+ To send an <aql_query> in response to a user's request/question,
10
+ OR to find SCHEMA information,
11
+ and WAIT for results of the <aql_query> BEFORE continuing with response.
12
+ You will receive RESULTS from this tool, and ONLY THEN you can continue.
13
+ """
14
+ aql_query: str
15
+
16
+ @classmethod
17
+ def examples(cls) -> List[ToolMessage | Tuple[str, ToolMessage]]:
18
+ """Few-shot examples to include in tool instructions."""
19
+ return [
20
+ (
21
+ "I want to see who Bob's Father is",
22
+ cls(
23
+ aql_query="""
24
+ FOR v, e, p IN 1..1 OUTBOUND 'users/Bob' GRAPH 'family_tree'
25
+ FILTER p.edges[0].type == 'father'
26
+ RETURN v
27
+ """
28
+ ),
29
+ ),
30
+ (
31
+ "I want to know the properties of the Actor node",
32
+ cls(
33
+ aql_query="""
34
+ FOR doc IN Actor
35
+ LIMIT 1
36
+ RETURN ATTRIBUTES(doc)
37
+ """
38
+ ),
39
+ ),
40
+ ]
41
+
42
+ @classmethod
43
+ def instructions(cls) -> str:
44
+ return """
45
+ When using this TOOL/Function-call, you must WAIT to receive the RESULTS
46
+ of the AQL query, before continuing your response!
47
+ DO NOT ASSUME YOU KNOW THE RESULTs BEFORE RECEIVING THEM.
48
+ """
49
+
50
+
51
+ aql_retrieval_tool_name = AQLRetrievalTool.default_value("request")
52
+
53
+
54
+ class AQLCreationTool(ToolMessage):
55
+ request: str = "aql_creation_tool"
56
+ purpose: str = """
57
+ To send the <aql_query> to create documents/edges in the graph database.
58
+ IMPORTANT: YOU MUST WAIT FOR THE RESULT OF THE TOOL BEFORE CONTINUING.
59
+ You will receive RESULTS from this tool, and ONLY THEN you can continue.
60
+ """
61
+ aql_query: str
62
+
63
+ @classmethod
64
+ def examples(cls) -> List[ToolMessage | Tuple[str, ToolMessage]]:
65
+ """Few-shot examples to include in tool instructions."""
66
+ return [
67
+ (
68
+ "Create a new document in the collection 'users'",
69
+ cls(
70
+ aql_query="""
71
+ INSERT {
72
+ "name": "Alice",
73
+ "age": 30
74
+ } INTO users
75
+ """
76
+ ),
77
+ ),
78
+ ]
79
+
80
+
81
+ aql_creation_tool_name = AQLCreationTool.default_value("request")
82
+
83
+
84
+ class ArangoSchemaTool(ToolMessage):
85
+ request: str = "arango_schema_tool"
86
+ purpose: str = """
87
+ To get the schema of the Arango graph database,
88
+ or some part of it. Follow these instructions:
89
+ 1. Set <properties> to True to get the properties of the collections,
90
+ and False if you only want to see the graph structure and get only the
91
+ from/to relations of the edges.
92
+ 2. Set <collections> to a list of collection names if you want to see,
93
+ or leave it as None to see all ALL collections.
94
+ IMPORTANT: YOU MUST WAIT FOR THE RESULT OF THE TOOL BEFORE CONTINUING.
95
+ You will receive RESULTS from this tool, and ONLY THEN you can continue.
96
+ """
97
+
98
+ properties: bool = True
99
+ collections: List[str] | None = None
100
+
101
+
102
+ arango_schema_tool_name = ArangoSchemaTool.default_value("request")
@@ -0,0 +1,36 @@
1
+ from typing import Any, Dict, List
2
+
3
+
4
+ def count_fields(schema: Dict[str, List[Dict[str, Any]]]) -> int:
5
+ total = 0
6
+ for coll in schema["Collection Schema"]:
7
+ # Count all keys in each collection's dict
8
+ total += len(coll)
9
+ # Also count properties if they exist
10
+ props = coll.get(f"{coll['collection_type']}_properties", [])
11
+ total += len(props)
12
+ return total
13
+
14
+
15
+ def trim_schema(
16
+ schema: Dict[str, List[Dict[str, Any]]]
17
+ ) -> Dict[str, List[Dict[str, Any]]]:
18
+ """Keep only edge connection info, remove properties and examples"""
19
+ trimmed: Dict[str, List[Dict[str, Any]]] = {
20
+ "Graph Schema": schema["Graph Schema"],
21
+ "Collection Schema": [],
22
+ }
23
+ for coll in schema["Collection Schema"]:
24
+ col_info: Dict[str, Any] = {
25
+ "collection_name": coll["collection_name"],
26
+ "collection_type": coll["collection_type"],
27
+ }
28
+ if coll["collection_type"] == "edge":
29
+ # preserve from/to info if present
30
+ if f"example_{coll['collection_type']}" in coll:
31
+ example = coll[f"example_{coll['collection_type']}"]
32
+ if example and "_from" in example:
33
+ col_info["from_collection"] = example["_from"].split("/")[0]
34
+ col_info["to_collection"] = example["_to"].split("/")[0]
35
+ trimmed["Collection Schema"].append(col_info)
36
+ return trimmed
@@ -1,11 +1,9 @@
1
- import json
2
1
  import logging
3
2
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
4
3
 
5
4
  from rich import print
6
5
  from rich.console import Console
7
6
 
8
- from langroid.agent import ToolMessage
9
7
  from langroid.pydantic_v1 import BaseModel, BaseSettings
10
8
 
11
9
  if TYPE_CHECKING:
@@ -13,13 +11,25 @@ if TYPE_CHECKING:
13
11
 
14
12
  from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
15
13
  from langroid.agent.chat_document import ChatDocument
16
- from langroid.agent.special.neo4j.utils.system_message import (
14
+ from langroid.agent.special.neo4j.system_messages import (
17
15
  ADDRESSING_INSTRUCTION,
18
16
  DEFAULT_NEO4J_CHAT_SYSTEM_MESSAGE,
19
- DEFAULT_SYS_MSG,
17
+ DONE_INSTRUCTION,
18
+ SCHEMA_PROVIDED_SYS_MSG,
20
19
  SCHEMA_TOOLS_SYS_MSG,
21
20
  )
21
+ from langroid.agent.special.neo4j.tools import (
22
+ CypherCreationTool,
23
+ CypherRetrievalTool,
24
+ GraphSchemaTool,
25
+ cypher_creation_tool_name,
26
+ cypher_retrieval_tool_name,
27
+ graph_schema_tool_name,
28
+ )
29
+ from langroid.agent.tools.orchestration import DoneTool, ForwardTool
30
+ from langroid.exceptions import LangroidImportError
22
31
  from langroid.mytypes import Entity
32
+ from langroid.utils.constants import SEND_TO
23
33
 
24
34
  logger = logging.getLogger(__name__)
25
35
 
@@ -31,25 +41,6 @@ NEO4J_ERROR_MSG = "There was an error in your Cypher Query"
31
41
  # TOOLS to be used by the agent
32
42
 
33
43
 
34
- class CypherRetrievalTool(ToolMessage):
35
- request: str = "retrieval_query"
36
- purpose: str = """Use this tool to send the Cypher query to retreive data from the
37
- graph database based provided text description and schema."""
38
- cypher_query: str
39
-
40
-
41
- class CypherCreationTool(ToolMessage):
42
- request: str = "create_query"
43
- purpose: str = """Use this tool to send the Cypher query to create
44
- entities/relationships in the graph database."""
45
- cypher_query: str
46
-
47
-
48
- class GraphSchemaTool(ToolMessage):
49
- request: str = "get_schema"
50
- purpose: str = """To get the schema of the graph database."""
51
-
52
-
53
44
  class Neo4jSettings(BaseSettings):
54
45
  uri: str = ""
55
46
  username: str = ""
@@ -65,17 +56,22 @@ class Neo4jSettings(BaseSettings):
65
56
 
66
57
  class QueryResult(BaseModel):
67
58
  success: bool
68
- data: Optional[Union[str, List[Dict[Any, Any]]]] = None
59
+ data: List[Dict[Any, Any]] | str | None = None
69
60
 
70
61
 
71
62
  class Neo4jChatAgentConfig(ChatAgentConfig):
72
63
  neo4j_settings: Neo4jSettings = Neo4jSettings()
73
64
  system_message: str = DEFAULT_NEO4J_CHAT_SYSTEM_MESSAGE
74
- kg_schema: Optional[List[Dict[str, Any]]]
65
+ kg_schema: Optional[List[Dict[str, Any]]] = None
75
66
  database_created: bool = False
67
+ # whether agent MUST use schema_tools to get schema, i.e.
68
+ # schema is NOT initially provided
76
69
  use_schema_tools: bool = True
77
70
  use_functions_api: bool = True
78
71
  use_tools: bool = False
72
+ # whether the agent is used in a continuous chat with user,
73
+ # as opposed to returning a result from the task.run()
74
+ chat_mode: bool = False
79
75
  addressing_prefix: str = ""
80
76
 
81
77
 
@@ -89,21 +85,48 @@ class Neo4jChatAgent(ChatAgent):
89
85
  self.config: Neo4jChatAgentConfig = config
90
86
  self._validate_config()
91
87
  self._import_neo4j()
92
- self._initialize_connection()
93
- self._init_tool_messages()
88
+ self._initialize_db()
89
+ self._init_tools_sys_message()
94
90
  self.init_state()
95
91
 
96
92
  def init_state(self) -> None:
97
93
  super().init_state()
98
94
  self.current_retrieval_cypher_query: str = ""
95
+ self.tried_schema: bool = False
99
96
 
100
97
  def handle_message_fallback(
101
98
  self, msg: str | ChatDocument
102
- ) -> str | ChatDocument | None:
103
- """When LLM sends a no-tool msg, assume user is the intended recipient."""
99
+ ) -> str | ForwardTool | None:
100
+ """
101
+ When LLM sends a no-tool msg, assume user is the intended recipient,
102
+ and if in interactive mode, forward the msg to the user.
103
+ """
104
+
105
+ done_tool_name = DoneTool.default_value("request")
106
+ forward_tool_name = ForwardTool.default_value("request")
104
107
  if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM:
105
- msg.metadata.recipient = Entity.USER
106
- return msg
108
+ if self.interactive:
109
+ return ForwardTool(agent="User")
110
+ else:
111
+ if self.config.chat_mode:
112
+ return f"""
113
+ Since you did not explicitly address the User, it is not clear
114
+ whether:
115
+ - you intend this to be the final response to the
116
+ user's query/request, in which case you must use the
117
+ `{forward_tool_name}` to indicate this.
118
+ - OR, you FORGOT to use an Appropriate TOOL,
119
+ in which case you should use the available tools to
120
+ make progress on the user's query/request.
121
+ """
122
+ return f"""
123
+ The intent of your response is not clear:
124
+ - if you intended this to be the final answer to the user's query,
125
+ then use the `{done_tool_name}` to indicate so,
126
+ with the `content` set to the answer or result.
127
+ - otherwise, use one of the available tools to make progress
128
+ to arrive at the final answer.
129
+ """
107
130
  return None
108
131
 
109
132
  def _validate_config(self) -> None:
@@ -122,16 +145,9 @@ class Neo4jChatAgent(ChatAgent):
122
145
  try:
123
146
  import neo4j
124
147
  except ImportError:
125
- raise ImportError(
126
- """
127
- neo4j not installed. Please install it via:
128
- pip install neo4j.
129
- Or when installing langroid, install it with the `neo4j` extra:
130
- pip install langroid[neo4j]
131
- """
132
- )
148
+ raise LangroidImportError("neo4j", "neo4j")
133
149
 
134
- def _initialize_connection(self) -> None:
150
+ def _initialize_db(self) -> None:
135
151
  """
136
152
  Initializes a connection to the Neo4j database using the configuration settings.
137
153
  """
@@ -144,6 +160,16 @@ class Neo4jChatAgent(ChatAgent):
144
160
  self.config.neo4j_settings.password,
145
161
  ),
146
162
  )
163
+ with self.driver.session() as session:
164
+ result = session.run("MATCH (n) RETURN count(n) as count")
165
+ count = result.single()["count"] # type: ignore
166
+ self.config.database_created = count > 0
167
+
168
+ # If database has data, get schema
169
+ if self.config.database_created:
170
+ # this updates self.config.kg_schema
171
+ self.graph_schema_tool(None)
172
+
147
173
  except Exception as e:
148
174
  raise ConnectionError(f"Failed to initialize Neo4j connection: {e}")
149
175
 
@@ -226,6 +252,21 @@ class Neo4jChatAgent(ChatAgent):
226
252
  QueryResult: An object representing the outcome of the query execution.
227
253
  It contains a success flag and an optional error message.
228
254
  """
255
+ # Check if query contains database/collection creation patterns
256
+ query_upper = query.upper()
257
+ is_creation_query = any(
258
+ [
259
+ "CREATE" in query_upper,
260
+ "MERGE" in query_upper,
261
+ "CREATE CONSTRAINT" in query_upper,
262
+ "CREATE INDEX" in query_upper,
263
+ ]
264
+ )
265
+
266
+ if is_creation_query:
267
+ self.config.database_created = True
268
+ logger.info("Detected database/collection creation query")
269
+
229
270
  if not self.driver:
230
271
  return QueryResult(
231
272
  success=False, data="No database connection is established."
@@ -260,7 +301,7 @@ class Neo4jChatAgent(ChatAgent):
260
301
  else:
261
302
  print("[red]Database is not deleted!")
262
303
 
263
- def retrieval_query(self, msg: CypherRetrievalTool) -> str:
304
+ def cypher_retrieval_tool(self, msg: CypherRetrievalTool) -> str:
264
305
  """ "
265
306
  Handle a CypherRetrievalTool message by executing a Cypher query and
266
307
  returning the result.
@@ -271,13 +312,20 @@ class Neo4jChatAgent(ChatAgent):
271
312
  str: The result of executing the cypher_query.
272
313
  """
273
314
  if not self.tried_schema:
274
- return """
275
- You did not yet use the `get_schema` tool to get the schema
315
+ return f"""
316
+ You did not yet use the `{graph_schema_tool_name}` tool to get the schema
276
317
  of the neo4j knowledge-graph db. Use that tool first before using
277
- the `retrieval_query` tool, to ensure you know all the correct
318
+ the `{cypher_retrieval_tool_name}` tool, to ensure you know all the correct
278
319
  node labels, relationship types, and property keys available in
279
320
  the database.
280
321
  """
322
+ elif not self.config.database_created:
323
+ return f"""
324
+ You have not yet created the Neo4j database.
325
+ Use the `{cypher_creation_tool_name}`
326
+ tool to create the database first before using the
327
+ `{cypher_retrieval_tool_name}` tool.
328
+ """
281
329
  query = msg.cypher_query
282
330
  self.current_retrieval_cypher_query = query
283
331
  logger.info(f"Executing Cypher query: {query}")
@@ -291,7 +339,7 @@ class Neo4jChatAgent(ChatAgent):
291
339
  """
292
340
  return str(response.data)
293
341
 
294
- def create_query(self, msg: CypherCreationTool) -> str:
342
+ def cypher_creation_tool(self, msg: CypherCreationTool) -> str:
295
343
  """ "
296
344
  Handle a CypherCreationTool message by executing a Cypher query and
297
345
  returning the result.
@@ -306,6 +354,7 @@ class Neo4jChatAgent(ChatAgent):
306
354
  logger.info(f"Executing Cypher query: {query}")
307
355
  response = self.write_query(query)
308
356
  if response.success:
357
+ self.config.database_created = True
309
358
  return "Cypher query executed successfully"
310
359
  else:
311
360
  return str(response.data)
@@ -316,7 +365,9 @@ class Neo4jChatAgent(ChatAgent):
316
365
  # The current query works well. But we could use the queries here:
317
366
  # https://github.com/neo4j/NaLLM/blob/1af09cd117ba0777d81075c597a5081583568f9f/api/
318
367
  # src/driver/neo4j.py#L30
319
- def get_schema(self, msg: GraphSchemaTool | None) -> str:
368
+ def graph_schema_tool(
369
+ self, msg: GraphSchemaTool | None
370
+ ) -> str | Optional[Union[str, List[Dict[Any, Any]]]]:
320
371
  """
321
372
  Retrieves the schema of a Neo4j graph database.
322
373
 
@@ -334,27 +385,42 @@ class Neo4jChatAgent(ChatAgent):
334
385
  to database connectivity or query execution.
335
386
  """
336
387
  self.tried_schema = True
388
+ if self.config.kg_schema is not None and len(self.config.kg_schema) > 0:
389
+ return self.config.kg_schema
337
390
  schema_result = self.read_query("CALL db.schema.visualization()")
338
391
  if schema_result.success:
339
- # ther is a possibility that the schema is empty, which is a valid response
392
+ # there is a possibility that the schema is empty, which is a valid response
340
393
  # the schema.data will be: [{"nodes": [], "relationships": []}]
341
- return json.dumps(schema_result.data)
394
+ self.config.kg_schema = schema_result.data # type: ignore
395
+ return schema_result.data
342
396
  else:
343
397
  return f"Failed to retrieve schema: {schema_result.data}"
344
398
 
345
- def _init_tool_messages(self) -> None:
399
+ def _init_tools_sys_message(self) -> None:
346
400
  """Initialize message tools used for chatting."""
347
401
  self.tried_schema = False
348
402
  message = self._format_message()
349
403
  self.config.system_message = self.config.system_message.format(mode=message)
350
- if self.config.addressing_prefix != "":
404
+ if self.config.chat_mode:
405
+ self.config.addressing_prefix = self.config.addressing_prefix or SEND_TO
351
406
  self.config.system_message += ADDRESSING_INSTRUCTION.format(
352
407
  prefix=self.config.addressing_prefix
353
408
  )
409
+ else:
410
+ self.config.system_message += DONE_INSTRUCTION
354
411
  super().__init__(self.config)
355
- self.enable_message(CypherRetrievalTool)
356
- self.enable_message(CypherCreationTool)
357
- self.enable_message(GraphSchemaTool)
412
+ # Note we are enabling GraphSchemaTool regardless of whether
413
+ # self.config.use_schema_tools is True or False, because
414
+ # even when schema provided, the agent may later want to get the schema,
415
+ # e.g. if the db evolves, or if it needs to bring in the schema
416
+ self.enable_message(
417
+ [
418
+ GraphSchemaTool,
419
+ CypherRetrievalTool,
420
+ CypherCreationTool,
421
+ DoneTool,
422
+ ]
423
+ )
358
424
 
359
425
  def _format_message(self) -> str:
360
426
  if self.driver is None:
@@ -363,5 +429,5 @@ class Neo4jChatAgent(ChatAgent):
363
429
  return (
364
430
  SCHEMA_TOOLS_SYS_MSG
365
431
  if self.config.use_schema_tools
366
- else DEFAULT_SYS_MSG.format(schema=self.get_schema(None))
432
+ else SCHEMA_PROVIDED_SYS_MSG.format(schema=self.graph_schema_tool(None))
367
433
  )