langroid 0.1.139__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +70 -0
- langroid/agent/__init__.py +22 -0
- langroid/agent/base.py +120 -33
- langroid/agent/batch.py +134 -35
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +164 -100
- langroid/agent/chat_document.py +19 -2
- langroid/agent/openai_assistant.py +20 -10
- langroid/agent/special/__init__.py +33 -10
- langroid/agent/special/doc_chat_agent.py +521 -108
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +23 -7
- langroid/agent/special/retriever_agent.py +29 -174
- langroid/agent/special/sql/__init__.py +7 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +11 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +423 -114
- langroid/agent/tool_message.py +67 -10
- langroid/agent/tools/__init__.py +8 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +6 -24
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/cachedb/__init__.py +6 -0
- langroid/embedding_models/__init__.py +24 -0
- langroid/embedding_models/base.py +9 -1
- langroid/embedding_models/models.py +117 -17
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +22 -0
- langroid/language_models/azure_openai.py +47 -4
- langroid/language_models/base.py +26 -10
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_gpt.py +407 -121
- langroid/language_models/prompt_formatter/__init__.py +9 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +10 -9
- langroid/mytypes.py +10 -4
- langroid/parsing/__init__.py +33 -1
- langroid/parsing/document_parser.py +259 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +20 -7
- langroid/parsing/repo_loader.py +108 -46
- langroid/parsing/search.py +8 -0
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -13
- langroid/parsing/urls.py +18 -9
- langroid/parsing/utils.py +130 -9
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +7 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +10 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/configuration.py +0 -1
- langroid/utils/constants.py +4 -0
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +15 -2
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +446 -4
- langroid/utils/system.py +36 -1
- langroid/vector_store/__init__.py +34 -2
- langroid/vector_store/base.py +33 -2
- langroid/vector_store/chromadb.py +42 -13
- langroid/vector_store/lancedb.py +226 -60
- langroid/vector_store/meilisearch.py +7 -6
- langroid/vector_store/momento.py +3 -2
- langroid/vector_store/qdrantdb.py +82 -11
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/METADATA +190 -129
- langroid-0.1.219.dist-info/RECORD +127 -0
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.139.dist-info/RECORD +0 -103
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/WHEEL +0 -0
@@ -0,0 +1,44 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
from langroid.agent.tool_message import ToolMessage
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
class QueryPlan(BaseModel):
|
11
|
+
original_query: str
|
12
|
+
query: str
|
13
|
+
filter: str
|
14
|
+
dataframe_calc: str = ""
|
15
|
+
|
16
|
+
|
17
|
+
class QueryPlanTool(ToolMessage):
|
18
|
+
request = "query_plan" # the agent method name that handles this tool
|
19
|
+
purpose = """
|
20
|
+
Given a user's query, generate a query <plan> consisting of:
|
21
|
+
- <original_query> - the original query for reference
|
22
|
+
- <filter> condition if needed (or empty string if no filter is needed)
|
23
|
+
- <query> - a possibly rephrased query that can be used to match the CONTENT
|
24
|
+
of the documents (can be same as <original_query> if no rephrasing is needed)
|
25
|
+
- <dataframe_calc> - a Pandas-dataframe calculation/aggregation string
|
26
|
+
that can be used to calculate the answer
|
27
|
+
(or empty string if no calculation is needed).
|
28
|
+
"""
|
29
|
+
plan: QueryPlan
|
30
|
+
|
31
|
+
|
32
|
+
class QueryPlanAnswerTool(ToolMessage):
|
33
|
+
request = "query_plan_answer" # the agent method name that handles this tool
|
34
|
+
purpose = """
|
35
|
+
Assemble query <plan> and <answer>
|
36
|
+
"""
|
37
|
+
plan: QueryPlan
|
38
|
+
answer: str
|
39
|
+
|
40
|
+
|
41
|
+
class QueryPlanFeedbackTool(ToolMessage):
|
42
|
+
request = "query_plan_feedback"
|
43
|
+
purpose = "To give <feedback> regarding the query plan."
|
44
|
+
feedback: str
|
File without changes
|
@@ -0,0 +1,174 @@
|
|
1
|
+
from typing import List, Optional
|
2
|
+
|
3
|
+
import pandas as pd
|
4
|
+
import typer
|
5
|
+
|
6
|
+
from langroid.agent.special.neo4j.neo4j_chat_agent import (
|
7
|
+
Neo4jChatAgent,
|
8
|
+
Neo4jChatAgentConfig,
|
9
|
+
)
|
10
|
+
from langroid.agent.tool_message import ToolMessage
|
11
|
+
from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
|
12
|
+
from langroid.parsing.table_loader import read_tabular_data
|
13
|
+
from langroid.utils.output import status
|
14
|
+
from langroid.vector_store.base import VectorStoreConfig
|
15
|
+
|
16
|
+
app = typer.Typer()
|
17
|
+
|
18
|
+
|
19
|
+
BUILD_KG_INSTRUCTIONS = """
|
20
|
+
Your task is to build a knowledge graph based on a CSV file.
|
21
|
+
|
22
|
+
You need to generate the graph database based on this
|
23
|
+
header:
|
24
|
+
|
25
|
+
{header}
|
26
|
+
|
27
|
+
and these sample rows:
|
28
|
+
|
29
|
+
{sample_rows}.
|
30
|
+
|
31
|
+
Leverage the above information to:
|
32
|
+
- Define node labels and their properties
|
33
|
+
- Infer relationships
|
34
|
+
- Infer constraints
|
35
|
+
ASK me if you need further information to figure out the schema.
|
36
|
+
You can use the tool/function `pandas_to_kg` to display and confirm
|
37
|
+
the nodes and relationships.
|
38
|
+
"""
|
39
|
+
|
40
|
+
DEFAULT_CSV_KG_CHAT_SYSTEM_MESSAGE = """
|
41
|
+
You are an expert in Knowledge Graphs and analyzing them using Neo4j.
|
42
|
+
You will be asked to answer questions based on the knowledge graph.
|
43
|
+
"""
|
44
|
+
|
45
|
+
|
46
|
+
def _preprocess_dataframe_for_neo4j(
|
47
|
+
df: pd.DataFrame, default_value: Optional[str] = None, remove_null_rows: bool = True
|
48
|
+
) -> pd.DataFrame:
|
49
|
+
"""
|
50
|
+
Preprocess a DataFrame for Neo4j import by fixing mismatched quotes in string
|
51
|
+
columns and handling null or missing values.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
df (DataFrame): The DataFrame to be preprocessed.
|
55
|
+
default_value (str, optional): The default value to replace null values.
|
56
|
+
This is ignored if remove_null_rows is True. Defaults to None.
|
57
|
+
remove_null_rows (bool, optional): If True, rows with any null values will
|
58
|
+
be removed.
|
59
|
+
If False, null values will be filled with default_value. Defaults to False.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
DataFrame: The preprocessed DataFrame ready for Neo4j import.
|
63
|
+
"""
|
64
|
+
|
65
|
+
# Fix mismatched quotes in string columns
|
66
|
+
for column in df.select_dtypes(include=["object"]):
|
67
|
+
df[column] = df[column].apply(
|
68
|
+
lambda x: x + '"' if (isinstance(x, str) and x.count('"') % 2 != 0) else x
|
69
|
+
)
|
70
|
+
|
71
|
+
# Handle null or missing values
|
72
|
+
if remove_null_rows:
|
73
|
+
df = df.dropna()
|
74
|
+
else:
|
75
|
+
if default_value is not None:
|
76
|
+
df = df.fillna(default_value)
|
77
|
+
|
78
|
+
return df
|
79
|
+
|
80
|
+
|
81
|
+
class CSVGraphAgentConfig(Neo4jChatAgentConfig):
|
82
|
+
system_message: str = DEFAULT_CSV_KG_CHAT_SYSTEM_MESSAGE
|
83
|
+
data: str | pd.DataFrame | None # data file, URL, or DataFrame
|
84
|
+
separator: None | str = None # separator for data file
|
85
|
+
vecdb: None | VectorStoreConfig = None
|
86
|
+
llm: OpenAIGPTConfig = OpenAIGPTConfig(
|
87
|
+
chat_model=OpenAIChatModel.GPT4_TURBO,
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
class PandasToKGTool(ToolMessage):
|
92
|
+
request: str = "pandas_to_kg"
|
93
|
+
purpose: str = """Use this tool to create ONLY nodes and their relationships based
|
94
|
+
on the created model.
|
95
|
+
Take into account that the Cypher query will be executed while iterating
|
96
|
+
over the rows in the CSV file (e.g. `index, row in df.iterrows()`),
|
97
|
+
so there NO NEED to load the CSV.
|
98
|
+
Make sure you send me the cypher query in this format:
|
99
|
+
- placeholders in <cypherQuery> should be based on the CSV header.
|
100
|
+
- <args> an array wherein each element corresponds to a placeholder in the
|
101
|
+
<cypherQuery> and provided in the same order as the headers.
|
102
|
+
SO the <args> should be the result of: `[row_dict[header] for header in headers]`
|
103
|
+
"""
|
104
|
+
cypherQuery: str
|
105
|
+
args: list[str]
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def examples(cls) -> List["ToolMessage"]:
|
109
|
+
return [
|
110
|
+
cls(
|
111
|
+
cypherQuery="""MERGE (employee:Employee {name: $employeeName,
|
112
|
+
id: $employeeId})\n
|
113
|
+
MERGE (department:Department {name: $departmentName})\n
|
114
|
+
MERGE (employee)-[:WORKS_IN]->(department)\n
|
115
|
+
SET employee.email = $employeeEmail""",
|
116
|
+
args=["employeeName", "employeeId", "departmentName", "employeeEmail"],
|
117
|
+
),
|
118
|
+
]
|
119
|
+
|
120
|
+
|
121
|
+
class CSVGraphAgent(Neo4jChatAgent):
|
122
|
+
def __init__(self, config: CSVGraphAgentConfig):
|
123
|
+
formatted_build_instr = ""
|
124
|
+
if isinstance(config.data, pd.DataFrame):
|
125
|
+
df = config.data
|
126
|
+
self.df = df
|
127
|
+
else:
|
128
|
+
if config.data:
|
129
|
+
df = read_tabular_data(config.data, config.separator)
|
130
|
+
df_cleaned = _preprocess_dataframe_for_neo4j(df)
|
131
|
+
|
132
|
+
df_cleaned.columns = df_cleaned.columns.str.strip().str.replace(
|
133
|
+
" +", "_", regex=True
|
134
|
+
)
|
135
|
+
|
136
|
+
self.df = df_cleaned
|
137
|
+
|
138
|
+
formatted_build_instr = BUILD_KG_INSTRUCTIONS.format(
|
139
|
+
header=self.df.columns, sample_rows=self.df.head(3)
|
140
|
+
)
|
141
|
+
|
142
|
+
config.system_message = config.system_message + formatted_build_instr
|
143
|
+
super().__init__(config)
|
144
|
+
|
145
|
+
self.config: Neo4jChatAgentConfig = config
|
146
|
+
|
147
|
+
self.enable_message(PandasToKGTool)
|
148
|
+
|
149
|
+
def pandas_to_kg(self, msg: PandasToKGTool) -> str:
|
150
|
+
"""
|
151
|
+
Creates nodes and relationships in the graph database based on the data in
|
152
|
+
a CSV file.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
msg (PandasToKGTool): An instance of the PandasToKGTool class containing
|
156
|
+
the necessary information for generating nodes.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
str: A string indicating the success or failure of the operation.
|
160
|
+
"""
|
161
|
+
with status("[cyan]Generating graph database..."):
|
162
|
+
if self.df is not None and hasattr(self.df, "iterrows"):
|
163
|
+
for counter, (index, row) in enumerate(self.df.iterrows()):
|
164
|
+
row_dict = row.to_dict()
|
165
|
+
response = self.write_query(
|
166
|
+
msg.cypherQuery,
|
167
|
+
parameters={header: row_dict[header] for header in msg.args},
|
168
|
+
)
|
169
|
+
# there is a possibility the generated cypher query is not correct
|
170
|
+
# so we need to check the response before continuing to the
|
171
|
+
# iteration
|
172
|
+
if counter == 0 and not response.success:
|
173
|
+
return str(response.data)
|
174
|
+
return "Graph database successfully generated"
|
@@ -0,0 +1,370 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
4
|
+
|
5
|
+
from pydantic import BaseModel, BaseSettings
|
6
|
+
from rich import print
|
7
|
+
from rich.console import Console
|
8
|
+
|
9
|
+
from langroid.agent import ToolMessage
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
import neo4j
|
13
|
+
|
14
|
+
|
15
|
+
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
16
|
+
from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
|
17
|
+
from langroid.agent.special.neo4j.utils.system_message import (
|
18
|
+
DEFAULT_NEO4J_CHAT_SYSTEM_MESSAGE,
|
19
|
+
DEFAULT_SYS_MSG,
|
20
|
+
SCHEMA_TOOLS_SYS_MSG,
|
21
|
+
)
|
22
|
+
from langroid.mytypes import Entity
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
console = Console()
|
27
|
+
|
28
|
+
NEO4J_ERROR_MSG = "There was an error in your Cypher Query"
|
29
|
+
|
30
|
+
|
31
|
+
# TOOLS to be used by the agent
|
32
|
+
|
33
|
+
|
34
|
+
class CypherRetrievalTool(ToolMessage):
|
35
|
+
request: str = "retrieval_query"
|
36
|
+
purpose: str = """Use this tool to send the Cypher query to retreive data from the
|
37
|
+
graph database based provided text description and schema."""
|
38
|
+
cypher_query: str
|
39
|
+
|
40
|
+
|
41
|
+
class CypherCreationTool(ToolMessage):
|
42
|
+
request: str = "create_query"
|
43
|
+
purpose: str = """Use this tool to send the Cypher query to create
|
44
|
+
entities/relationships in the graph database."""
|
45
|
+
cypher_query: str
|
46
|
+
|
47
|
+
|
48
|
+
class GraphSchemaTool(ToolMessage):
|
49
|
+
request: str = "get_schema"
|
50
|
+
purpose: str = """To get the schema of the graph database."""
|
51
|
+
|
52
|
+
|
53
|
+
class Neo4jSettings(BaseSettings):
|
54
|
+
uri: str = ""
|
55
|
+
username: str = ""
|
56
|
+
password: str = ""
|
57
|
+
database: str = ""
|
58
|
+
|
59
|
+
class Config:
|
60
|
+
# This enables the use of environment variables to set the settings,
|
61
|
+
# e.g. NEO4J_URI, NEO4J_USERNAME, etc.,
|
62
|
+
# which can either be set in a .env file or in the shell via export cmds.
|
63
|
+
env_prefix = "NEO4J_"
|
64
|
+
|
65
|
+
|
66
|
+
class QueryResult(BaseModel):
|
67
|
+
success: bool
|
68
|
+
data: Optional[Union[str, List[Dict[Any, Any]]]] = None
|
69
|
+
|
70
|
+
|
71
|
+
class Neo4jChatAgentConfig(ChatAgentConfig):
|
72
|
+
neo4j_settings: Neo4jSettings = Neo4jSettings()
|
73
|
+
system_message: str = DEFAULT_NEO4J_CHAT_SYSTEM_MESSAGE
|
74
|
+
kg_schema: Optional[List[Dict[str, Any]]]
|
75
|
+
database_created: bool = False
|
76
|
+
use_schema_tools: bool = True
|
77
|
+
use_functions_api: bool = True
|
78
|
+
use_tools: bool = False
|
79
|
+
|
80
|
+
|
81
|
+
class Neo4jChatAgent(ChatAgent):
|
82
|
+
def __init__(self, config: Neo4jChatAgentConfig):
|
83
|
+
"""Initialize the Neo4jChatAgent.
|
84
|
+
|
85
|
+
Raises:
|
86
|
+
ValueError: If database information is not provided in the config.
|
87
|
+
"""
|
88
|
+
self.config = config
|
89
|
+
self._validate_config()
|
90
|
+
self._import_neo4j()
|
91
|
+
self._initialize_connection()
|
92
|
+
self._init_tool_messages()
|
93
|
+
|
94
|
+
def _validate_config(self) -> None:
|
95
|
+
"""Validate the configuration to ensure all necessary fields are present."""
|
96
|
+
assert isinstance(self.config, Neo4jChatAgentConfig)
|
97
|
+
if (
|
98
|
+
self.config.neo4j_settings.username is None
|
99
|
+
and self.config.neo4j_settings.password is None
|
100
|
+
and self.config.neo4j_settings.database
|
101
|
+
):
|
102
|
+
raise ValueError("Neo4j env information must be provided")
|
103
|
+
|
104
|
+
def _import_neo4j(self) -> None:
|
105
|
+
"""Dynamically imports the Neo4j module and sets it as a global variable."""
|
106
|
+
global neo4j
|
107
|
+
try:
|
108
|
+
import neo4j
|
109
|
+
except ImportError:
|
110
|
+
raise ImportError(
|
111
|
+
"""
|
112
|
+
neo4j not installed. Please install it via:
|
113
|
+
pip install neo4j.
|
114
|
+
Or when installing langroid, install it with the `neo4j` extra:
|
115
|
+
pip install langroid[neo4j]
|
116
|
+
"""
|
117
|
+
)
|
118
|
+
|
119
|
+
def _initialize_connection(self) -> None:
|
120
|
+
"""
|
121
|
+
Initializes a connection to the Neo4j database using the configuration settings.
|
122
|
+
"""
|
123
|
+
try:
|
124
|
+
assert isinstance(self.config, Neo4jChatAgentConfig)
|
125
|
+
self.driver = neo4j.GraphDatabase.driver(
|
126
|
+
self.config.neo4j_settings.uri,
|
127
|
+
auth=(
|
128
|
+
self.config.neo4j_settings.username,
|
129
|
+
self.config.neo4j_settings.password,
|
130
|
+
),
|
131
|
+
)
|
132
|
+
except Exception as e:
|
133
|
+
raise ConnectionError(f"Failed to initialize Neo4j connection: {e}")
|
134
|
+
|
135
|
+
def close(self) -> None:
|
136
|
+
"""close the connection"""
|
137
|
+
if self.driver:
|
138
|
+
self.driver.close()
|
139
|
+
|
140
|
+
def retry_query(self, e: Exception, query: str) -> str:
|
141
|
+
"""
|
142
|
+
Generate an error message for a failed Cypher query and return it.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
e (Exception): The exception raised during the Cypher query execution.
|
146
|
+
query (str): The Cypher query that failed.
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
str: The error message.
|
150
|
+
"""
|
151
|
+
logger.error(f"Cypher Query failed: {query}\nException: {e}")
|
152
|
+
|
153
|
+
# Construct the error message
|
154
|
+
error_message_template = f"""\
|
155
|
+
{NEO4J_ERROR_MSG}: '{query}'
|
156
|
+
{str(e)}
|
157
|
+
Run a new query, correcting the errors.
|
158
|
+
"""
|
159
|
+
|
160
|
+
return error_message_template
|
161
|
+
|
162
|
+
def read_query(
|
163
|
+
self, query: str, parameters: Optional[Dict[Any, Any]] = None
|
164
|
+
) -> QueryResult:
|
165
|
+
"""
|
166
|
+
Executes a given Cypher query with parameters on the Neo4j database.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
query (str): The Cypher query string to be executed.
|
170
|
+
parameters (Optional[Dict[Any, Any]]): A dictionary of parameters for
|
171
|
+
the query.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
QueryResult: An object representing the outcome of the query execution.
|
175
|
+
"""
|
176
|
+
if not self.driver:
|
177
|
+
return QueryResult(
|
178
|
+
success=False, data="No database connection is established."
|
179
|
+
)
|
180
|
+
|
181
|
+
try:
|
182
|
+
assert isinstance(self.config, Neo4jChatAgentConfig)
|
183
|
+
with self.driver.session(
|
184
|
+
database=self.config.neo4j_settings.database
|
185
|
+
) as session:
|
186
|
+
result = session.run(query, parameters)
|
187
|
+
if result.peek():
|
188
|
+
records = [record.data() for record in result]
|
189
|
+
return QueryResult(success=True, data=records)
|
190
|
+
else:
|
191
|
+
return QueryResult(success=True, data=[])
|
192
|
+
except Exception as e:
|
193
|
+
logger.error(f"Failed to execute query: {query}\n{e}")
|
194
|
+
error_message = self.retry_query(e, query)
|
195
|
+
return QueryResult(success=False, data=error_message)
|
196
|
+
finally:
|
197
|
+
self.close()
|
198
|
+
|
199
|
+
def write_query(
|
200
|
+
self, query: str, parameters: Optional[Dict[Any, Any]] = None
|
201
|
+
) -> QueryResult:
|
202
|
+
"""
|
203
|
+
Executes a write transaction using a given Cypher query on the Neo4j database.
|
204
|
+
This method should be used for queries that modify the database.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
query (str): The Cypher query string to be executed.
|
208
|
+
parameters (dict, optional): A dict of parameters for the Cypher query.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
QueryResult: An object representing the outcome of the query execution.
|
212
|
+
It contains a success flag and an optional error message.
|
213
|
+
"""
|
214
|
+
if not self.driver:
|
215
|
+
return QueryResult(
|
216
|
+
success=False, data="No database connection is established."
|
217
|
+
)
|
218
|
+
|
219
|
+
try:
|
220
|
+
assert isinstance(self.config, Neo4jChatAgentConfig)
|
221
|
+
with self.driver.session(
|
222
|
+
database=self.config.neo4j_settings.database
|
223
|
+
) as session:
|
224
|
+
session.write_transaction(lambda tx: tx.run(query, parameters))
|
225
|
+
return QueryResult(success=True)
|
226
|
+
except Exception as e:
|
227
|
+
logging.warning(f"An error occurred: {e}")
|
228
|
+
error_message = self.retry_query(e, query)
|
229
|
+
return QueryResult(success=False, data=error_message)
|
230
|
+
finally:
|
231
|
+
self.close()
|
232
|
+
|
233
|
+
# TODO: test under enterprise edition because community edition doesn't allow
|
234
|
+
# database creation/deletion
|
235
|
+
def remove_database(self) -> None:
|
236
|
+
"""Deletes all nodes and relationships from the current Neo4j database."""
|
237
|
+
delete_query = """
|
238
|
+
MATCH (n)
|
239
|
+
DETACH DELETE n
|
240
|
+
"""
|
241
|
+
response = self.write_query(delete_query)
|
242
|
+
|
243
|
+
if response.success:
|
244
|
+
print("[green]Database is deleted!")
|
245
|
+
else:
|
246
|
+
print("[red]Database is not deleted!")
|
247
|
+
|
248
|
+
def retrieval_query(self, msg: CypherRetrievalTool) -> str:
|
249
|
+
""" "
|
250
|
+
Handle a CypherRetrievalTool message by executing a Cypher query and
|
251
|
+
returning the result.
|
252
|
+
Args:
|
253
|
+
msg (CypherRetrievalTool): The tool-message to handle.
|
254
|
+
|
255
|
+
Returns:
|
256
|
+
str: The result of executing the cypher_query.
|
257
|
+
"""
|
258
|
+
query = msg.cypher_query
|
259
|
+
|
260
|
+
logger.info(f"Executing Cypher query: {query}")
|
261
|
+
response = self.read_query(query)
|
262
|
+
if response.success:
|
263
|
+
return json.dumps(response.data)
|
264
|
+
else:
|
265
|
+
return str(response.data)
|
266
|
+
|
267
|
+
def create_query(self, msg: CypherCreationTool) -> str:
|
268
|
+
""" "
|
269
|
+
Handle a CypherCreationTool message by executing a Cypher query and
|
270
|
+
returning the result.
|
271
|
+
Args:
|
272
|
+
msg (CypherCreationTool): The tool-message to handle.
|
273
|
+
|
274
|
+
Returns:
|
275
|
+
str: The result of executing the cypher_query.
|
276
|
+
"""
|
277
|
+
query = msg.cypher_query
|
278
|
+
|
279
|
+
logger.info(f"Executing Cypher query: {query}")
|
280
|
+
response = self.write_query(query)
|
281
|
+
if response.success:
|
282
|
+
return "Cypher query executed successfully"
|
283
|
+
else:
|
284
|
+
return str(response.data)
|
285
|
+
|
286
|
+
# TODO: There are various ways to get the schema. The current one uses the func
|
287
|
+
# `read_query`, which requires post processing to identify whether the response upon
|
288
|
+
# the schema query is valid. Another way is to isolate this func from `read_query`.
|
289
|
+
# The current query works well. But we could use the queries here:
|
290
|
+
# https://github.com/neo4j/NaLLM/blob/1af09cd117ba0777d81075c597a5081583568f9f/api/
|
291
|
+
# src/driver/neo4j.py#L30
|
292
|
+
def get_schema(self, msg: GraphSchemaTool | None) -> str:
|
293
|
+
"""
|
294
|
+
Retrieves the schema of a Neo4j graph database.
|
295
|
+
|
296
|
+
Args:
|
297
|
+
msg (GraphSchemaTool): An instance of GraphDatabaseSchema, typically
|
298
|
+
containing information or parameters needed for the database query.
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
str: The visual representation of the database schema as a string, or a
|
302
|
+
message stating that the database schema is empty or not valid.
|
303
|
+
|
304
|
+
Raises:
|
305
|
+
This function does not explicitly raise exceptions but depends on the
|
306
|
+
behavior of 'self.read_query' method, which might raise exceptions related
|
307
|
+
to database connectivity or query execution.
|
308
|
+
"""
|
309
|
+
schema_result = self.read_query("CALL db.schema.visualization()")
|
310
|
+
if schema_result.success:
|
311
|
+
# ther is a possibility that the schema is empty, which is a valid response
|
312
|
+
# the schema.data will be: [{"nodes": [], "relationships": []}]
|
313
|
+
return json.dumps(schema_result.data)
|
314
|
+
else:
|
315
|
+
return f"Failed to retrieve schema: {schema_result.data}"
|
316
|
+
|
317
|
+
def _init_tool_messages(self) -> None:
|
318
|
+
"""Initialize message tools used for chatting."""
|
319
|
+
message = self._format_message()
|
320
|
+
self.config.system_message = self.config.system_message.format(mode=message)
|
321
|
+
super().__init__(self.config)
|
322
|
+
self.enable_message(CypherRetrievalTool)
|
323
|
+
self.enable_message(CypherCreationTool)
|
324
|
+
self.enable_message(GraphSchemaTool)
|
325
|
+
|
326
|
+
def _format_message(self) -> str:
|
327
|
+
if self.driver is None:
|
328
|
+
raise ValueError("Database driver None")
|
329
|
+
assert isinstance(self.config, Neo4jChatAgentConfig)
|
330
|
+
return (
|
331
|
+
SCHEMA_TOOLS_SYS_MSG
|
332
|
+
if self.config.use_schema_tools
|
333
|
+
else DEFAULT_SYS_MSG.format(schema=self.get_schema(None))
|
334
|
+
)
|
335
|
+
|
336
|
+
def agent_response(
|
337
|
+
self,
|
338
|
+
msg: Optional[str | ChatDocument] = None,
|
339
|
+
) -> Optional[ChatDocument]:
|
340
|
+
if msg is None:
|
341
|
+
return None
|
342
|
+
|
343
|
+
results = self.handle_message(msg)
|
344
|
+
if results is None:
|
345
|
+
return None
|
346
|
+
|
347
|
+
output = results
|
348
|
+
if NEO4J_ERROR_MSG in output:
|
349
|
+
output = "There was an error in the Cypher Query. Press enter to retry."
|
350
|
+
|
351
|
+
console.print(f"[red]{self.indent}", end="")
|
352
|
+
print(f"[red]Agent: {output}")
|
353
|
+
sender_name = self.config.name
|
354
|
+
if isinstance(msg, ChatDocument) and msg.function_call is not None:
|
355
|
+
sender_name = msg.function_call.name
|
356
|
+
|
357
|
+
content = results.content if isinstance(results, ChatDocument) else results
|
358
|
+
recipient = (
|
359
|
+
results.metadata.recipient if isinstance(results, ChatDocument) else ""
|
360
|
+
)
|
361
|
+
|
362
|
+
return ChatDocument(
|
363
|
+
content=content,
|
364
|
+
metadata=ChatDocMetaData(
|
365
|
+
# source=Entity.AGENT,
|
366
|
+
sender=Entity.LLM,
|
367
|
+
sender_name=sender_name,
|
368
|
+
recipient=recipient,
|
369
|
+
),
|
370
|
+
)
|
File without changes
|
@@ -0,0 +1,46 @@
|
|
1
|
+
DEFAULT_SYS_MSG = """You are a data scientist and expert in Knowledge Graphs,
|
2
|
+
with expertise in answering questions by interacting with a Neo4j graph database.
|
3
|
+
|
4
|
+
The schema maps the Neo4j database structure. node labels, relationship types,
|
5
|
+
and property keys available in your Neo4j database.
|
6
|
+
{schema}
|
7
|
+
Do not make assumptions about the database schema before using the tools.
|
8
|
+
Use the tool/function to learn more about the database schema."""
|
9
|
+
|
10
|
+
SCHEMA_TOOLS_SYS_MSG = """You are a data scientist and expert in Knowledge Graphs,
|
11
|
+
with expertise in answering questions by querying Neo4j database.
|
12
|
+
You have access to the following tools:
|
13
|
+
- `retrieval_query` tool/function-call to retreive infomration from the graph database
|
14
|
+
to answer questions.
|
15
|
+
|
16
|
+
- `create_query` tool/function-call to execute cypher query that creates
|
17
|
+
entities/relationships in the graph database.
|
18
|
+
|
19
|
+
- `get_schema` tool/function-call to get all the node labels, relationship
|
20
|
+
types, and property keys available in your Neo4j database.
|
21
|
+
|
22
|
+
You must be smart about using the right node labels, relationship types, and property
|
23
|
+
keys based on the english description. If you are thinking of using a node label,
|
24
|
+
relationship type, or property key that does not exist, you are probably on the wrong
|
25
|
+
track, so you should try your best to answer based on an existing table or column.
|
26
|
+
DO NOT assume any nodes or relationships other than those above."""
|
27
|
+
|
28
|
+
DEFAULT_NEO4J_CHAT_SYSTEM_MESSAGE = """
|
29
|
+
{mode}
|
30
|
+
|
31
|
+
You do not need to attempt answering a question with just one query.
|
32
|
+
You could make a sequence of Neo4j queries to help you write the final query.
|
33
|
+
Also if you receive a null or other unexpected result,
|
34
|
+
(a) make sure you use the available TOOLs correctly, and
|
35
|
+
(b) see if you have made an assumption in your Neo4j query, and try another way,
|
36
|
+
or use `retrieval_query` to explore the database contents before submitting your
|
37
|
+
final query.
|
38
|
+
(c) USE `create_query` tool/function-call to execute cypher query that creates
|
39
|
+
entities/relationships in the graph database.
|
40
|
+
|
41
|
+
(d) USE `get_schema` tool/function-call to get all the node labels, relationship
|
42
|
+
types, and property keys available in your Neo4j database.
|
43
|
+
|
44
|
+
Start by asking what I would like to know about the data.
|
45
|
+
|
46
|
+
"""
|