langroid 0.58.2__py3-none-any.whl → 0.59.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +39 -17
- langroid/agent/base.py-e +2216 -0
- langroid/agent/callbacks/chainlit.py +2 -1
- langroid/agent/chat_agent.py +73 -55
- langroid/agent/chat_agent.py-e +2086 -0
- langroid/agent/chat_document.py +7 -7
- langroid/agent/chat_document.py-e +513 -0
- langroid/agent/openai_assistant.py +9 -9
- langroid/agent/openai_assistant.py-e +882 -0
- langroid/agent/special/arangodb/arangodb_agent.py +10 -18
- langroid/agent/special/arangodb/arangodb_agent.py-e +648 -0
- langroid/agent/special/arangodb/tools.py +3 -3
- langroid/agent/special/doc_chat_agent.py +16 -14
- langroid/agent/special/lance_rag/critic_agent.py +2 -2
- langroid/agent/special/lance_rag/query_planner_agent.py +4 -4
- langroid/agent/special/lance_tools.py +6 -5
- langroid/agent/special/lance_tools.py-e +61 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +3 -7
- langroid/agent/special/neo4j/neo4j_chat_agent.py-e +430 -0
- langroid/agent/special/relevance_extractor_agent.py +1 -1
- langroid/agent/special/sql/sql_chat_agent.py +11 -3
- langroid/agent/task.py +9 -87
- langroid/agent/task.py-e +2418 -0
- langroid/agent/tool_message.py +33 -17
- langroid/agent/tool_message.py-e +400 -0
- langroid/agent/tools/file_tools.py +4 -2
- langroid/agent/tools/file_tools.py-e +234 -0
- langroid/agent/tools/mcp/fastmcp_client.py +19 -6
- langroid/agent/tools/mcp/fastmcp_client.py-e +584 -0
- langroid/agent/tools/orchestration.py +22 -17
- langroid/agent/tools/orchestration.py-e +301 -0
- langroid/agent/tools/recipient_tool.py +3 -3
- langroid/agent/tools/task_tool.py +22 -16
- langroid/agent/tools/task_tool.py-e +249 -0
- langroid/agent/xml_tool_message.py +90 -35
- langroid/agent/xml_tool_message.py-e +392 -0
- langroid/cachedb/base.py +1 -1
- langroid/embedding_models/base.py +2 -2
- langroid/embedding_models/models.py +3 -7
- langroid/embedding_models/models.py-e +563 -0
- langroid/exceptions.py +4 -1
- langroid/language_models/azure_openai.py +2 -2
- langroid/language_models/azure_openai.py-e +134 -0
- langroid/language_models/base.py +6 -4
- langroid/language_models/base.py-e +812 -0
- langroid/language_models/client_cache.py +64 -0
- langroid/language_models/config.py +2 -4
- langroid/language_models/config.py-e +18 -0
- langroid/language_models/model_info.py +9 -1
- langroid/language_models/model_info.py-e +483 -0
- langroid/language_models/openai_gpt.py +119 -20
- langroid/language_models/openai_gpt.py-e +2280 -0
- langroid/language_models/provider_params.py +3 -22
- langroid/language_models/provider_params.py-e +153 -0
- langroid/mytypes.py +11 -4
- langroid/mytypes.py-e +132 -0
- langroid/parsing/code_parser.py +1 -1
- langroid/parsing/file_attachment.py +1 -1
- langroid/parsing/file_attachment.py-e +246 -0
- langroid/parsing/md_parser.py +14 -4
- langroid/parsing/md_parser.py-e +574 -0
- langroid/parsing/parser.py +22 -7
- langroid/parsing/parser.py-e +410 -0
- langroid/parsing/repo_loader.py +3 -1
- langroid/parsing/repo_loader.py-e +812 -0
- langroid/parsing/search.py +1 -1
- langroid/parsing/url_loader.py +17 -51
- langroid/parsing/url_loader.py-e +683 -0
- langroid/parsing/urls.py +5 -4
- langroid/parsing/urls.py-e +279 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/pydantic_v1/__init__.py +45 -6
- langroid/pydantic_v1/__init__.py-e +36 -0
- langroid/pydantic_v1/main.py +11 -4
- langroid/pydantic_v1/main.py-e +11 -0
- langroid/utils/configuration.py +13 -11
- langroid/utils/configuration.py-e +141 -0
- langroid/utils/constants.py +1 -1
- langroid/utils/constants.py-e +32 -0
- langroid/utils/globals.py +21 -5
- langroid/utils/globals.py-e +49 -0
- langroid/utils/html_logger.py +2 -1
- langroid/utils/html_logger.py-e +825 -0
- langroid/utils/object_registry.py +1 -1
- langroid/utils/object_registry.py-e +66 -0
- langroid/utils/pydantic_utils.py +55 -28
- langroid/utils/pydantic_utils.py-e +602 -0
- langroid/utils/types.py +2 -2
- langroid/utils/types.py-e +113 -0
- langroid/vector_store/base.py +3 -3
- langroid/vector_store/lancedb.py +5 -5
- langroid/vector_store/lancedb.py-e +404 -0
- langroid/vector_store/meilisearch.py +2 -2
- langroid/vector_store/pineconedb.py +4 -4
- langroid/vector_store/pineconedb.py-e +427 -0
- langroid/vector_store/postgres.py +1 -1
- langroid/vector_store/qdrantdb.py +3 -3
- langroid/vector_store/weaviatedb.py +1 -1
- {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/METADATA +3 -2
- langroid-0.59.0b1.dist-info/RECORD +181 -0
- langroid/agent/special/doc_chat_task.py +0 -0
- langroid/mcp/__init__.py +0 -1
- langroid/mcp/server/__init__.py +0 -1
- langroid-0.58.2.dist-info/RECORD +0 -145
- {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/WHEEL +0 -0
- {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,430 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
3
|
+
|
4
|
+
from pydantic_settings import BaseSettings
|
5
|
+
from rich import print
|
6
|
+
from rich.console import Console
|
7
|
+
|
8
|
+
from pydantic import BaseModel, ConfigDict
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
import neo4j
|
12
|
+
|
13
|
+
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
14
|
+
from langroid.agent.chat_document import ChatDocument
|
15
|
+
from langroid.agent.special.neo4j.system_messages import (
|
16
|
+
ADDRESSING_INSTRUCTION,
|
17
|
+
DEFAULT_NEO4J_CHAT_SYSTEM_MESSAGE,
|
18
|
+
DONE_INSTRUCTION,
|
19
|
+
SCHEMA_PROVIDED_SYS_MSG,
|
20
|
+
SCHEMA_TOOLS_SYS_MSG,
|
21
|
+
)
|
22
|
+
from langroid.agent.special.neo4j.tools import (
|
23
|
+
CypherCreationTool,
|
24
|
+
CypherRetrievalTool,
|
25
|
+
GraphSchemaTool,
|
26
|
+
cypher_creation_tool_name,
|
27
|
+
cypher_retrieval_tool_name,
|
28
|
+
graph_schema_tool_name,
|
29
|
+
)
|
30
|
+
from langroid.agent.tools.orchestration import DoneTool, ForwardTool
|
31
|
+
from langroid.exceptions import LangroidImportError
|
32
|
+
from langroid.mytypes import Entity
|
33
|
+
from langroid.utils.constants import SEND_TO
|
34
|
+
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
console = Console()
|
38
|
+
|
39
|
+
NEO4J_ERROR_MSG = "There was an error in your Cypher Query"
|
40
|
+
|
41
|
+
|
42
|
+
# TOOLS to be used by the agent
|
43
|
+
|
44
|
+
|
45
|
+
class Neo4jSettings(BaseSettings):
|
46
|
+
uri: str = ""
|
47
|
+
username: str = ""
|
48
|
+
password: str = ""
|
49
|
+
database: str = ""
|
50
|
+
|
51
|
+
model_config = ConfigDict(env_prefix="NEO4J_")
|
52
|
+
|
53
|
+
|
54
|
+
class QueryResult(BaseModel):
|
55
|
+
success: bool
|
56
|
+
data: List[Dict[Any, Any]] | str | None = None
|
57
|
+
|
58
|
+
|
59
|
+
class Neo4jChatAgentConfig(ChatAgentConfig):
|
60
|
+
neo4j_settings: Neo4jSettings = Neo4jSettings()
|
61
|
+
system_message: str = DEFAULT_NEO4J_CHAT_SYSTEM_MESSAGE
|
62
|
+
kg_schema: Optional[List[Dict[str, Any]]] = None
|
63
|
+
database_created: bool = False
|
64
|
+
# whether agent MUST use schema_tools to get schema, i.e.
|
65
|
+
# schema is NOT initially provided
|
66
|
+
use_schema_tools: bool = True
|
67
|
+
use_functions_api: bool = True
|
68
|
+
use_tools: bool = False
|
69
|
+
# whether the agent is used in a continuous chat with user,
|
70
|
+
# as opposed to returning a result from the task.run()
|
71
|
+
chat_mode: bool = False
|
72
|
+
addressing_prefix: str = ""
|
73
|
+
|
74
|
+
|
75
|
+
class Neo4jChatAgent(ChatAgent):
|
76
|
+
def __init__(self, config: Neo4jChatAgentConfig):
|
77
|
+
"""Initialize the Neo4jChatAgent.
|
78
|
+
|
79
|
+
Raises:
|
80
|
+
ValueError: If database information is not provided in the config.
|
81
|
+
"""
|
82
|
+
self.config: Neo4jChatAgentConfig = config
|
83
|
+
self._validate_config()
|
84
|
+
self._import_neo4j()
|
85
|
+
self._initialize_db()
|
86
|
+
self._init_tools_sys_message()
|
87
|
+
self.init_state()
|
88
|
+
|
89
|
+
def init_state(self) -> None:
|
90
|
+
super().init_state()
|
91
|
+
self.current_retrieval_cypher_query: str = ""
|
92
|
+
self.tried_schema: bool = False
|
93
|
+
|
94
|
+
def handle_message_fallback(
|
95
|
+
self, msg: str | ChatDocument
|
96
|
+
) -> str | ForwardTool | None:
|
97
|
+
"""
|
98
|
+
When LLM sends a no-tool msg, assume user is the intended recipient,
|
99
|
+
and if in interactive mode, forward the msg to the user.
|
100
|
+
"""
|
101
|
+
|
102
|
+
done_tool_name = DoneTool.default_value("request")
|
103
|
+
forward_tool_name = ForwardTool.default_value("request")
|
104
|
+
if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM:
|
105
|
+
if self.interactive:
|
106
|
+
return ForwardTool(agent="User")
|
107
|
+
else:
|
108
|
+
if self.config.chat_mode:
|
109
|
+
return f"""
|
110
|
+
Since you did not explicitly address the User, it is not clear
|
111
|
+
whether:
|
112
|
+
- you intend this to be the final response to the
|
113
|
+
user's query/request, in which case you must use the
|
114
|
+
`{forward_tool_name}` to indicate this.
|
115
|
+
- OR, you FORGOT to use an Appropriate TOOL,
|
116
|
+
in which case you should use the available tools to
|
117
|
+
make progress on the user's query/request.
|
118
|
+
"""
|
119
|
+
return f"""
|
120
|
+
The intent of your response is not clear:
|
121
|
+
- if you intended this to be the final answer to the user's query,
|
122
|
+
then use the `{done_tool_name}` to indicate so,
|
123
|
+
with the `content` set to the answer or result.
|
124
|
+
- otherwise, use one of the available tools to make progress
|
125
|
+
to arrive at the final answer.
|
126
|
+
"""
|
127
|
+
return None
|
128
|
+
|
129
|
+
def _validate_config(self) -> None:
|
130
|
+
"""Validate the configuration to ensure all necessary fields are present."""
|
131
|
+
assert isinstance(self.config, Neo4jChatAgentConfig)
|
132
|
+
if (
|
133
|
+
self.config.neo4j_settings.username is None
|
134
|
+
and self.config.neo4j_settings.password is None
|
135
|
+
and self.config.neo4j_settings.database
|
136
|
+
):
|
137
|
+
raise ValueError("Neo4j env information must be provided")
|
138
|
+
|
139
|
+
def _import_neo4j(self) -> None:
|
140
|
+
"""Dynamically imports the Neo4j module and sets it as a global variable."""
|
141
|
+
global neo4j
|
142
|
+
try:
|
143
|
+
import neo4j
|
144
|
+
except ImportError:
|
145
|
+
raise LangroidImportError("neo4j", "neo4j")
|
146
|
+
|
147
|
+
def _initialize_db(self) -> None:
|
148
|
+
"""
|
149
|
+
Initializes a connection to the Neo4j database using the configuration settings.
|
150
|
+
"""
|
151
|
+
try:
|
152
|
+
assert isinstance(self.config, Neo4jChatAgentConfig)
|
153
|
+
self.driver = neo4j.GraphDatabase.driver(
|
154
|
+
self.config.neo4j_settings.uri,
|
155
|
+
auth=(
|
156
|
+
self.config.neo4j_settings.username,
|
157
|
+
self.config.neo4j_settings.password,
|
158
|
+
),
|
159
|
+
)
|
160
|
+
with self.driver.session() as session:
|
161
|
+
result = session.run("MATCH (n) RETURN count(n) as count")
|
162
|
+
count = result.single()["count"] # type: ignore
|
163
|
+
self.config.database_created = count > 0
|
164
|
+
|
165
|
+
# If database has data, get schema
|
166
|
+
if self.config.database_created:
|
167
|
+
# this updates self.config.kg_schema
|
168
|
+
self.graph_schema_tool(None)
|
169
|
+
|
170
|
+
except Exception as e:
|
171
|
+
raise ConnectionError(f"Failed to initialize Neo4j connection: {e}")
|
172
|
+
|
173
|
+
def close(self) -> None:
|
174
|
+
"""close the connection"""
|
175
|
+
if self.driver:
|
176
|
+
self.driver.close()
|
177
|
+
|
178
|
+
def retry_query(self, e: Exception, query: str) -> str:
|
179
|
+
"""
|
180
|
+
Generate an error message for a failed Cypher query and return it.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
e (Exception): The exception raised during the Cypher query execution.
|
184
|
+
query (str): The Cypher query that failed.
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
str: The error message.
|
188
|
+
"""
|
189
|
+
logger.error(f"Cypher Query failed: {query}\nException: {e}")
|
190
|
+
|
191
|
+
# Construct the error message
|
192
|
+
error_message_template = f"""\
|
193
|
+
{NEO4J_ERROR_MSG}: '{query}'
|
194
|
+
{str(e)}
|
195
|
+
Run a new query, correcting the errors.
|
196
|
+
"""
|
197
|
+
|
198
|
+
return error_message_template
|
199
|
+
|
200
|
+
def read_query(
|
201
|
+
self, query: str, parameters: Optional[Dict[Any, Any]] = None
|
202
|
+
) -> QueryResult:
|
203
|
+
"""
|
204
|
+
Executes a given Cypher query with parameters on the Neo4j database.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
query (str): The Cypher query string to be executed.
|
208
|
+
parameters (Optional[Dict[Any, Any]]): A dictionary of parameters for
|
209
|
+
the query.
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
QueryResult: An object representing the outcome of the query execution.
|
213
|
+
"""
|
214
|
+
if not self.driver:
|
215
|
+
return QueryResult(
|
216
|
+
success=False, data="No database connection is established."
|
217
|
+
)
|
218
|
+
|
219
|
+
try:
|
220
|
+
assert isinstance(self.config, Neo4jChatAgentConfig)
|
221
|
+
with self.driver.session(
|
222
|
+
database=self.config.neo4j_settings.database
|
223
|
+
) as session:
|
224
|
+
result = session.run(query, parameters)
|
225
|
+
if result.peek():
|
226
|
+
records = [record.data() for record in result]
|
227
|
+
return QueryResult(success=True, data=records)
|
228
|
+
else:
|
229
|
+
return QueryResult(success=True, data=[])
|
230
|
+
except Exception as e:
|
231
|
+
logger.error(f"Failed to execute query: {query}\n{e}")
|
232
|
+
error_message = self.retry_query(e, query)
|
233
|
+
return QueryResult(success=False, data=error_message)
|
234
|
+
finally:
|
235
|
+
self.close()
|
236
|
+
|
237
|
+
def write_query(
|
238
|
+
self, query: str, parameters: Optional[Dict[Any, Any]] = None
|
239
|
+
) -> QueryResult:
|
240
|
+
"""
|
241
|
+
Executes a write transaction using a given Cypher query on the Neo4j database.
|
242
|
+
This method should be used for queries that modify the database.
|
243
|
+
|
244
|
+
Args:
|
245
|
+
query (str): The Cypher query string to be executed.
|
246
|
+
parameters (dict, optional): A dict of parameters for the Cypher query.
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
QueryResult: An object representing the outcome of the query execution.
|
250
|
+
It contains a success flag and an optional error message.
|
251
|
+
"""
|
252
|
+
# Check if query contains database/collection creation patterns
|
253
|
+
query_upper = query.upper()
|
254
|
+
is_creation_query = any(
|
255
|
+
[
|
256
|
+
"CREATE" in query_upper,
|
257
|
+
"MERGE" in query_upper,
|
258
|
+
"CREATE CONSTRAINT" in query_upper,
|
259
|
+
"CREATE INDEX" in query_upper,
|
260
|
+
]
|
261
|
+
)
|
262
|
+
|
263
|
+
if is_creation_query:
|
264
|
+
self.config.database_created = True
|
265
|
+
logger.info("Detected database/collection creation query")
|
266
|
+
|
267
|
+
if not self.driver:
|
268
|
+
return QueryResult(
|
269
|
+
success=False, data="No database connection is established."
|
270
|
+
)
|
271
|
+
|
272
|
+
try:
|
273
|
+
assert isinstance(self.config, Neo4jChatAgentConfig)
|
274
|
+
with self.driver.session(
|
275
|
+
database=self.config.neo4j_settings.database
|
276
|
+
) as session:
|
277
|
+
session.write_transaction(lambda tx: tx.run(query, parameters))
|
278
|
+
return QueryResult(success=True)
|
279
|
+
except Exception as e:
|
280
|
+
logging.warning(f"An error occurred: {e}")
|
281
|
+
error_message = self.retry_query(e, query)
|
282
|
+
return QueryResult(success=False, data=error_message)
|
283
|
+
finally:
|
284
|
+
self.close()
|
285
|
+
|
286
|
+
# TODO: test under enterprise edition because community edition doesn't allow
|
287
|
+
# database creation/deletion
|
288
|
+
def remove_database(self) -> None:
|
289
|
+
"""Deletes all nodes and relationships from the current Neo4j database."""
|
290
|
+
delete_query = """
|
291
|
+
MATCH (n)
|
292
|
+
DETACH DELETE n
|
293
|
+
"""
|
294
|
+
response = self.write_query(delete_query)
|
295
|
+
|
296
|
+
if response.success:
|
297
|
+
print("[green]Database is deleted!")
|
298
|
+
else:
|
299
|
+
print("[red]Database is not deleted!")
|
300
|
+
|
301
|
+
def cypher_retrieval_tool(self, msg: CypherRetrievalTool) -> str:
|
302
|
+
""" "
|
303
|
+
Handle a CypherRetrievalTool message by executing a Cypher query and
|
304
|
+
returning the result.
|
305
|
+
Args:
|
306
|
+
msg (CypherRetrievalTool): The tool-message to handle.
|
307
|
+
|
308
|
+
Returns:
|
309
|
+
str: The result of executing the cypher_query.
|
310
|
+
"""
|
311
|
+
if not self.tried_schema:
|
312
|
+
return f"""
|
313
|
+
You did not yet use the `{graph_schema_tool_name}` tool to get the schema
|
314
|
+
of the neo4j knowledge-graph db. Use that tool first before using
|
315
|
+
the `{cypher_retrieval_tool_name}` tool, to ensure you know all the correct
|
316
|
+
node labels, relationship types, and property keys available in
|
317
|
+
the database.
|
318
|
+
"""
|
319
|
+
elif not self.config.database_created:
|
320
|
+
return f"""
|
321
|
+
You have not yet created the Neo4j database.
|
322
|
+
Use the `{cypher_creation_tool_name}`
|
323
|
+
tool to create the database first before using the
|
324
|
+
`{cypher_retrieval_tool_name}` tool.
|
325
|
+
"""
|
326
|
+
query = msg.cypher_query
|
327
|
+
self.current_retrieval_cypher_query = query
|
328
|
+
logger.info(f"Executing Cypher query: {query}")
|
329
|
+
response = self.read_query(query)
|
330
|
+
if isinstance(response.data, list) and len(response.data) == 0:
|
331
|
+
return """
|
332
|
+
No results found; check if your query used the right label names --
|
333
|
+
remember these are case sensitive, so you have to use the exact label
|
334
|
+
names you found in the schema.
|
335
|
+
Or retry using one of the RETRY-SUGGESTIONS in your instructions.
|
336
|
+
"""
|
337
|
+
return str(response.data)
|
338
|
+
|
339
|
+
def cypher_creation_tool(self, msg: CypherCreationTool) -> str:
|
340
|
+
""" "
|
341
|
+
Handle a CypherCreationTool message by executing a Cypher query and
|
342
|
+
returning the result.
|
343
|
+
Args:
|
344
|
+
msg (CypherCreationTool): The tool-message to handle.
|
345
|
+
|
346
|
+
Returns:
|
347
|
+
str: The result of executing the cypher_query.
|
348
|
+
"""
|
349
|
+
query = msg.cypher_query
|
350
|
+
|
351
|
+
logger.info(f"Executing Cypher query: {query}")
|
352
|
+
response = self.write_query(query)
|
353
|
+
if response.success:
|
354
|
+
self.config.database_created = True
|
355
|
+
return "Cypher query executed successfully"
|
356
|
+
else:
|
357
|
+
return str(response.data)
|
358
|
+
|
359
|
+
# TODO: There are various ways to get the schema. The current one uses the func
|
360
|
+
# `read_query`, which requires post processing to identify whether the response upon
|
361
|
+
# the schema query is valid. Another way is to isolate this func from `read_query`.
|
362
|
+
# The current query works well. But we could use the queries here:
|
363
|
+
# https://github.com/neo4j/NaLLM/blob/1af09cd117ba0777d81075c597a5081583568f9f/api/
|
364
|
+
# src/driver/neo4j.py#L30
|
365
|
+
def graph_schema_tool(
|
366
|
+
self, msg: GraphSchemaTool | None
|
367
|
+
) -> str | Optional[Union[str, List[Dict[Any, Any]]]]:
|
368
|
+
"""
|
369
|
+
Retrieves the schema of a Neo4j graph database.
|
370
|
+
|
371
|
+
Args:
|
372
|
+
msg (GraphSchemaTool): An instance of GraphDatabaseSchema, typically
|
373
|
+
containing information or parameters needed for the database query.
|
374
|
+
|
375
|
+
Returns:
|
376
|
+
str: The visual representation of the database schema as a string, or a
|
377
|
+
message stating that the database schema is empty or not valid.
|
378
|
+
|
379
|
+
Raises:
|
380
|
+
This function does not explicitly raise exceptions but depends on the
|
381
|
+
behavior of 'self.read_query' method, which might raise exceptions related
|
382
|
+
to database connectivity or query execution.
|
383
|
+
"""
|
384
|
+
self.tried_schema = True
|
385
|
+
if self.config.kg_schema is not None and len(self.config.kg_schema) > 0:
|
386
|
+
return self.config.kg_schema
|
387
|
+
schema_result = self.read_query("CALL db.schema.visualization()")
|
388
|
+
if schema_result.success:
|
389
|
+
# there is a possibility that the schema is empty, which is a valid response
|
390
|
+
# the schema.data will be: [{"nodes": [], "relationships": []}]
|
391
|
+
self.config.kg_schema = schema_result.data # type: ignore
|
392
|
+
return schema_result.data
|
393
|
+
else:
|
394
|
+
return f"Failed to retrieve schema: {schema_result.data}"
|
395
|
+
|
396
|
+
def _init_tools_sys_message(self) -> None:
|
397
|
+
"""Initialize message tools used for chatting."""
|
398
|
+
self.tried_schema = False
|
399
|
+
message = self._format_message()
|
400
|
+
self.config.system_message = self.config.system_message.format(mode=message)
|
401
|
+
if self.config.chat_mode:
|
402
|
+
self.config.addressing_prefix = self.config.addressing_prefix or SEND_TO
|
403
|
+
self.config.system_message += ADDRESSING_INSTRUCTION.format(
|
404
|
+
prefix=self.config.addressing_prefix
|
405
|
+
)
|
406
|
+
else:
|
407
|
+
self.config.system_message += DONE_INSTRUCTION
|
408
|
+
super().__init__(self.config)
|
409
|
+
# Note we are enabling GraphSchemaTool regardless of whether
|
410
|
+
# self.config.use_schema_tools is True or False, because
|
411
|
+
# even when schema provided, the agent may later want to get the schema,
|
412
|
+
# e.g. if the db evolves, or if it needs to bring in the schema
|
413
|
+
self.enable_message(
|
414
|
+
[
|
415
|
+
GraphSchemaTool,
|
416
|
+
CypherRetrievalTool,
|
417
|
+
CypherCreationTool,
|
418
|
+
DoneTool,
|
419
|
+
]
|
420
|
+
)
|
421
|
+
|
422
|
+
def _format_message(self) -> str:
|
423
|
+
if self.driver is None:
|
424
|
+
raise ValueError("Database driver None")
|
425
|
+
assert isinstance(self.config, Neo4jChatAgentConfig)
|
426
|
+
return (
|
427
|
+
SCHEMA_TOOLS_SYS_MSG
|
428
|
+
if self.config.use_schema_tools
|
429
|
+
else SCHEMA_PROVIDED_SYS_MSG.format(schema=self.graph_schema_tool(None))
|
430
|
+
)
|
@@ -26,7 +26,7 @@ class RelevanceExtractorAgentConfig(ChatAgentConfig):
|
|
26
26
|
llm: LLMConfig | None = OpenAIGPTConfig()
|
27
27
|
segment_length: int = 1 # number of sentences per segment
|
28
28
|
query: str = "" # query for relevance extraction
|
29
|
-
system_message = """
|
29
|
+
system_message: str = """
|
30
30
|
The user will give you a PASSAGE containing segments numbered as
|
31
31
|
<#1#>, <#2#>, <#3#>, etc.,
|
32
32
|
followed by a QUERY. Extract ONLY the segment-numbers from
|
@@ -184,7 +184,7 @@ class SQLChatAgent(ChatAgent):
|
|
184
184
|
if self.config.use_helper:
|
185
185
|
# helper_config.system_message is now the fully-populated sys msg of
|
186
186
|
# the main SQLAgent.
|
187
|
-
self.helper_config = self.config.
|
187
|
+
self.helper_config = self.config.model_copy()
|
188
188
|
self.helper_config.is_helper = True
|
189
189
|
self.helper_config.use_helper = False
|
190
190
|
self.helper_config.chat_mode = False
|
@@ -271,8 +271,16 @@ class SQLChatAgent(ChatAgent):
|
|
271
271
|
|
272
272
|
def _init_tools(self) -> None:
|
273
273
|
"""Initialize sys msg and tools."""
|
274
|
-
RunQueryTool
|
275
|
-
self.
|
274
|
+
# Create a custom RunQueryTool class with the desired max_retained_tokens
|
275
|
+
if self.config.max_retained_tokens is not None:
|
276
|
+
|
277
|
+
class CustomRunQueryTool(RunQueryTool):
|
278
|
+
_max_retained_tokens = self.config.max_retained_tokens
|
279
|
+
|
280
|
+
self.enable_message([CustomRunQueryTool, ForwardTool])
|
281
|
+
else:
|
282
|
+
self.enable_message([RunQueryTool, ForwardTool])
|
283
|
+
|
276
284
|
if self.config.use_schema_tools:
|
277
285
|
self._enable_schema_tools()
|
278
286
|
if not self.config.chat_mode:
|