langroid 0.33.6__py3-none-any.whl → 0.33.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +106 -0
- langroid/agent/__init__.py +41 -0
- langroid/agent/base.py +1983 -0
- langroid/agent/batch.py +398 -0
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +598 -0
- langroid/agent/chat_agent.py +1899 -0
- langroid/agent/chat_document.py +454 -0
- langroid/agent/openai_assistant.py +882 -0
- langroid/agent/special/__init__.py +59 -0
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +656 -0
- langroid/agent/special/arangodb/system_messages.py +186 -0
- langroid/agent/special/arangodb/tools.py +107 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/doc_chat_agent.py +1466 -0
- langroid/agent/special/lance_doc_chat_agent.py +262 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +198 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
- langroid/agent/special/lance_tools.py +61 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +56 -0
- langroid/agent/special/sql/__init__.py +17 -0
- langroid/agent/special/sql/sql_chat_agent.py +654 -0
- langroid/agent/special/sql/utils/__init__.py +21 -0
- langroid/agent/special/sql/utils/description_extractors.py +190 -0
- langroid/agent/special/sql/utils/populate_metadata.py +85 -0
- langroid/agent/special/sql/utils/system_message.py +35 -0
- langroid/agent/special/sql/utils/tools.py +64 -0
- langroid/agent/special/table_chat_agent.py +263 -0
- langroid/agent/task.py +2095 -0
- langroid/agent/tool_message.py +393 -0
- langroid/agent/tools/__init__.py +38 -0
- langroid/agent/tools/duckduckgo_search_tool.py +50 -0
- langroid/agent/tools/file_tools.py +234 -0
- langroid/agent/tools/google_search_tool.py +39 -0
- langroid/agent/tools/metaphor_search_tool.py +68 -0
- langroid/agent/tools/orchestration.py +303 -0
- langroid/agent/tools/recipient_tool.py +235 -0
- langroid/agent/tools/retrieval_tool.py +32 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/agent/tools/segment_extract_tool.py +41 -0
- langroid/agent/xml_tool_message.py +382 -0
- langroid/cachedb/__init__.py +17 -0
- langroid/cachedb/base.py +58 -0
- langroid/cachedb/momento_cachedb.py +108 -0
- langroid/cachedb/redis_cachedb.py +153 -0
- langroid/embedding_models/__init__.py +39 -0
- langroid/embedding_models/base.py +74 -0
- langroid/embedding_models/models.py +461 -0
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/exceptions.py +71 -0
- langroid/language_models/__init__.py +53 -0
- langroid/language_models/azure_openai.py +153 -0
- langroid/language_models/base.py +678 -0
- langroid/language_models/config.py +18 -0
- langroid/language_models/mock_lm.py +124 -0
- langroid/language_models/openai_gpt.py +1964 -0
- langroid/language_models/prompt_formatter/__init__.py +16 -0
- langroid/language_models/prompt_formatter/base.py +40 -0
- langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
- langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
- langroid/language_models/utils.py +151 -0
- langroid/mytypes.py +84 -0
- langroid/parsing/__init__.py +52 -0
- langroid/parsing/agent_chats.py +38 -0
- langroid/parsing/code_parser.py +121 -0
- langroid/parsing/document_parser.py +718 -0
- langroid/parsing/para_sentence_split.py +62 -0
- langroid/parsing/parse_json.py +155 -0
- langroid/parsing/parser.py +313 -0
- langroid/parsing/repo_loader.py +790 -0
- langroid/parsing/routing.py +36 -0
- langroid/parsing/search.py +275 -0
- langroid/parsing/spider.py +102 -0
- langroid/parsing/table_loader.py +94 -0
- langroid/parsing/url_loader.py +111 -0
- langroid/parsing/urls.py +273 -0
- langroid/parsing/utils.py +373 -0
- langroid/parsing/web_search.py +156 -0
- langroid/prompts/__init__.py +9 -0
- langroid/prompts/dialog.py +17 -0
- langroid/prompts/prompts_config.py +5 -0
- langroid/prompts/templates.py +141 -0
- langroid/pydantic_v1/__init__.py +10 -0
- langroid/pydantic_v1/main.py +4 -0
- langroid/utils/__init__.py +19 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +98 -0
- langroid/utils/constants.py +30 -0
- langroid/utils/git_utils.py +252 -0
- langroid/utils/globals.py +49 -0
- langroid/utils/logging.py +135 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/output/__init__.py +20 -0
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +99 -0
- langroid/utils/output/status.py +40 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +602 -0
- langroid/utils/system.py +286 -0
- langroid/utils/types.py +93 -0
- langroid/vector_store/__init__.py +50 -0
- langroid/vector_store/base.py +359 -0
- langroid/vector_store/chromadb.py +214 -0
- langroid/vector_store/lancedb.py +406 -0
- langroid/vector_store/meilisearch.py +299 -0
- langroid/vector_store/momento.py +278 -0
- langroid/vector_store/qdrantdb.py +468 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
- langroid-0.33.7.dist-info/RECORD +127 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
- langroid-0.33.6.dist-info/RECORD +0 -7
- langroid-0.33.6.dist-info/entry_points.txt +0 -4
- pyproject.toml +0 -356
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,654 @@
|
|
1
|
+
"""
|
2
|
+
Agent that allows interaction with an SQL database using SQLAlchemy library.
|
3
|
+
The agent can execute SQL queries in the database and return the result.
|
4
|
+
|
5
|
+
Functionality includes:
|
6
|
+
- adding table and column context
|
7
|
+
- asking a question about a SQL schema
|
8
|
+
"""
|
9
|
+
|
10
|
+
import logging
|
11
|
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
12
|
+
|
13
|
+
from rich.console import Console
|
14
|
+
|
15
|
+
from langroid.exceptions import LangroidImportError
|
16
|
+
from langroid.utils.constants import SEND_TO
|
17
|
+
|
18
|
+
try:
|
19
|
+
from sqlalchemy import MetaData, Row, create_engine, inspect, text
|
20
|
+
from sqlalchemy.engine import Engine
|
21
|
+
from sqlalchemy.exc import ResourceClosedError, SQLAlchemyError
|
22
|
+
from sqlalchemy.orm import Session, sessionmaker
|
23
|
+
except ImportError as e:
|
24
|
+
raise LangroidImportError(extra="sql", error=str(e))
|
25
|
+
|
26
|
+
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
27
|
+
from langroid.agent.chat_document import ChatDocument
|
28
|
+
from langroid.agent.special.sql.utils.description_extractors import (
|
29
|
+
extract_schema_descriptions,
|
30
|
+
)
|
31
|
+
from langroid.agent.special.sql.utils.populate_metadata import (
|
32
|
+
populate_metadata,
|
33
|
+
populate_metadata_with_schema_tools,
|
34
|
+
)
|
35
|
+
from langroid.agent.special.sql.utils.system_message import (
|
36
|
+
DEFAULT_SYS_MSG,
|
37
|
+
SCHEMA_TOOLS_SYS_MSG,
|
38
|
+
)
|
39
|
+
from langroid.agent.special.sql.utils.tools import (
|
40
|
+
GetColumnDescriptionsTool,
|
41
|
+
GetTableNamesTool,
|
42
|
+
GetTableSchemaTool,
|
43
|
+
RunQueryTool,
|
44
|
+
)
|
45
|
+
from langroid.agent.tools.orchestration import (
|
46
|
+
DoneTool,
|
47
|
+
ForwardTool,
|
48
|
+
PassTool,
|
49
|
+
)
|
50
|
+
from langroid.vector_store.base import VectorStoreConfig
|
51
|
+
|
52
|
+
logger = logging.getLogger(__name__)
|
53
|
+
|
54
|
+
console = Console()
|
55
|
+
|
56
|
+
DEFAULT_SQL_CHAT_SYSTEM_MESSAGE = """
|
57
|
+
{mode}
|
58
|
+
|
59
|
+
You do not need to attempt answering a question with just one query.
|
60
|
+
You could make a sequence of SQL queries to help you write the final query.
|
61
|
+
Also if you receive a null or other unexpected result,
|
62
|
+
(a) make sure you use the available TOOLs correctly, and
|
63
|
+
(b) see if you have made an assumption in your SQL query, and try another way,
|
64
|
+
or use `run_query` to explore the database table contents before submitting your
|
65
|
+
final query. For example when searching for "males" you may have used "gender= 'M'",
|
66
|
+
in your query, because you did not know that the possible genders in the table
|
67
|
+
are "Male" and "Female".
|
68
|
+
|
69
|
+
Start by asking what I would like to know about the data.
|
70
|
+
|
71
|
+
"""
|
72
|
+
|
73
|
+
ADDRESSING_INSTRUCTION = """
|
74
|
+
IMPORTANT - Whenever you are NOT writing a SQL query, make sure you address the user
|
75
|
+
using {prefix}User. You MUST use the EXACT syntax {prefix} !!!
|
76
|
+
|
77
|
+
In other words, you ALWAYS write EITHER:
|
78
|
+
- a SQL query using the `run_query` tool,
|
79
|
+
- OR address the user using {prefix}User
|
80
|
+
"""
|
81
|
+
|
82
|
+
DONE_INSTRUCTION = f"""
|
83
|
+
When you are SURE you have the CORRECT answer to a user's query or request,
|
84
|
+
use the `{DoneTool.name()}` with `content` set to the answer or result.
|
85
|
+
If you DO NOT think you have the answer to the user's query or request,
|
86
|
+
you SHOULD NOT use the `{DoneTool.name()}` tool.
|
87
|
+
Instead, you must CONTINUE to improve your queries (tools) to get the correct answer,
|
88
|
+
and finally use the `{DoneTool.name()}` tool to send the correct answer to the user.
|
89
|
+
"""
|
90
|
+
|
91
|
+
|
92
|
+
SQL_ERROR_MSG = "There was an error in your SQL Query"
|
93
|
+
|
94
|
+
|
95
|
+
class SQLChatAgentConfig(ChatAgentConfig):
|
96
|
+
system_message: str = DEFAULT_SQL_CHAT_SYSTEM_MESSAGE
|
97
|
+
user_message: None | str = None
|
98
|
+
cache: bool = True # cache results
|
99
|
+
debug: bool = False
|
100
|
+
use_helper: bool = True
|
101
|
+
is_helper: bool = False
|
102
|
+
stream: bool = True # allow streaming where needed
|
103
|
+
database_uri: str = "" # Database URI
|
104
|
+
database_session: None | Session = None # Database session
|
105
|
+
vecdb: None | VectorStoreConfig = None
|
106
|
+
context_descriptions: Dict[str, Dict[str, Union[str, Dict[str, str]]]] = {}
|
107
|
+
use_schema_tools: bool = False
|
108
|
+
multi_schema: bool = False
|
109
|
+
# whether the agent is used in a continuous chat with user,
|
110
|
+
# as opposed to returning a result from the task.run()
|
111
|
+
chat_mode: bool = False
|
112
|
+
addressing_prefix: str = ""
|
113
|
+
|
114
|
+
"""
|
115
|
+
Optional, but strongly recommended, context descriptions for tables, columns,
|
116
|
+
and relationships. It should be a dictionary where each key is a table name
|
117
|
+
and its value is another dictionary.
|
118
|
+
|
119
|
+
In this inner dictionary:
|
120
|
+
- The 'description' key corresponds to a string description of the table.
|
121
|
+
- The 'columns' key corresponds to another dictionary where each key is a
|
122
|
+
column name and its value is a string description of that column.
|
123
|
+
- The 'relationships' key corresponds to another dictionary where each key
|
124
|
+
is another table name and the value is a description of the relationship to
|
125
|
+
that table.
|
126
|
+
|
127
|
+
If multi_schema support is enabled, the tables names in the description
|
128
|
+
should be of the form 'schema_name.table_name'.
|
129
|
+
|
130
|
+
For example:
|
131
|
+
{
|
132
|
+
'table1': {
|
133
|
+
'description': 'description of table1',
|
134
|
+
'columns': {
|
135
|
+
'column1': 'description of column1 in table1',
|
136
|
+
'column2': 'description of column2 in table1'
|
137
|
+
}
|
138
|
+
},
|
139
|
+
'table2': {
|
140
|
+
'description': 'description of table2',
|
141
|
+
'columns': {
|
142
|
+
'column3': 'description of column3 in table2',
|
143
|
+
'column4': 'description of column4 in table2'
|
144
|
+
}
|
145
|
+
}
|
146
|
+
}
|
147
|
+
"""
|
148
|
+
|
149
|
+
|
150
|
+
class SQLChatAgent(ChatAgent):
|
151
|
+
"""
|
152
|
+
Agent for chatting with a SQL database
|
153
|
+
"""
|
154
|
+
|
155
|
+
used_run_query: bool = False
|
156
|
+
llm_responded: bool = False
|
157
|
+
|
158
|
+
def __init__(self, config: "SQLChatAgentConfig") -> None:
|
159
|
+
"""Initialize the SQLChatAgent.
|
160
|
+
|
161
|
+
Raises:
|
162
|
+
ValueError: If database information is not provided in the config.
|
163
|
+
"""
|
164
|
+
self._validate_config(config)
|
165
|
+
self.config: SQLChatAgentConfig = config
|
166
|
+
self._init_database()
|
167
|
+
self._init_metadata()
|
168
|
+
self._init_table_metadata()
|
169
|
+
self.final_instructions = ""
|
170
|
+
|
171
|
+
# Caution - this updates the self.config.system_message!
|
172
|
+
self._init_system_message()
|
173
|
+
super().__init__(config)
|
174
|
+
self._init_tools()
|
175
|
+
if self.config.is_helper:
|
176
|
+
self.system_tool_format_instructions += self.final_instructions
|
177
|
+
|
178
|
+
if self.config.use_helper:
|
179
|
+
# helper_config.system_message is now the fully-populated sys msg of
|
180
|
+
# the main SQLAgent.
|
181
|
+
self.helper_config = self.config.copy()
|
182
|
+
self.helper_config.is_helper = True
|
183
|
+
self.helper_config.use_helper = False
|
184
|
+
self.helper_agent = SQLHelperAgent(self.helper_config)
|
185
|
+
|
186
|
+
def _validate_config(self, config: "SQLChatAgentConfig") -> None:
|
187
|
+
"""Validate the configuration to ensure all necessary fields are present."""
|
188
|
+
if config.database_session is None and config.database_uri is None:
|
189
|
+
raise ValueError("Database information must be provided")
|
190
|
+
|
191
|
+
def _init_database(self) -> None:
|
192
|
+
"""Initialize the database engine and session."""
|
193
|
+
if self.config.database_session:
|
194
|
+
self.Session = self.config.database_session
|
195
|
+
self.engine = self.Session.bind
|
196
|
+
else:
|
197
|
+
self.engine = create_engine(self.config.database_uri)
|
198
|
+
self.Session = sessionmaker(bind=self.engine)()
|
199
|
+
|
200
|
+
def _init_metadata(self) -> None:
|
201
|
+
"""Initialize the database metadata."""
|
202
|
+
if self.engine is None:
|
203
|
+
raise ValueError("Database engine is None")
|
204
|
+
self.metadata: MetaData | List[MetaData] = []
|
205
|
+
|
206
|
+
if self.config.multi_schema:
|
207
|
+
logger.info(
|
208
|
+
"Initializing SQLChatAgent with database: %s",
|
209
|
+
self.engine,
|
210
|
+
)
|
211
|
+
|
212
|
+
self.metadata = []
|
213
|
+
inspector = inspect(self.engine)
|
214
|
+
|
215
|
+
for schema in inspector.get_schema_names():
|
216
|
+
metadata = MetaData(schema=schema)
|
217
|
+
metadata.reflect(self.engine)
|
218
|
+
self.metadata.append(metadata)
|
219
|
+
|
220
|
+
logger.info(
|
221
|
+
"Initializing SQLChatAgent with database: %s, schema: %s, "
|
222
|
+
"and tables: %s",
|
223
|
+
self.engine,
|
224
|
+
schema,
|
225
|
+
metadata.tables,
|
226
|
+
)
|
227
|
+
else:
|
228
|
+
self.metadata = MetaData()
|
229
|
+
self.metadata.reflect(self.engine)
|
230
|
+
logger.info(
|
231
|
+
"SQLChatAgent initialized with database: %s and tables: %s",
|
232
|
+
self.engine,
|
233
|
+
self.metadata.tables,
|
234
|
+
)
|
235
|
+
|
236
|
+
def _init_table_metadata(self) -> None:
|
237
|
+
"""Initialize metadata for the tables present in the database."""
|
238
|
+
if not self.config.context_descriptions and isinstance(self.engine, Engine):
|
239
|
+
self.config.context_descriptions = extract_schema_descriptions(
|
240
|
+
self.engine, self.config.multi_schema
|
241
|
+
)
|
242
|
+
|
243
|
+
if self.config.use_schema_tools:
|
244
|
+
self.table_metadata = populate_metadata_with_schema_tools(
|
245
|
+
self.metadata, self.config.context_descriptions
|
246
|
+
)
|
247
|
+
else:
|
248
|
+
self.table_metadata = populate_metadata(
|
249
|
+
self.metadata, self.config.context_descriptions
|
250
|
+
)
|
251
|
+
|
252
|
+
def _init_system_message(self) -> None:
|
253
|
+
"""Initialize the system message."""
|
254
|
+
message = self._format_message()
|
255
|
+
self.config.system_message = self.config.system_message.format(mode=message)
|
256
|
+
|
257
|
+
if self.config.chat_mode:
|
258
|
+
self.config.addressing_prefix = self.config.addressing_prefix or SEND_TO
|
259
|
+
self.config.system_message += ADDRESSING_INSTRUCTION.format(
|
260
|
+
prefix=self.config.addressing_prefix
|
261
|
+
)
|
262
|
+
else:
|
263
|
+
self.config.system_message += DONE_INSTRUCTION
|
264
|
+
|
265
|
+
def _init_tools(self) -> None:
|
266
|
+
"""Initialize sys msg and tools."""
|
267
|
+
self.enable_message([RunQueryTool, ForwardTool])
|
268
|
+
if self.config.use_schema_tools:
|
269
|
+
self._enable_schema_tools()
|
270
|
+
if not self.config.chat_mode:
|
271
|
+
self.enable_message(DoneTool)
|
272
|
+
|
273
|
+
def _format_message(self) -> str:
|
274
|
+
if self.engine is None:
|
275
|
+
raise ValueError("Database engine is None")
|
276
|
+
|
277
|
+
"""Format the system message based on the engine and table metadata."""
|
278
|
+
return (
|
279
|
+
SCHEMA_TOOLS_SYS_MSG.format(dialect=self.engine.dialect.name)
|
280
|
+
if self.config.use_schema_tools
|
281
|
+
else DEFAULT_SYS_MSG.format(
|
282
|
+
dialect=self.engine.dialect.name, schema_dict=self.table_metadata
|
283
|
+
)
|
284
|
+
)
|
285
|
+
|
286
|
+
def _enable_schema_tools(self) -> None:
|
287
|
+
"""Enable tools for schema-related functionalities."""
|
288
|
+
self.enable_message(GetTableNamesTool)
|
289
|
+
self.enable_message(GetTableSchemaTool)
|
290
|
+
self.enable_message(GetColumnDescriptionsTool)
|
291
|
+
|
292
|
+
def llm_response(
|
293
|
+
self, message: Optional[str | ChatDocument] = None
|
294
|
+
) -> Optional[ChatDocument]:
|
295
|
+
self.llm_responded = True
|
296
|
+
self.used_run_query = False
|
297
|
+
return super().llm_response(message)
|
298
|
+
|
299
|
+
def user_response(
|
300
|
+
self,
|
301
|
+
msg: Optional[str | ChatDocument] = None,
|
302
|
+
) -> Optional[ChatDocument]:
|
303
|
+
self.llm_responded = False
|
304
|
+
self.used_run_query = False
|
305
|
+
return super().user_response(msg)
|
306
|
+
|
307
|
+
def _clarify_answer_instruction(self) -> str:
|
308
|
+
"""
|
309
|
+
Prompt to use when asking LLM to clarify intent of
|
310
|
+
an already-generated response
|
311
|
+
"""
|
312
|
+
if self.config.chat_mode:
|
313
|
+
return f"""
|
314
|
+
you must use the `{ForwardTool.name()}` with the `agent`
|
315
|
+
parameter set to "User"
|
316
|
+
"""
|
317
|
+
else:
|
318
|
+
return f"""
|
319
|
+
you must use the `{DoneTool.name()}` with the `content`
|
320
|
+
set to the answer or result
|
321
|
+
"""
|
322
|
+
|
323
|
+
def _clarifying_message(self) -> str:
|
324
|
+
tools_instruction = f"""
|
325
|
+
For example you may want to use the TOOL
|
326
|
+
`{RunQueryTool.name()}` to further explore the database contents
|
327
|
+
"""
|
328
|
+
if self.config.use_schema_tools:
|
329
|
+
tools_instruction += """
|
330
|
+
OR you may want to use one of the schema tools to
|
331
|
+
explore the database schema
|
332
|
+
"""
|
333
|
+
return f"""
|
334
|
+
The intent of your response is not clear:
|
335
|
+
- if you intended this to be the FINAL answer to the user's query,
|
336
|
+
{self._clarify_answer_instruction()}
|
337
|
+
- otherwise, use one of the available tools to make progress
|
338
|
+
to arrive at the final answer.
|
339
|
+
{tools_instruction}
|
340
|
+
"""
|
341
|
+
|
342
|
+
def handle_message_fallback(
|
343
|
+
self, message: str | ChatDocument
|
344
|
+
) -> str | ForwardTool | ChatDocument | None:
|
345
|
+
"""
|
346
|
+
Handle the scenario where current msg is not a tool.
|
347
|
+
Special handling is only needed if the message was from the LLM
|
348
|
+
(as indicated by self.llm_responded).
|
349
|
+
"""
|
350
|
+
if not self.llm_responded:
|
351
|
+
return None
|
352
|
+
if self.interactive:
|
353
|
+
# self.interactive will be set to True by the Task,
|
354
|
+
# when chat_mode=True, so in this case
|
355
|
+
# we send any Non-tool msg to the user
|
356
|
+
return ForwardTool(agent="User")
|
357
|
+
# Agent intent not clear => use the helper agent to
|
358
|
+
# do what this agent should have done, e.g. generate tool, etc.
|
359
|
+
# This is likelier to succeed since this agent has no "baggage" of
|
360
|
+
# prior conversation, other than the system msg, and special
|
361
|
+
# "Intent-interpretation" instructions.
|
362
|
+
if self._json_schema_available():
|
363
|
+
AnyTool = self._get_any_tool_message(optional=False)
|
364
|
+
self.set_output_format(
|
365
|
+
AnyTool,
|
366
|
+
force_tools=True,
|
367
|
+
use=True,
|
368
|
+
handle=True,
|
369
|
+
instructions=True,
|
370
|
+
)
|
371
|
+
recovery_message = self._strict_recovery_instructions(
|
372
|
+
AnyTool, optional=False
|
373
|
+
)
|
374
|
+
return self.llm_response(recovery_message)
|
375
|
+
else:
|
376
|
+
response = self.helper_agent.llm_response(message)
|
377
|
+
tools = self.try_get_tool_messages(response)
|
378
|
+
if tools:
|
379
|
+
return response
|
380
|
+
else:
|
381
|
+
# fall back on the clarification message
|
382
|
+
return self._clarifying_message()
|
383
|
+
|
384
|
+
def retry_query(self, e: Exception, query: str) -> str:
|
385
|
+
"""
|
386
|
+
Generate an error message for a failed SQL query and return it.
|
387
|
+
|
388
|
+
Parameters:
|
389
|
+
e (Exception): The exception raised during the SQL query execution.
|
390
|
+
query (str): The SQL query that failed.
|
391
|
+
|
392
|
+
Returns:
|
393
|
+
str: The error message.
|
394
|
+
"""
|
395
|
+
logger.error(f"SQL Query failed: {query}\nException: {e}")
|
396
|
+
|
397
|
+
# Optional part to be included based on `use_schema_tools`
|
398
|
+
optional_schema_description = ""
|
399
|
+
if not self.config.use_schema_tools:
|
400
|
+
optional_schema_description = f"""\
|
401
|
+
This JSON schema maps SQL database structure. It outlines tables, each
|
402
|
+
with a description and columns. Each table is identified by a key, and holds
|
403
|
+
a description and a dictionary of columns, with column
|
404
|
+
names as keys and their descriptions as values.
|
405
|
+
|
406
|
+
```json
|
407
|
+
{self.config.context_descriptions}
|
408
|
+
```"""
|
409
|
+
|
410
|
+
# Construct the error message
|
411
|
+
error_message_template = f"""\
|
412
|
+
{SQL_ERROR_MSG}: '{query}'
|
413
|
+
{str(e)}
|
414
|
+
Run a new query, correcting the errors.
|
415
|
+
{optional_schema_description}"""
|
416
|
+
|
417
|
+
return error_message_template
|
418
|
+
|
419
|
+
def _available_tool_names(self) -> str:
|
420
|
+
return ",".join(
|
421
|
+
tool.name() # type: ignore
|
422
|
+
for tool in [
|
423
|
+
RunQueryTool,
|
424
|
+
GetTableNamesTool,
|
425
|
+
GetTableSchemaTool,
|
426
|
+
GetColumnDescriptionsTool,
|
427
|
+
]
|
428
|
+
)
|
429
|
+
|
430
|
+
def _tool_result_llm_answer_prompt(self) -> str:
|
431
|
+
"""
|
432
|
+
Prompt to use at end of tool result,
|
433
|
+
to guide LLM, for the case where it wants to answer the user's query
|
434
|
+
"""
|
435
|
+
if self.config.chat_mode:
|
436
|
+
assert self.config.addressing_prefix != ""
|
437
|
+
return """
|
438
|
+
You must EXPLICITLY address the User with
|
439
|
+
the addressing prefix according to your instructions,
|
440
|
+
to convey your answer to the User.
|
441
|
+
"""
|
442
|
+
else:
|
443
|
+
return f"""
|
444
|
+
you must use the `{DoneTool.name()}` with the `content`
|
445
|
+
set to the answer or result
|
446
|
+
"""
|
447
|
+
|
448
|
+
def run_query(self, msg: RunQueryTool) -> str:
|
449
|
+
"""
|
450
|
+
Handle a RunQueryTool message by executing a SQL query and returning the result.
|
451
|
+
|
452
|
+
Args:
|
453
|
+
msg (RunQueryTool): The tool-message to handle.
|
454
|
+
|
455
|
+
Returns:
|
456
|
+
str: The result of executing the SQL query.
|
457
|
+
"""
|
458
|
+
query = msg.query
|
459
|
+
session = self.Session
|
460
|
+
self.used_run_query = True
|
461
|
+
try:
|
462
|
+
logger.info(f"Executing SQL query: {query}")
|
463
|
+
|
464
|
+
query_result = session.execute(text(query))
|
465
|
+
session.commit()
|
466
|
+
try:
|
467
|
+
# attempt to fetch results: should work for normal SELECT queries
|
468
|
+
rows = query_result.fetchall()
|
469
|
+
response_message = self._format_rows(rows)
|
470
|
+
except ResourceClosedError:
|
471
|
+
# If we get here, it's a non-SELECT query (UPDATE, INSERT, DELETE)
|
472
|
+
affected_rows = query_result.rowcount # type: ignore
|
473
|
+
response_message = f"""
|
474
|
+
Non-SELECT query executed successfully.
|
475
|
+
Rows affected: {affected_rows}
|
476
|
+
"""
|
477
|
+
|
478
|
+
except SQLAlchemyError as e:
|
479
|
+
session.rollback()
|
480
|
+
logger.error(f"Failed to execute query: {query}\n{e}")
|
481
|
+
response_message = self.retry_query(e, query)
|
482
|
+
finally:
|
483
|
+
session.close()
|
484
|
+
|
485
|
+
final_message = f"""
|
486
|
+
Below is the result from your use of the TOOL `{RunQueryTool.name()}`:
|
487
|
+
==== result ====
|
488
|
+
{response_message}
|
489
|
+
================
|
490
|
+
|
491
|
+
If you are READY to ANSWER the ORIGINAL QUERY:
|
492
|
+
{self._tool_result_llm_answer_prompt()}
|
493
|
+
OTHERWISE:
|
494
|
+
continue using one of your available TOOLs:
|
495
|
+
{self._available_tool_names()}
|
496
|
+
"""
|
497
|
+
return final_message
|
498
|
+
|
499
|
+
def _format_rows(self, rows: Sequence[Row[Any]]) -> str:
|
500
|
+
"""
|
501
|
+
Format the rows fetched from the query result into a string.
|
502
|
+
|
503
|
+
Args:
|
504
|
+
rows (list): List of rows fetched from the query result.
|
505
|
+
|
506
|
+
Returns:
|
507
|
+
str: Formatted string representation of rows.
|
508
|
+
"""
|
509
|
+
# TODO: UPDATE FORMATTING
|
510
|
+
return (
|
511
|
+
",\n".join(str(row) for row in rows)
|
512
|
+
if rows
|
513
|
+
else "Query executed successfully."
|
514
|
+
)
|
515
|
+
|
516
|
+
def get_table_names(self, msg: GetTableNamesTool) -> str:
|
517
|
+
"""
|
518
|
+
Handle a GetTableNamesTool message by returning the names of all tables in the
|
519
|
+
database.
|
520
|
+
|
521
|
+
Returns:
|
522
|
+
str: The names of all tables in the database.
|
523
|
+
"""
|
524
|
+
if isinstance(self.metadata, list):
|
525
|
+
table_names = [", ".join(md.tables.keys()) for md in self.metadata]
|
526
|
+
return ", ".join(table_names)
|
527
|
+
|
528
|
+
return ", ".join(self.metadata.tables.keys())
|
529
|
+
|
530
|
+
def get_table_schema(self, msg: GetTableSchemaTool) -> str:
|
531
|
+
"""
|
532
|
+
Handle a GetTableSchemaTool message by returning the schema of all provided
|
533
|
+
tables in the database.
|
534
|
+
|
535
|
+
Returns:
|
536
|
+
str: The schema of all provided tables in the database.
|
537
|
+
"""
|
538
|
+
tables = msg.tables
|
539
|
+
result = ""
|
540
|
+
for table_name in tables:
|
541
|
+
table = self.table_metadata.get(table_name)
|
542
|
+
if table is not None:
|
543
|
+
result += f"{table_name}: {table}\n"
|
544
|
+
else:
|
545
|
+
result += f"{table_name} is not a valid table name.\n"
|
546
|
+
return result
|
547
|
+
|
548
|
+
def get_column_descriptions(self, msg: GetColumnDescriptionsTool) -> str:
|
549
|
+
"""
|
550
|
+
Handle a GetColumnDescriptionsTool message by returning the descriptions of all
|
551
|
+
provided columns from the database.
|
552
|
+
|
553
|
+
Returns:
|
554
|
+
str: The descriptions of all provided columns from the database.
|
555
|
+
"""
|
556
|
+
table = msg.table
|
557
|
+
columns = msg.columns.split(", ")
|
558
|
+
result = f"\nTABLE: {table}"
|
559
|
+
descriptions = self.config.context_descriptions.get(table)
|
560
|
+
|
561
|
+
for col in columns:
|
562
|
+
result += f"\n{col} => {descriptions['columns'][col]}" # type: ignore
|
563
|
+
return result
|
564
|
+
|
565
|
+
|
566
|
+
class SQLHelperAgent(SQLChatAgent):
|
567
|
+
|
568
|
+
def _clarifying_message(self) -> str:
|
569
|
+
tools_instruction = f"""
|
570
|
+
For example the Agent may have forgotten to use the TOOL
|
571
|
+
`{RunQueryTool.name()}` to further explore the database contents
|
572
|
+
"""
|
573
|
+
if self.config.use_schema_tools:
|
574
|
+
tools_instruction += """
|
575
|
+
OR the agent may have forgotten to use one of the schema tools to
|
576
|
+
explore the database schema
|
577
|
+
"""
|
578
|
+
|
579
|
+
return f"""
|
580
|
+
The intent of the Agent's response is not clear:
|
581
|
+
- if you think the Agent intended this as ANSWER to the
|
582
|
+
user's query,
|
583
|
+
{self._clarify_answer_instruction()}
|
584
|
+
- otherwise, the Agent may have forgotten to
|
585
|
+
use one of the available tools to make progress
|
586
|
+
to arrive at the final answer.
|
587
|
+
{tools_instruction}
|
588
|
+
"""
|
589
|
+
|
590
|
+
def _init_system_message(self) -> None:
|
591
|
+
"""Set up helper sys msg"""
|
592
|
+
|
593
|
+
# Note that self.config.system_message is already set to the
|
594
|
+
# parent SQLAgent's system_message
|
595
|
+
self.config.system_message = f"""
|
596
|
+
You role is to help INTERPRET the INTENT of an
|
597
|
+
AI agent in a conversation. This Agent was supposed to generate
|
598
|
+
a TOOL/Function-call but forgot to do so, and this is where
|
599
|
+
you can help, by trying to generate the appropriate TOOL
|
600
|
+
based on your best guess of the Agent's INTENT.
|
601
|
+
|
602
|
+
Below are the instructions that were given to this Agent:
|
603
|
+
===== AGENT INSTRUCTIONS =====
|
604
|
+
{self.config.system_message}
|
605
|
+
===== END OF AGENT INSTRUCTIONS =====
|
606
|
+
"""
|
607
|
+
|
608
|
+
# note that the initial msg in chat history will contain:
|
609
|
+
# - system message
|
610
|
+
# - tool instructions
|
611
|
+
# so the final_instructions will be at the end of this initial msg
|
612
|
+
|
613
|
+
self.final_instructions = f"""
|
614
|
+
You must take note especially of the TOOLs that are
|
615
|
+
available to the Agent. Your reasoning process should be as follows:
|
616
|
+
|
617
|
+
- If the Agent's message appears to be an ANSWER to the original query,
|
618
|
+
{self._clarify_answer_instruction()}.
|
619
|
+
CAUTION - You must be absolutely sure that the Agent's message is
|
620
|
+
an ACTUAL ANSWER to the user's query, and not a failed attempt to use
|
621
|
+
a TOOL without JSON, e.g. something like "run_query" or "done_tool"
|
622
|
+
without any actual JSON formatting.
|
623
|
+
|
624
|
+
- Else, if you think the Agent intended to use some type of SQL
|
625
|
+
query tool to READ or UPDATE the table(s),
|
626
|
+
AND it is clear WHICH TOOL is intended as well as the
|
627
|
+
TOOL PARAMETERS, then you must generate the JSON-Formatted
|
628
|
+
TOOL with the parameters set based on your understanding.
|
629
|
+
Note that the `{RunQueryTool.name()}` is not ONLY for querying the tables,
|
630
|
+
but also for UPDATING the tables.
|
631
|
+
|
632
|
+
- Else, use the `{PassTool.name()}` to pass the message unchanged.
|
633
|
+
CAUTION - ONLY use `{PassTool.name()}` if you think the Agent's response
|
634
|
+
is NEITHER an ANSWER, nor an intended SQL QUERY.
|
635
|
+
"""
|
636
|
+
|
637
|
+
def llm_response(
|
638
|
+
self, message: Optional[str | ChatDocument] = None
|
639
|
+
) -> Optional[ChatDocument]:
|
640
|
+
if message is None:
|
641
|
+
return None
|
642
|
+
message_str = message if isinstance(message, str) else message.content
|
643
|
+
instruc_msg = f"""
|
644
|
+
Below is the MESSAGE from the SQL Agent.
|
645
|
+
Remember your instructions on how to respond based on your understanding
|
646
|
+
of the INTENT of this message:
|
647
|
+
{self.final_instructions}
|
648
|
+
|
649
|
+
=== AGENT MESSAGE =========
|
650
|
+
{message_str}
|
651
|
+
=== END OF AGENT MESSAGE ===
|
652
|
+
"""
|
653
|
+
# user response_forget to avoid accumulating the chat history
|
654
|
+
return super().llm_response_forget(instruc_msg)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from . import tools
|
2
|
+
from . import description_extractors
|
3
|
+
from . import populate_metadata
|
4
|
+
from . import system_message
|
5
|
+
from .tools import (
|
6
|
+
RunQueryTool,
|
7
|
+
GetTableNamesTool,
|
8
|
+
GetTableSchemaTool,
|
9
|
+
GetColumnDescriptionsTool,
|
10
|
+
)
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"RunQueryTool",
|
14
|
+
"GetTableNamesTool",
|
15
|
+
"GetTableSchemaTool",
|
16
|
+
"GetColumnDescriptionsTool",
|
17
|
+
"description_extractors",
|
18
|
+
"populate_metadata",
|
19
|
+
"system_message",
|
20
|
+
"tools",
|
21
|
+
]
|