langroid 0.19.5__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +29 -14
- langroid/agent/special/arangodb/arangodb_agent.py +649 -0
- langroid/agent/special/arangodb/system_messages.py +183 -0
- langroid/agent/special/arangodb/tools.py +102 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +120 -54
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/sql/sql_chat_agent.py +8 -3
- langroid/agent/task.py +78 -56
- langroid/agent/tools/orchestration.py +7 -7
- langroid/parsing/parser.py +6 -0
- {langroid-0.19.5.dist-info → langroid-0.20.1.dist-info}/METADATA +5 -1
- {langroid-0.19.5.dist-info → langroid-0.20.1.dist-info}/RECORD +18 -13
- pyproject.toml +7 -1
- langroid/agent/special/neo4j/utils/system_message.py +0 -64
- /langroid/agent/special/{neo4j/utils → arangodb}/__init__.py +0 -0
- {langroid-0.19.5.dist-info → langroid-0.20.1.dist-info}/LICENSE +0 -0
- {langroid-0.19.5.dist-info → langroid-0.20.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,183 @@
|
|
1
|
+
from langroid.agent.special.arangodb.tools import (
|
2
|
+
aql_creation_tool_name,
|
3
|
+
aql_retrieval_tool_name,
|
4
|
+
arango_schema_tool_name,
|
5
|
+
)
|
6
|
+
from langroid.agent.tools.orchestration import DoneTool
|
7
|
+
|
8
|
+
done_tool_name = DoneTool.default_value("request")
|
9
|
+
|
10
|
+
arango_schema_tool_description = f"""
|
11
|
+
`{arango_schema_tool_name}` tool/function-call to find the schema
|
12
|
+
of the graph database, or for some SPECIFIC collections, i.e. get information on
|
13
|
+
(document and edge), their attributes, and graph definitions available in your
|
14
|
+
ArangoDB database. You MUST use this tool BEFORE attempting to use the
|
15
|
+
`{aql_retrieval_tool_name}` tool/function-call, to ensure that you are using the
|
16
|
+
correct collection names and attributes in your `{aql_retrieval_tool_name}` tool.
|
17
|
+
"""
|
18
|
+
|
19
|
+
aql_retrieval_tool_description = f"""
|
20
|
+
`{aql_retrieval_tool_name}` tool/function-call to retrieve information from
|
21
|
+
the database using AQL (ArangoDB Query Language) queries, to answer
|
22
|
+
the user's questions, OR for you to learn more about the SCHEMA of the database.
|
23
|
+
"""
|
24
|
+
|
25
|
+
aql_creation_tool_description = f"""
|
26
|
+
`{aql_creation_tool_name}` tool/function-call to execute AQL query that creates
|
27
|
+
documents/edges in the database.
|
28
|
+
"""
|
29
|
+
|
30
|
+
aql_retrieval_query_example = """
|
31
|
+
EXAMPLE:
|
32
|
+
Suppose you are asked this question "Does Bob have a father?".
|
33
|
+
Then you will go through the following steps, where YOU indicates
|
34
|
+
the message YOU will be sending, and RESULTS indicates the RESULTS
|
35
|
+
you will receive from the helper executing the query:
|
36
|
+
|
37
|
+
1. YOU:
|
38
|
+
{{ "request": "aql_retrieval_tool",
|
39
|
+
"aql_query": "FOR v, e, p in ... [query truncated for brevity]..."}}
|
40
|
+
|
41
|
+
2. RESULTS:
|
42
|
+
[.. results from the query...]
|
43
|
+
3. YOU: [ since results were not satisfactory, you try ANOTHER query]
|
44
|
+
{{ "request": "aql_retrieval_tool",
|
45
|
+
"aql_query": "blah blah ... [query truncated for brevity]..."}}
|
46
|
+
}}
|
47
|
+
4. RESULTS:
|
48
|
+
[.. results from the query...]
|
49
|
+
5. YOU: [ now you have the answer, you can generate your response ]
|
50
|
+
The answer is YES, Bob has a father, and his name is John.
|
51
|
+
"""
|
52
|
+
|
53
|
+
aql_query_instructions = """
|
54
|
+
When writing AQL queries:
|
55
|
+
1. Use the exact property names shown in the schema
|
56
|
+
2. Pay attention to the 'type' field of each node
|
57
|
+
3. Note that all names are case-sensitive:
|
58
|
+
- collection names
|
59
|
+
- property names
|
60
|
+
- node type values
|
61
|
+
- relationship type values
|
62
|
+
4. Always include type filters in your queries, e.g.:
|
63
|
+
FILTER doc.type == '<type-from-schema>'
|
64
|
+
|
65
|
+
The schema shows:
|
66
|
+
- Collections (usually 'nodes' and 'edges')
|
67
|
+
- Node types in each collection
|
68
|
+
- Available properties for each node type
|
69
|
+
- Relationship types and their properties
|
70
|
+
|
71
|
+
Examine the schema carefully before writing queries to ensure:
|
72
|
+
- Correct property names
|
73
|
+
- Correct node types
|
74
|
+
- Correct relationship types
|
75
|
+
|
76
|
+
You must be smart about using the right collection names and attributes
|
77
|
+
based on the English description. If you are thinking of using a collection
|
78
|
+
or attribute that does not exist, you are probably on the wrong track,
|
79
|
+
so you should try your best to answer based on existing collections and attributes.
|
80
|
+
DO NOT assume any collections or graphs other than those above.
|
81
|
+
"""
|
82
|
+
|
83
|
+
tool_result_instruction = """
|
84
|
+
REMEMBER:
|
85
|
+
[1] DO NOT FORGET TO USE ONE OF THE AVAILABLE TOOLS TO ANSWER THE USER'S QUERY!!
|
86
|
+
[2] When using a TOOL/FUNCTION, you MUST WAIT for the tool result before continuing
|
87
|
+
with your response. DO NOT MAKE UP RESULTS FROM A TOOL!
|
88
|
+
[3] YOU MUST NOT ANSWER queries from your OWN KNOWLEDGE; ALWAYS RELY ON
|
89
|
+
the result of a TOOL/FUNCTION to compose your response.
|
90
|
+
[4] Use ONLY ONE TOOL/FUNCTION at a TIME!
|
91
|
+
"""
|
92
|
+
# sys msg to use when schema already provided initially,
|
93
|
+
# so agent should not use schema tool
|
94
|
+
SCHEMA_PROVIDED_SYS_MSG = f"""You are a data scientist and expert in Graph Databases,
|
95
|
+
with expertise in answering questions by interacting with an ArangoDB database.
|
96
|
+
|
97
|
+
The schema below describes the ArangoDB database structure,
|
98
|
+
collections (document and edge),
|
99
|
+
and their attribute keys available in your ArangoDB database.
|
100
|
+
|
101
|
+
=== SCHEMA ===
|
102
|
+
{{schema}}
|
103
|
+
=== END SCHEMA ===
|
104
|
+
|
105
|
+
|
106
|
+
To help with the user's question or database update/creation request,
|
107
|
+
you have access to these tools:
|
108
|
+
|
109
|
+
- {aql_retrieval_tool_description}
|
110
|
+
|
111
|
+
- {aql_creation_tool_description}
|
112
|
+
|
113
|
+
|
114
|
+
{tool_result_instruction}
|
115
|
+
"""
|
116
|
+
|
117
|
+
# sys msg to use when schema is not initially provided,
|
118
|
+
# and we want agent to use schema tool to get schema
|
119
|
+
SCHEMA_TOOLS_SYS_MSG = f"""You are a data scientist and expert in
|
120
|
+
Arango Graph Databases,
|
121
|
+
with expertise in answering questions by querying ArangoDB database
|
122
|
+
using the Arango Query Language (AQL).
|
123
|
+
You have access to the following tools:
|
124
|
+
|
125
|
+
- {arango_schema_tool_description}
|
126
|
+
|
127
|
+
- {aql_retrieval_tool_description}
|
128
|
+
|
129
|
+
- {aql_creation_tool_description}
|
130
|
+
|
131
|
+
{tool_result_instruction}
|
132
|
+
"""
|
133
|
+
|
134
|
+
DEFAULT_ARANGO_CHAT_SYSTEM_MESSAGE = f"""
|
135
|
+
{{mode}}
|
136
|
+
|
137
|
+
You do not need to be able to answer a question with just one query.
|
138
|
+
You can make a query, WAIT for the result,
|
139
|
+
THEN make ANOTHER query, WAIT for result,
|
140
|
+
THEN make ANOTHER query, and so on, until you have the answer.
|
141
|
+
|
142
|
+
{aql_query_instructions}
|
143
|
+
|
144
|
+
RETRY-SUGGESTIONS:
|
145
|
+
If you receive a null or other unexpected result,
|
146
|
+
(a) make sure you use the available TOOLs correctly,
|
147
|
+
(b) USE `{arango_schema_tool_name}` tool/function-call to get all collections,
|
148
|
+
their attributes and graph definitions available in your ArangoDB database.
|
149
|
+
(c) Collection names are CASE-SENSITIVE -- make sure you adhere to the exact
|
150
|
+
collection name you found in the schema.
|
151
|
+
(d) see if you have made an assumption in your AQL query, and try another way,
|
152
|
+
or use `{aql_retrieval_tool_name}` to explore the database contents before
|
153
|
+
submitting your final query.
|
154
|
+
(f) Try APPROXIMATE or PARTIAL MATCHES to strings in the user's query,
|
155
|
+
e.g. user may ask about "Godfather" instead of "The Godfather",
|
156
|
+
or try using CASE-INSENSITIVE MATCHES.
|
157
|
+
|
158
|
+
Start by asking what the user needs help with.
|
159
|
+
|
160
|
+
{tool_result_instruction}
|
161
|
+
|
162
|
+
{aql_retrieval_query_example}
|
163
|
+
"""
|
164
|
+
|
165
|
+
ADDRESSING_INSTRUCTION = """
|
166
|
+
IMPORTANT - Whenever you are NOT writing an AQL query, make sure you address the
|
167
|
+
user using {prefix}User. You MUST use the EXACT syntax {prefix} !!!
|
168
|
+
|
169
|
+
In other words, you ALWAYS EITHER:
|
170
|
+
- write an AQL query using one of the tools,
|
171
|
+
- OR address the user using {prefix}User.
|
172
|
+
|
173
|
+
YOU CANNOT ADDRESS THE USER WHEN USING A TOOL!!
|
174
|
+
"""
|
175
|
+
|
176
|
+
DONE_INSTRUCTION = f"""
|
177
|
+
When you are SURE you have the CORRECT answer to a user's query or request,
|
178
|
+
use the `{done_tool_name}` with `content` set to the answer or result.
|
179
|
+
If you DO NOT think you have the answer to the user's query or request,
|
180
|
+
you SHOULD NOT use the `{done_tool_name}` tool.
|
181
|
+
Instead, you must CONTINUE to improve your queries (tools) to get the correct answer,
|
182
|
+
and finally use the `{done_tool_name}` tool to send the correct answer to the user.
|
183
|
+
"""
|
@@ -0,0 +1,102 @@
|
|
1
|
+
from typing import List, Tuple
|
2
|
+
|
3
|
+
from langroid.agent.tool_message import ToolMessage
|
4
|
+
|
5
|
+
|
6
|
+
class AQLRetrievalTool(ToolMessage):
|
7
|
+
request: str = "aql_retrieval_tool"
|
8
|
+
purpose: str = """
|
9
|
+
To send an <aql_query> in response to a user's request/question,
|
10
|
+
OR to find SCHEMA information,
|
11
|
+
and WAIT for results of the <aql_query> BEFORE continuing with response.
|
12
|
+
You will receive RESULTS from this tool, and ONLY THEN you can continue.
|
13
|
+
"""
|
14
|
+
aql_query: str
|
15
|
+
|
16
|
+
@classmethod
|
17
|
+
def examples(cls) -> List[ToolMessage | Tuple[str, ToolMessage]]:
|
18
|
+
"""Few-shot examples to include in tool instructions."""
|
19
|
+
return [
|
20
|
+
(
|
21
|
+
"I want to see who Bob's Father is",
|
22
|
+
cls(
|
23
|
+
aql_query="""
|
24
|
+
FOR v, e, p IN 1..1 OUTBOUND 'users/Bob' GRAPH 'family_tree'
|
25
|
+
FILTER p.edges[0].type == 'father'
|
26
|
+
RETURN v
|
27
|
+
"""
|
28
|
+
),
|
29
|
+
),
|
30
|
+
(
|
31
|
+
"I want to know the properties of the Actor node",
|
32
|
+
cls(
|
33
|
+
aql_query="""
|
34
|
+
FOR doc IN Actor
|
35
|
+
LIMIT 1
|
36
|
+
RETURN ATTRIBUTES(doc)
|
37
|
+
"""
|
38
|
+
),
|
39
|
+
),
|
40
|
+
]
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def instructions(cls) -> str:
|
44
|
+
return """
|
45
|
+
When using this TOOL/Function-call, you must WAIT to receive the RESULTS
|
46
|
+
of the AQL query, before continuing your response!
|
47
|
+
DO NOT ASSUME YOU KNOW THE RESULTs BEFORE RECEIVING THEM.
|
48
|
+
"""
|
49
|
+
|
50
|
+
|
51
|
+
aql_retrieval_tool_name = AQLRetrievalTool.default_value("request")
|
52
|
+
|
53
|
+
|
54
|
+
class AQLCreationTool(ToolMessage):
|
55
|
+
request: str = "aql_creation_tool"
|
56
|
+
purpose: str = """
|
57
|
+
To send the <aql_query> to create documents/edges in the graph database.
|
58
|
+
IMPORTANT: YOU MUST WAIT FOR THE RESULT OF THE TOOL BEFORE CONTINUING.
|
59
|
+
You will receive RESULTS from this tool, and ONLY THEN you can continue.
|
60
|
+
"""
|
61
|
+
aql_query: str
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
def examples(cls) -> List[ToolMessage | Tuple[str, ToolMessage]]:
|
65
|
+
"""Few-shot examples to include in tool instructions."""
|
66
|
+
return [
|
67
|
+
(
|
68
|
+
"Create a new document in the collection 'users'",
|
69
|
+
cls(
|
70
|
+
aql_query="""
|
71
|
+
INSERT {
|
72
|
+
"name": "Alice",
|
73
|
+
"age": 30
|
74
|
+
} INTO users
|
75
|
+
"""
|
76
|
+
),
|
77
|
+
),
|
78
|
+
]
|
79
|
+
|
80
|
+
|
81
|
+
aql_creation_tool_name = AQLCreationTool.default_value("request")
|
82
|
+
|
83
|
+
|
84
|
+
class ArangoSchemaTool(ToolMessage):
|
85
|
+
request: str = "arango_schema_tool"
|
86
|
+
purpose: str = """
|
87
|
+
To get the schema of the Arango graph database,
|
88
|
+
or some part of it. Follow these instructions:
|
89
|
+
1. Set <properties> to True to get the properties of the collections,
|
90
|
+
and False if you only want to see the graph structure and get only the
|
91
|
+
from/to relations of the edges.
|
92
|
+
2. Set <collections> to a list of collection names if you want to see,
|
93
|
+
or leave it as None to see all ALL collections.
|
94
|
+
IMPORTANT: YOU MUST WAIT FOR THE RESULT OF THE TOOL BEFORE CONTINUING.
|
95
|
+
You will receive RESULTS from this tool, and ONLY THEN you can continue.
|
96
|
+
"""
|
97
|
+
|
98
|
+
properties: bool = True
|
99
|
+
collections: List[str] | None = None
|
100
|
+
|
101
|
+
|
102
|
+
arango_schema_tool_name = ArangoSchemaTool.default_value("request")
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from typing import Any, Dict, List
|
2
|
+
|
3
|
+
|
4
|
+
def count_fields(schema: Dict[str, List[Dict[str, Any]]]) -> int:
|
5
|
+
total = 0
|
6
|
+
for coll in schema["Collection Schema"]:
|
7
|
+
# Count all keys in each collection's dict
|
8
|
+
total += len(coll)
|
9
|
+
# Also count properties if they exist
|
10
|
+
props = coll.get(f"{coll['collection_type']}_properties", [])
|
11
|
+
total += len(props)
|
12
|
+
return total
|
13
|
+
|
14
|
+
|
15
|
+
def trim_schema(
|
16
|
+
schema: Dict[str, List[Dict[str, Any]]]
|
17
|
+
) -> Dict[str, List[Dict[str, Any]]]:
|
18
|
+
"""Keep only edge connection info, remove properties and examples"""
|
19
|
+
trimmed: Dict[str, List[Dict[str, Any]]] = {
|
20
|
+
"Graph Schema": schema["Graph Schema"],
|
21
|
+
"Collection Schema": [],
|
22
|
+
}
|
23
|
+
for coll in schema["Collection Schema"]:
|
24
|
+
col_info: Dict[str, Any] = {
|
25
|
+
"collection_name": coll["collection_name"],
|
26
|
+
"collection_type": coll["collection_type"],
|
27
|
+
}
|
28
|
+
if coll["collection_type"] == "edge":
|
29
|
+
# preserve from/to info if present
|
30
|
+
if f"example_{coll['collection_type']}" in coll:
|
31
|
+
example = coll[f"example_{coll['collection_type']}"]
|
32
|
+
if example and "_from" in example:
|
33
|
+
col_info["from_collection"] = example["_from"].split("/")[0]
|
34
|
+
col_info["to_collection"] = example["_to"].split("/")[0]
|
35
|
+
trimmed["Collection Schema"].append(col_info)
|
36
|
+
return trimmed
|
@@ -1,11 +1,9 @@
|
|
1
|
-
import json
|
2
1
|
import logging
|
3
2
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
4
3
|
|
5
4
|
from rich import print
|
6
5
|
from rich.console import Console
|
7
6
|
|
8
|
-
from langroid.agent import ToolMessage
|
9
7
|
from langroid.pydantic_v1 import BaseModel, BaseSettings
|
10
8
|
|
11
9
|
if TYPE_CHECKING:
|
@@ -13,13 +11,25 @@ if TYPE_CHECKING:
|
|
13
11
|
|
14
12
|
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
15
13
|
from langroid.agent.chat_document import ChatDocument
|
16
|
-
from langroid.agent.special.neo4j.
|
14
|
+
from langroid.agent.special.neo4j.system_messages import (
|
17
15
|
ADDRESSING_INSTRUCTION,
|
18
16
|
DEFAULT_NEO4J_CHAT_SYSTEM_MESSAGE,
|
19
|
-
|
17
|
+
DONE_INSTRUCTION,
|
18
|
+
SCHEMA_PROVIDED_SYS_MSG,
|
20
19
|
SCHEMA_TOOLS_SYS_MSG,
|
21
20
|
)
|
21
|
+
from langroid.agent.special.neo4j.tools import (
|
22
|
+
CypherCreationTool,
|
23
|
+
CypherRetrievalTool,
|
24
|
+
GraphSchemaTool,
|
25
|
+
cypher_creation_tool_name,
|
26
|
+
cypher_retrieval_tool_name,
|
27
|
+
graph_schema_tool_name,
|
28
|
+
)
|
29
|
+
from langroid.agent.tools.orchestration import DoneTool, ForwardTool
|
30
|
+
from langroid.exceptions import LangroidImportError
|
22
31
|
from langroid.mytypes import Entity
|
32
|
+
from langroid.utils.constants import SEND_TO
|
23
33
|
|
24
34
|
logger = logging.getLogger(__name__)
|
25
35
|
|
@@ -31,25 +41,6 @@ NEO4J_ERROR_MSG = "There was an error in your Cypher Query"
|
|
31
41
|
# TOOLS to be used by the agent
|
32
42
|
|
33
43
|
|
34
|
-
class CypherRetrievalTool(ToolMessage):
|
35
|
-
request: str = "retrieval_query"
|
36
|
-
purpose: str = """Use this tool to send the Cypher query to retreive data from the
|
37
|
-
graph database based provided text description and schema."""
|
38
|
-
cypher_query: str
|
39
|
-
|
40
|
-
|
41
|
-
class CypherCreationTool(ToolMessage):
|
42
|
-
request: str = "create_query"
|
43
|
-
purpose: str = """Use this tool to send the Cypher query to create
|
44
|
-
entities/relationships in the graph database."""
|
45
|
-
cypher_query: str
|
46
|
-
|
47
|
-
|
48
|
-
class GraphSchemaTool(ToolMessage):
|
49
|
-
request: str = "get_schema"
|
50
|
-
purpose: str = """To get the schema of the graph database."""
|
51
|
-
|
52
|
-
|
53
44
|
class Neo4jSettings(BaseSettings):
|
54
45
|
uri: str = ""
|
55
46
|
username: str = ""
|
@@ -65,17 +56,22 @@ class Neo4jSettings(BaseSettings):
|
|
65
56
|
|
66
57
|
class QueryResult(BaseModel):
|
67
58
|
success: bool
|
68
|
-
data:
|
59
|
+
data: List[Dict[Any, Any]] | str | None = None
|
69
60
|
|
70
61
|
|
71
62
|
class Neo4jChatAgentConfig(ChatAgentConfig):
|
72
63
|
neo4j_settings: Neo4jSettings = Neo4jSettings()
|
73
64
|
system_message: str = DEFAULT_NEO4J_CHAT_SYSTEM_MESSAGE
|
74
|
-
kg_schema: Optional[List[Dict[str, Any]]]
|
65
|
+
kg_schema: Optional[List[Dict[str, Any]]] = None
|
75
66
|
database_created: bool = False
|
67
|
+
# whether agent MUST use schema_tools to get schema, i.e.
|
68
|
+
# schema is NOT initially provided
|
76
69
|
use_schema_tools: bool = True
|
77
70
|
use_functions_api: bool = True
|
78
71
|
use_tools: bool = False
|
72
|
+
# whether the agent is used in a continuous chat with user,
|
73
|
+
# as opposed to returning a result from the task.run()
|
74
|
+
chat_mode: bool = False
|
79
75
|
addressing_prefix: str = ""
|
80
76
|
|
81
77
|
|
@@ -89,21 +85,48 @@ class Neo4jChatAgent(ChatAgent):
|
|
89
85
|
self.config: Neo4jChatAgentConfig = config
|
90
86
|
self._validate_config()
|
91
87
|
self._import_neo4j()
|
92
|
-
self.
|
93
|
-
self.
|
88
|
+
self._initialize_db()
|
89
|
+
self._init_tools_sys_message()
|
94
90
|
self.init_state()
|
95
91
|
|
96
92
|
def init_state(self) -> None:
|
97
93
|
super().init_state()
|
98
94
|
self.current_retrieval_cypher_query: str = ""
|
95
|
+
self.tried_schema: bool = False
|
99
96
|
|
100
97
|
def handle_message_fallback(
|
101
98
|
self, msg: str | ChatDocument
|
102
|
-
) -> str |
|
103
|
-
"""
|
99
|
+
) -> str | ForwardTool | None:
|
100
|
+
"""
|
101
|
+
When LLM sends a no-tool msg, assume user is the intended recipient,
|
102
|
+
and if in interactive mode, forward the msg to the user.
|
103
|
+
"""
|
104
|
+
|
105
|
+
done_tool_name = DoneTool.default_value("request")
|
106
|
+
forward_tool_name = ForwardTool.default_value("request")
|
104
107
|
if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM:
|
105
|
-
|
106
|
-
|
108
|
+
if self.interactive:
|
109
|
+
return ForwardTool(agent="User")
|
110
|
+
else:
|
111
|
+
if self.config.chat_mode:
|
112
|
+
return f"""
|
113
|
+
Since you did not explicitly address the User, it is not clear
|
114
|
+
whether:
|
115
|
+
- you intend this to be the final response to the
|
116
|
+
user's query/request, in which case you must use the
|
117
|
+
`{forward_tool_name}` to indicate this.
|
118
|
+
- OR, you FORGOT to use an Appropriate TOOL,
|
119
|
+
in which case you should use the available tools to
|
120
|
+
make progress on the user's query/request.
|
121
|
+
"""
|
122
|
+
return f"""
|
123
|
+
The intent of your response is not clear:
|
124
|
+
- if you intended this to be the final answer to the user's query,
|
125
|
+
then use the `{done_tool_name}` to indicate so,
|
126
|
+
with the `content` set to the answer or result.
|
127
|
+
- otherwise, use one of the available tools to make progress
|
128
|
+
to arrive at the final answer.
|
129
|
+
"""
|
107
130
|
return None
|
108
131
|
|
109
132
|
def _validate_config(self) -> None:
|
@@ -122,16 +145,9 @@ class Neo4jChatAgent(ChatAgent):
|
|
122
145
|
try:
|
123
146
|
import neo4j
|
124
147
|
except ImportError:
|
125
|
-
raise
|
126
|
-
"""
|
127
|
-
neo4j not installed. Please install it via:
|
128
|
-
pip install neo4j.
|
129
|
-
Or when installing langroid, install it with the `neo4j` extra:
|
130
|
-
pip install langroid[neo4j]
|
131
|
-
"""
|
132
|
-
)
|
148
|
+
raise LangroidImportError("neo4j", "neo4j")
|
133
149
|
|
134
|
-
def
|
150
|
+
def _initialize_db(self) -> None:
|
135
151
|
"""
|
136
152
|
Initializes a connection to the Neo4j database using the configuration settings.
|
137
153
|
"""
|
@@ -144,6 +160,16 @@ class Neo4jChatAgent(ChatAgent):
|
|
144
160
|
self.config.neo4j_settings.password,
|
145
161
|
),
|
146
162
|
)
|
163
|
+
with self.driver.session() as session:
|
164
|
+
result = session.run("MATCH (n) RETURN count(n) as count")
|
165
|
+
count = result.single()["count"] # type: ignore
|
166
|
+
self.config.database_created = count > 0
|
167
|
+
|
168
|
+
# If database has data, get schema
|
169
|
+
if self.config.database_created:
|
170
|
+
# this updates self.config.kg_schema
|
171
|
+
self.graph_schema_tool(None)
|
172
|
+
|
147
173
|
except Exception as e:
|
148
174
|
raise ConnectionError(f"Failed to initialize Neo4j connection: {e}")
|
149
175
|
|
@@ -226,6 +252,21 @@ class Neo4jChatAgent(ChatAgent):
|
|
226
252
|
QueryResult: An object representing the outcome of the query execution.
|
227
253
|
It contains a success flag and an optional error message.
|
228
254
|
"""
|
255
|
+
# Check if query contains database/collection creation patterns
|
256
|
+
query_upper = query.upper()
|
257
|
+
is_creation_query = any(
|
258
|
+
[
|
259
|
+
"CREATE" in query_upper,
|
260
|
+
"MERGE" in query_upper,
|
261
|
+
"CREATE CONSTRAINT" in query_upper,
|
262
|
+
"CREATE INDEX" in query_upper,
|
263
|
+
]
|
264
|
+
)
|
265
|
+
|
266
|
+
if is_creation_query:
|
267
|
+
self.config.database_created = True
|
268
|
+
logger.info("Detected database/collection creation query")
|
269
|
+
|
229
270
|
if not self.driver:
|
230
271
|
return QueryResult(
|
231
272
|
success=False, data="No database connection is established."
|
@@ -260,7 +301,7 @@ class Neo4jChatAgent(ChatAgent):
|
|
260
301
|
else:
|
261
302
|
print("[red]Database is not deleted!")
|
262
303
|
|
263
|
-
def
|
304
|
+
def cypher_retrieval_tool(self, msg: CypherRetrievalTool) -> str:
|
264
305
|
""" "
|
265
306
|
Handle a CypherRetrievalTool message by executing a Cypher query and
|
266
307
|
returning the result.
|
@@ -271,13 +312,20 @@ class Neo4jChatAgent(ChatAgent):
|
|
271
312
|
str: The result of executing the cypher_query.
|
272
313
|
"""
|
273
314
|
if not self.tried_schema:
|
274
|
-
return """
|
275
|
-
You did not yet use the `
|
315
|
+
return f"""
|
316
|
+
You did not yet use the `{graph_schema_tool_name}` tool to get the schema
|
276
317
|
of the neo4j knowledge-graph db. Use that tool first before using
|
277
|
-
the `
|
318
|
+
the `{cypher_retrieval_tool_name}` tool, to ensure you know all the correct
|
278
319
|
node labels, relationship types, and property keys available in
|
279
320
|
the database.
|
280
321
|
"""
|
322
|
+
elif not self.config.database_created:
|
323
|
+
return f"""
|
324
|
+
You have not yet created the Neo4j database.
|
325
|
+
Use the `{cypher_creation_tool_name}`
|
326
|
+
tool to create the database first before using the
|
327
|
+
`{cypher_retrieval_tool_name}` tool.
|
328
|
+
"""
|
281
329
|
query = msg.cypher_query
|
282
330
|
self.current_retrieval_cypher_query = query
|
283
331
|
logger.info(f"Executing Cypher query: {query}")
|
@@ -291,7 +339,7 @@ class Neo4jChatAgent(ChatAgent):
|
|
291
339
|
"""
|
292
340
|
return str(response.data)
|
293
341
|
|
294
|
-
def
|
342
|
+
def cypher_creation_tool(self, msg: CypherCreationTool) -> str:
|
295
343
|
""" "
|
296
344
|
Handle a CypherCreationTool message by executing a Cypher query and
|
297
345
|
returning the result.
|
@@ -306,6 +354,7 @@ class Neo4jChatAgent(ChatAgent):
|
|
306
354
|
logger.info(f"Executing Cypher query: {query}")
|
307
355
|
response = self.write_query(query)
|
308
356
|
if response.success:
|
357
|
+
self.config.database_created = True
|
309
358
|
return "Cypher query executed successfully"
|
310
359
|
else:
|
311
360
|
return str(response.data)
|
@@ -316,7 +365,9 @@ class Neo4jChatAgent(ChatAgent):
|
|
316
365
|
# The current query works well. But we could use the queries here:
|
317
366
|
# https://github.com/neo4j/NaLLM/blob/1af09cd117ba0777d81075c597a5081583568f9f/api/
|
318
367
|
# src/driver/neo4j.py#L30
|
319
|
-
def
|
368
|
+
def graph_schema_tool(
|
369
|
+
self, msg: GraphSchemaTool | None
|
370
|
+
) -> str | Optional[Union[str, List[Dict[Any, Any]]]]:
|
320
371
|
"""
|
321
372
|
Retrieves the schema of a Neo4j graph database.
|
322
373
|
|
@@ -334,27 +385,42 @@ class Neo4jChatAgent(ChatAgent):
|
|
334
385
|
to database connectivity or query execution.
|
335
386
|
"""
|
336
387
|
self.tried_schema = True
|
388
|
+
if self.config.kg_schema is not None and len(self.config.kg_schema) > 0:
|
389
|
+
return self.config.kg_schema
|
337
390
|
schema_result = self.read_query("CALL db.schema.visualization()")
|
338
391
|
if schema_result.success:
|
339
|
-
#
|
392
|
+
# there is a possibility that the schema is empty, which is a valid response
|
340
393
|
# the schema.data will be: [{"nodes": [], "relationships": []}]
|
341
|
-
|
394
|
+
self.config.kg_schema = schema_result.data # type: ignore
|
395
|
+
return schema_result.data
|
342
396
|
else:
|
343
397
|
return f"Failed to retrieve schema: {schema_result.data}"
|
344
398
|
|
345
|
-
def
|
399
|
+
def _init_tools_sys_message(self) -> None:
|
346
400
|
"""Initialize message tools used for chatting."""
|
347
401
|
self.tried_schema = False
|
348
402
|
message = self._format_message()
|
349
403
|
self.config.system_message = self.config.system_message.format(mode=message)
|
350
|
-
if self.config.
|
404
|
+
if self.config.chat_mode:
|
405
|
+
self.config.addressing_prefix = self.config.addressing_prefix or SEND_TO
|
351
406
|
self.config.system_message += ADDRESSING_INSTRUCTION.format(
|
352
407
|
prefix=self.config.addressing_prefix
|
353
408
|
)
|
409
|
+
else:
|
410
|
+
self.config.system_message += DONE_INSTRUCTION
|
354
411
|
super().__init__(self.config)
|
355
|
-
|
356
|
-
self.
|
357
|
-
|
412
|
+
# Note we are enabling GraphSchemaTool regardless of whether
|
413
|
+
# self.config.use_schema_tools is True or False, because
|
414
|
+
# even when schema provided, the agent may later want to get the schema,
|
415
|
+
# e.g. if the db evolves, or if it needs to bring in the schema
|
416
|
+
self.enable_message(
|
417
|
+
[
|
418
|
+
GraphSchemaTool,
|
419
|
+
CypherRetrievalTool,
|
420
|
+
CypherCreationTool,
|
421
|
+
DoneTool,
|
422
|
+
]
|
423
|
+
)
|
358
424
|
|
359
425
|
def _format_message(self) -> str:
|
360
426
|
if self.driver is None:
|
@@ -363,5 +429,5 @@ class Neo4jChatAgent(ChatAgent):
|
|
363
429
|
return (
|
364
430
|
SCHEMA_TOOLS_SYS_MSG
|
365
431
|
if self.config.use_schema_tools
|
366
|
-
else
|
432
|
+
else SCHEMA_PROVIDED_SYS_MSG.format(schema=self.graph_schema_tool(None))
|
367
433
|
)
|