langroid 0.1.139__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +70 -0
- langroid/agent/__init__.py +22 -0
- langroid/agent/base.py +120 -33
- langroid/agent/batch.py +134 -35
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +164 -100
- langroid/agent/chat_document.py +19 -2
- langroid/agent/openai_assistant.py +20 -10
- langroid/agent/special/__init__.py +33 -10
- langroid/agent/special/doc_chat_agent.py +521 -108
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +23 -7
- langroid/agent/special/retriever_agent.py +29 -174
- langroid/agent/special/sql/__init__.py +7 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +11 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +423 -114
- langroid/agent/tool_message.py +67 -10
- langroid/agent/tools/__init__.py +8 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +6 -24
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/cachedb/__init__.py +6 -0
- langroid/embedding_models/__init__.py +24 -0
- langroid/embedding_models/base.py +9 -1
- langroid/embedding_models/models.py +117 -17
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +22 -0
- langroid/language_models/azure_openai.py +47 -4
- langroid/language_models/base.py +26 -10
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_gpt.py +407 -121
- langroid/language_models/prompt_formatter/__init__.py +9 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +10 -9
- langroid/mytypes.py +10 -4
- langroid/parsing/__init__.py +33 -1
- langroid/parsing/document_parser.py +259 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +20 -7
- langroid/parsing/repo_loader.py +108 -46
- langroid/parsing/search.py +8 -0
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -13
- langroid/parsing/urls.py +18 -9
- langroid/parsing/utils.py +130 -9
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +7 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +10 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/configuration.py +0 -1
- langroid/utils/constants.py +4 -0
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +15 -2
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +446 -4
- langroid/utils/system.py +36 -1
- langroid/vector_store/__init__.py +34 -2
- langroid/vector_store/base.py +33 -2
- langroid/vector_store/chromadb.py +42 -13
- langroid/vector_store/lancedb.py +226 -60
- langroid/vector_store/meilisearch.py +7 -6
- langroid/vector_store/momento.py +3 -2
- langroid/vector_store/qdrantdb.py +82 -11
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/METADATA +190 -129
- langroid-0.1.219.dist-info/RECORD +127 -0
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.139.dist-info/RECORD +0 -103
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/WHEEL +0 -0
@@ -3,6 +3,7 @@ Agent to retrieve relevant segments from a body of text,
|
|
3
3
|
that are relevant to a query.
|
4
4
|
|
5
5
|
"""
|
6
|
+
|
6
7
|
import logging
|
7
8
|
from typing import Optional, no_type_check
|
8
9
|
|
@@ -11,16 +12,18 @@ from rich.console import Console
|
|
11
12
|
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
12
13
|
from langroid.agent.chat_document import ChatDocument
|
13
14
|
from langroid.agent.tools.segment_extract_tool import SegmentExtractTool
|
15
|
+
from langroid.language_models.base import LLMConfig
|
14
16
|
from langroid.language_models.openai_gpt import OpenAIGPTConfig
|
17
|
+
from langroid.mytypes import Entity
|
15
18
|
from langroid.parsing.utils import extract_numbered_segments, number_segments
|
16
|
-
from langroid.utils.constants import NO_ANSWER
|
19
|
+
from langroid.utils.constants import DONE, NO_ANSWER
|
17
20
|
|
18
21
|
console = Console()
|
19
22
|
logger = logging.getLogger(__name__)
|
20
23
|
|
21
24
|
|
22
25
|
class RelevanceExtractorAgentConfig(ChatAgentConfig):
|
23
|
-
llm:
|
26
|
+
llm: LLMConfig | None = OpenAIGPTConfig()
|
24
27
|
segment_length: int = 1 # number of sentences per segment
|
25
28
|
query: str = "" # query for relevance extraction
|
26
29
|
system_message = """
|
@@ -28,6 +31,7 @@ class RelevanceExtractorAgentConfig(ChatAgentConfig):
|
|
28
31
|
<#1#>, <#2#>, <#3#>, etc.,
|
29
32
|
followed by a QUERY. Extract ONLY the segment-numbers from
|
30
33
|
the PASSAGE that are RELEVANT to the QUERY.
|
34
|
+
Present the extracted segment-numbers using the `extract_segments` tool/function.
|
31
35
|
"""
|
32
36
|
|
33
37
|
|
@@ -101,11 +105,23 @@ class RelevanceExtractorAgent(ChatAgent):
|
|
101
105
|
"""Method to handle a segmentExtractTool message from LLM"""
|
102
106
|
spec = msg.segment_list
|
103
107
|
if len(self.message_history) == 0:
|
104
|
-
return NO_ANSWER
|
105
|
-
if spec is None or spec.strip()
|
106
|
-
return NO_ANSWER
|
108
|
+
return DONE + " " + NO_ANSWER
|
109
|
+
if spec is None or spec.strip() in ["", NO_ANSWER]:
|
110
|
+
return DONE + " " + NO_ANSWER
|
107
111
|
assert self.numbered_passage is not None, "No numbered passage"
|
108
112
|
# assume this has numbered segments
|
109
|
-
|
113
|
+
try:
|
114
|
+
extracts = extract_numbered_segments(self.numbered_passage, spec)
|
115
|
+
except Exception:
|
116
|
+
return DONE + " " + NO_ANSWER
|
110
117
|
# this response ends the task by saying DONE
|
111
|
-
return
|
118
|
+
return DONE + " " + extracts
|
119
|
+
|
120
|
+
def handle_message_fallback(
|
121
|
+
self, msg: str | ChatDocument
|
122
|
+
) -> str | ChatDocument | None:
|
123
|
+
"""Handle case where LLM forgets to use SegmentExtractTool"""
|
124
|
+
if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM:
|
125
|
+
return DONE + " " + NO_ANSWER
|
126
|
+
else:
|
127
|
+
return None
|
@@ -1,201 +1,56 @@
|
|
1
1
|
"""
|
2
|
-
|
2
|
+
Deprecated: use DocChatAgent instead, with DocChatAgentConfig.retrieve_only=True,
|
3
|
+
and if you want to retrieve FULL relevant doc-contents rather than just extracts,
|
4
|
+
then set DocChatAgentConfig.extraction_granularity=-1
|
5
|
+
|
6
|
+
This is an agent to retrieve relevant extracts from a vector store,
|
3
7
|
where the LLM is used to filter for "true" relevance after retrieval from the
|
4
8
|
vector store.
|
9
|
+
This is essentially the same as DocChatAgent, except that instead of
|
10
|
+
generating final summary answer based on relevant extracts, it just returns
|
11
|
+
those extracts.
|
5
12
|
See test_retriever_agent.py for example usage.
|
6
13
|
"""
|
14
|
+
|
7
15
|
import logging
|
8
|
-
from
|
9
|
-
from typing import List, Optional, Sequence
|
16
|
+
from typing import Sequence
|
10
17
|
|
11
|
-
from rich import print
|
12
18
|
from rich.console import Console
|
13
19
|
|
14
|
-
from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
|
15
20
|
from langroid.agent.special.doc_chat_agent import DocChatAgent, DocChatAgentConfig
|
16
|
-
from langroid.
|
17
|
-
from langroid.language_models.base import StreamingIfAllowed
|
18
|
-
from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
|
19
|
-
from langroid.mytypes import DocMetaData, Document, Entity
|
20
|
-
from langroid.parsing.parser import ParsingConfig, Splitter
|
21
|
-
from langroid.prompts.prompts_config import PromptsConfig
|
22
|
-
from langroid.utils.constants import NO_ANSWER
|
23
|
-
from langroid.vector_store.base import VectorStoreConfig
|
24
|
-
from langroid.vector_store.qdrantdb import QdrantDBConfig
|
21
|
+
from langroid.mytypes import DocMetaData, Document
|
25
22
|
|
26
23
|
console = Console()
|
27
24
|
logger = logging.getLogger(__name__)
|
28
25
|
|
26
|
+
# for backwards compatibility:
|
27
|
+
RecordMetadata = DocMetaData
|
28
|
+
RecordDoc = Document
|
29
|
+
RetrieverAgentConfig = DocChatAgentConfig
|
29
30
|
|
30
|
-
class RecordMetadata(DocMetaData):
|
31
|
-
id: None | str = None
|
32
|
-
|
33
|
-
|
34
|
-
class RecordDoc(Document):
|
35
|
-
metadata: RecordMetadata
|
36
|
-
|
37
|
-
|
38
|
-
class RetrieverAgentConfig(DocChatAgentConfig):
|
39
|
-
n_matches: int = 3
|
40
|
-
debug: bool = False
|
41
|
-
max_context_tokens = 500
|
42
|
-
conversation_mode = True
|
43
|
-
cache: bool = True # cache results
|
44
|
-
gpt4: bool = True # use GPT-4
|
45
|
-
stream: bool = True # allow streaming where needed
|
46
|
-
max_tokens: int = 10000
|
47
|
-
vecdb: VectorStoreConfig = QdrantDBConfig(
|
48
|
-
collection_name=None,
|
49
|
-
storage_path=".qdrant/data/",
|
50
|
-
embedding=OpenAIEmbeddingsConfig(
|
51
|
-
model_type="openai",
|
52
|
-
model_name="text-embedding-ada-002",
|
53
|
-
dims=1536,
|
54
|
-
),
|
55
|
-
)
|
56
|
-
|
57
|
-
llm: OpenAIGPTConfig = OpenAIGPTConfig(
|
58
|
-
type="openai",
|
59
|
-
chat_model=OpenAIChatModel.GPT4,
|
60
|
-
)
|
61
|
-
parsing: ParsingConfig = ParsingConfig(
|
62
|
-
splitter=Splitter.TOKENS,
|
63
|
-
chunk_size=100,
|
64
|
-
n_similar_docs=5,
|
65
|
-
)
|
66
|
-
|
67
|
-
prompts: PromptsConfig = PromptsConfig(
|
68
|
-
max_tokens=1000,
|
69
|
-
)
|
70
31
|
|
71
|
-
|
72
|
-
class RetrieverAgent(DocChatAgent, ABC):
|
32
|
+
class RetrieverAgent(DocChatAgent):
|
73
33
|
"""
|
74
|
-
Agent for retrieving
|
34
|
+
Agent for just retrieving chunks/docs/extracts matching a query
|
75
35
|
"""
|
76
36
|
|
77
|
-
def __init__(self, config:
|
37
|
+
def __init__(self, config: DocChatAgentConfig):
|
78
38
|
super().__init__(config)
|
79
|
-
self.config:
|
39
|
+
self.config: DocChatAgentConfig = config
|
40
|
+
logger.warning(
|
41
|
+
"""
|
42
|
+
`RetrieverAgent` is deprecated. Use `DocChatAgent` instead, with
|
43
|
+
`DocChatAgentConfig.retrieve_only=True`, and if you want to retrieve
|
44
|
+
FULL relevant doc-contents rather than just extracts, then set
|
45
|
+
`DocChatAgentConfig.extraction_granularity=-1`
|
46
|
+
"""
|
47
|
+
)
|
80
48
|
|
81
|
-
|
82
|
-
|
83
|
-
pass
|
49
|
+
def get_records(self) -> Sequence[Document]:
|
50
|
+
raise NotImplementedError
|
84
51
|
|
85
52
|
def ingest(self) -> None:
|
86
53
|
records = self.get_records()
|
87
54
|
if self.vecdb is None:
|
88
55
|
raise ValueError("No vector store specified")
|
89
56
|
self.vecdb.add_documents(records)
|
90
|
-
|
91
|
-
def llm_response(
|
92
|
-
self,
|
93
|
-
query: None | str | ChatDocument = None,
|
94
|
-
) -> Optional[ChatDocument]:
|
95
|
-
if not self.llm_can_respond(query):
|
96
|
-
return None
|
97
|
-
if query is None:
|
98
|
-
return super().llm_response(None) # type: ignore
|
99
|
-
if isinstance(query, ChatDocument):
|
100
|
-
query_str = query.content
|
101
|
-
else:
|
102
|
-
query_str = query
|
103
|
-
docs = self.get_relevant_extracts(query_str)
|
104
|
-
if len(docs) == 0:
|
105
|
-
return None
|
106
|
-
content = "\n\n".join([d.content for d in docs])
|
107
|
-
print(f"[green]{content}")
|
108
|
-
meta = dict(
|
109
|
-
sender=Entity.LLM,
|
110
|
-
)
|
111
|
-
meta.update(docs[0].metadata)
|
112
|
-
|
113
|
-
return ChatDocument(
|
114
|
-
content=content,
|
115
|
-
metadata=ChatDocMetaData(**meta),
|
116
|
-
)
|
117
|
-
|
118
|
-
def get_relevant_extracts(self, query: str) -> List[Document]:
|
119
|
-
"""
|
120
|
-
Given a query, get the records/docs whose contents are most relevant to the
|
121
|
-
query. First get nearest docs from vector store, then select the best
|
122
|
-
matches according to the LLM.
|
123
|
-
Args:
|
124
|
-
query (str): query string
|
125
|
-
|
126
|
-
Returns:
|
127
|
-
List[Document]: list of Document objects
|
128
|
-
"""
|
129
|
-
response = Document(
|
130
|
-
content=NO_ANSWER,
|
131
|
-
metadata=DocMetaData(
|
132
|
-
source="None",
|
133
|
-
),
|
134
|
-
)
|
135
|
-
nearest_docs = self.get_relevant_chunks(query)
|
136
|
-
if len(nearest_docs) == 0:
|
137
|
-
return [response]
|
138
|
-
if self.llm is None:
|
139
|
-
logger.warning("No LLM specified")
|
140
|
-
return nearest_docs
|
141
|
-
with console.status("LLM selecting relevant docs from retrieved ones..."):
|
142
|
-
with StreamingIfAllowed(self.llm, False):
|
143
|
-
doc_list = self.llm_select_relevant_docs(query, nearest_docs)
|
144
|
-
|
145
|
-
return doc_list
|
146
|
-
|
147
|
-
def llm_select_relevant_docs(
|
148
|
-
self, query: str, docs: List[Document]
|
149
|
-
) -> List[Document]:
|
150
|
-
"""
|
151
|
-
Given a query and a list of docs, select the docs whose contents match best,
|
152
|
-
according to the LLM. Use the doc IDs to select the docs from the vector
|
153
|
-
store.
|
154
|
-
Args:
|
155
|
-
query: query string
|
156
|
-
docs: list of Document objects
|
157
|
-
Returns:
|
158
|
-
list of Document objects
|
159
|
-
"""
|
160
|
-
doc_contents = "\n\n".join(
|
161
|
-
[f"DOC: ID={d.id()}, CONTENT: {d.content}" for d in docs]
|
162
|
-
)
|
163
|
-
prompt = f"""
|
164
|
-
Given the following QUERY:
|
165
|
-
{query}
|
166
|
-
and the following DOCS with IDs and contents
|
167
|
-
{doc_contents}
|
168
|
-
|
169
|
-
Find at most {self.config.n_matches} DOCs that are most relevant to the QUERY.
|
170
|
-
Return your answer as a sequence of DOC IDS ONLY, for example:
|
171
|
-
"id1 id2 id3..."
|
172
|
-
If there are no relevant docs, simply say {NO_ANSWER}.
|
173
|
-
Even if there is only one relevant doc, return it as a single ID.
|
174
|
-
Do not give any explanations or justifications.
|
175
|
-
"""
|
176
|
-
default_response = Document(
|
177
|
-
content=NO_ANSWER,
|
178
|
-
metadata=DocMetaData(
|
179
|
-
source="None",
|
180
|
-
),
|
181
|
-
)
|
182
|
-
|
183
|
-
if self.llm is None:
|
184
|
-
logger.warning("No LLM specified")
|
185
|
-
return [default_response]
|
186
|
-
response = self.llm.generate(
|
187
|
-
prompt, max_tokens=self.config.llm.max_output_tokens
|
188
|
-
)
|
189
|
-
if response.message == NO_ANSWER:
|
190
|
-
return [default_response]
|
191
|
-
ids = response.message.split()
|
192
|
-
if len(ids) == 0:
|
193
|
-
return [default_response]
|
194
|
-
if self.vecdb is None:
|
195
|
-
logger.warning("No vector store specified")
|
196
|
-
return [default_response]
|
197
|
-
docs = self.vecdb.get_documents_by_ids(ids)
|
198
|
-
return [
|
199
|
-
Document(content=d.content, metadata=DocMetaData(source="LLM"))
|
200
|
-
for d in docs
|
201
|
-
]
|
@@ -6,12 +6,13 @@ Functionality includes:
|
|
6
6
|
- adding table and column context
|
7
7
|
- asking a question about a SQL schema
|
8
8
|
"""
|
9
|
+
|
9
10
|
import logging
|
10
|
-
from typing import Any, Dict, Optional, Sequence, Union
|
11
|
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
11
12
|
|
12
13
|
from rich import print
|
13
14
|
from rich.console import Console
|
14
|
-
from sqlalchemy import MetaData, Row, create_engine, text
|
15
|
+
from sqlalchemy import MetaData, Row, create_engine, inspect, text
|
15
16
|
from sqlalchemy.engine import Engine
|
16
17
|
from sqlalchemy.exc import SQLAlchemyError
|
17
18
|
from sqlalchemy.orm import Session, sessionmaker
|
@@ -35,9 +36,7 @@ from langroid.agent.special.sql.utils.tools import (
|
|
35
36
|
GetTableSchemaTool,
|
36
37
|
RunQueryTool,
|
37
38
|
)
|
38
|
-
from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
|
39
39
|
from langroid.mytypes import Entity
|
40
|
-
from langroid.prompts.prompts_config import PromptsConfig
|
41
40
|
from langroid.vector_store.base import VectorStoreConfig
|
42
41
|
|
43
42
|
logger = logging.getLogger(__name__)
|
@@ -67,7 +66,6 @@ SQL_ERROR_MSG = "There was an error in your SQL Query"
|
|
67
66
|
class SQLChatAgentConfig(ChatAgentConfig):
|
68
67
|
system_message: str = DEFAULT_SQL_CHAT_SYSTEM_MESSAGE
|
69
68
|
user_message: None | str = None
|
70
|
-
max_context_tokens: int = 1000
|
71
69
|
cache: bool = True # cache results
|
72
70
|
debug: bool = False
|
73
71
|
stream: bool = True # allow streaming where needed
|
@@ -76,6 +74,7 @@ class SQLChatAgentConfig(ChatAgentConfig):
|
|
76
74
|
vecdb: None | VectorStoreConfig = None
|
77
75
|
context_descriptions: Dict[str, Dict[str, Union[str, Dict[str, str]]]] = {}
|
78
76
|
use_schema_tools: bool = False
|
77
|
+
multi_schema: bool = False
|
79
78
|
|
80
79
|
"""
|
81
80
|
Optional, but strongly recommended, context descriptions for tables, columns,
|
@@ -90,6 +89,9 @@ class SQLChatAgentConfig(ChatAgentConfig):
|
|
90
89
|
is another table name and the value is a description of the relationship to
|
91
90
|
that table.
|
92
91
|
|
92
|
+
If multi_schema support is enabled, the tables names in the description
|
93
|
+
should be of the form 'schema_name.table_name'.
|
94
|
+
|
93
95
|
For example:
|
94
96
|
{
|
95
97
|
'table1': {
|
@@ -109,15 +111,6 @@ class SQLChatAgentConfig(ChatAgentConfig):
|
|
109
111
|
}
|
110
112
|
"""
|
111
113
|
|
112
|
-
llm: OpenAIGPTConfig = OpenAIGPTConfig(
|
113
|
-
type="openai",
|
114
|
-
chat_model=OpenAIChatModel.GPT4,
|
115
|
-
completion_model=OpenAIChatModel.GPT4,
|
116
|
-
)
|
117
|
-
prompts: PromptsConfig = PromptsConfig(
|
118
|
-
max_tokens=1000,
|
119
|
-
)
|
120
|
-
|
121
114
|
|
122
115
|
class SQLChatAgent(ChatAgent):
|
123
116
|
"""
|
@@ -155,19 +148,44 @@ class SQLChatAgent(ChatAgent):
|
|
155
148
|
"""Initialize the database metadata."""
|
156
149
|
if self.engine is None:
|
157
150
|
raise ValueError("Database engine is None")
|
151
|
+
self.metadata: MetaData | List[MetaData] = []
|
158
152
|
|
159
|
-
self.
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
153
|
+
if self.config.multi_schema:
|
154
|
+
logger.info(
|
155
|
+
"Initializing SQLChatAgent with database: %s",
|
156
|
+
self.engine,
|
157
|
+
)
|
158
|
+
|
159
|
+
self.metadata = []
|
160
|
+
inspector = inspect(self.engine)
|
161
|
+
|
162
|
+
for schema in inspector.get_schema_names():
|
163
|
+
metadata = MetaData(schema=schema)
|
164
|
+
metadata.reflect(self.engine)
|
165
|
+
self.metadata.append(metadata)
|
166
|
+
|
167
|
+
logger.info(
|
168
|
+
"Initializing SQLChatAgent with database: %s, schema: %s, "
|
169
|
+
"and tables: %s",
|
170
|
+
self.engine,
|
171
|
+
schema,
|
172
|
+
metadata.tables,
|
173
|
+
)
|
174
|
+
else:
|
175
|
+
self.metadata = MetaData()
|
176
|
+
self.metadata.reflect(self.engine)
|
177
|
+
logger.info(
|
178
|
+
"SQLChatAgent initialized with database: %s and tables: %s",
|
179
|
+
self.engine,
|
180
|
+
self.metadata.tables,
|
181
|
+
)
|
166
182
|
|
167
183
|
def _init_table_metadata(self) -> None:
|
168
184
|
"""Initialize metadata for the tables present in the database."""
|
169
185
|
if not self.config.context_descriptions and isinstance(self.engine, Engine):
|
170
|
-
self.config.context_descriptions = extract_schema_descriptions(
|
186
|
+
self.config.context_descriptions = extract_schema_descriptions(
|
187
|
+
self.engine, self.config.multi_schema
|
188
|
+
)
|
171
189
|
|
172
190
|
if self.config.use_schema_tools:
|
173
191
|
self.table_metadata = populate_metadata_with_schema_tools(
|
@@ -228,8 +246,10 @@ class SQLChatAgent(ChatAgent):
|
|
228
246
|
if isinstance(msg, ChatDocument) and msg.function_call is not None:
|
229
247
|
sender_name = msg.function_call.name
|
230
248
|
|
249
|
+
content = results.content if isinstance(results, ChatDocument) else results
|
250
|
+
|
231
251
|
return ChatDocument(
|
232
|
-
content=
|
252
|
+
content=content,
|
233
253
|
metadata=ChatDocMetaData(
|
234
254
|
source=Entity.AGENT,
|
235
255
|
sender=Entity.AGENT,
|
@@ -329,6 +349,10 @@ class SQLChatAgent(ChatAgent):
|
|
329
349
|
Returns:
|
330
350
|
str: The names of all tables in the database.
|
331
351
|
"""
|
352
|
+
if isinstance(self.metadata, list):
|
353
|
+
table_names = [", ".join(md.tables.keys()) for md in self.metadata]
|
354
|
+
return ", ".join(table_names)
|
355
|
+
|
332
356
|
return ", ".join(self.metadata.tables.keys())
|
333
357
|
|
334
358
|
def get_table_schema(self, msg: GetTableSchemaTool) -> str:
|
@@ -9,3 +9,14 @@ from . import description_extractors
|
|
9
9
|
from . import populate_metadata
|
10
10
|
from . import system_message
|
11
11
|
from . import tools
|
12
|
+
|
13
|
+
__all__ = [
|
14
|
+
"RunQueryTool",
|
15
|
+
"GetTableNamesTool",
|
16
|
+
"GetTableSchemaTool",
|
17
|
+
"GetColumnDescriptionsTool",
|
18
|
+
"description_extractors",
|
19
|
+
"populate_metadata",
|
20
|
+
"system_message",
|
21
|
+
"tools",
|
22
|
+
]
|
@@ -1,10 +1,13 @@
|
|
1
|
-
from typing import Any, Dict, List
|
1
|
+
from typing import Any, Dict, List, Optional
|
2
2
|
|
3
3
|
from sqlalchemy import inspect, text
|
4
4
|
from sqlalchemy.engine import Engine
|
5
5
|
|
6
6
|
|
7
|
-
def extract_postgresql_descriptions(
|
7
|
+
def extract_postgresql_descriptions(
|
8
|
+
engine: Engine,
|
9
|
+
multi_schema: bool = False,
|
10
|
+
) -> Dict[str, Dict[str, Any]]:
|
8
11
|
"""
|
9
12
|
Extracts descriptions for tables and columns from a PostgreSQL database.
|
10
13
|
|
@@ -13,6 +16,7 @@ def extract_postgresql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]
|
|
13
16
|
|
14
17
|
Args:
|
15
18
|
engine (Engine): SQLAlchemy engine connected to a PostgreSQL database.
|
19
|
+
multi_schema (bool): Generate descriptions for all schemas in the database.
|
16
20
|
|
17
21
|
Returns:
|
18
22
|
Dict[str, Dict[str, Any]]: A dictionary mapping table names to a
|
@@ -20,36 +24,53 @@ def extract_postgresql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]
|
|
20
24
|
column descriptions.
|
21
25
|
"""
|
22
26
|
inspector = inspect(engine)
|
23
|
-
table_names: List[str] = inspector.get_table_names()
|
24
|
-
|
25
27
|
result: Dict[str, Dict[str, Any]] = {}
|
26
28
|
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
29
|
+
def gen_schema_descriptions(schema: Optional[str] = None) -> None:
|
30
|
+
table_names: List[str] = inspector.get_table_names(schema=schema)
|
31
|
+
with engine.connect() as conn:
|
32
|
+
for table in table_names:
|
33
|
+
if schema is None:
|
34
|
+
table_name = table
|
35
|
+
else:
|
36
|
+
table_name = f"{schema}.{table}"
|
35
37
|
|
36
|
-
|
37
|
-
col_data = inspector.get_columns(table)
|
38
|
-
for idx, col in enumerate(col_data, start=1):
|
39
|
-
col_comment = (
|
38
|
+
table_comment = (
|
40
39
|
conn.execute(
|
41
|
-
text(f"SELECT
|
40
|
+
text(f"SELECT obj_description('{table_name}'::regclass)")
|
42
41
|
).scalar()
|
43
42
|
or ""
|
44
43
|
)
|
45
|
-
columns[col["name"]] = col_comment
|
46
44
|
|
47
|
-
|
45
|
+
columns = {}
|
46
|
+
col_data = inspector.get_columns(table, schema=schema)
|
47
|
+
for idx, col in enumerate(col_data, start=1):
|
48
|
+
col_comment = (
|
49
|
+
conn.execute(
|
50
|
+
text(
|
51
|
+
f"SELECT col_description('{table_name}'::regclass, "
|
52
|
+
f"{idx})"
|
53
|
+
)
|
54
|
+
).scalar()
|
55
|
+
or ""
|
56
|
+
)
|
57
|
+
columns[col["name"]] = col_comment
|
58
|
+
|
59
|
+
result[table_name] = {"description": table_comment, "columns": columns}
|
60
|
+
|
61
|
+
if multi_schema:
|
62
|
+
for schema in inspector.get_schema_names():
|
63
|
+
gen_schema_descriptions(schema)
|
64
|
+
else:
|
65
|
+
gen_schema_descriptions()
|
48
66
|
|
49
67
|
return result
|
50
68
|
|
51
69
|
|
52
|
-
def extract_mysql_descriptions(
|
70
|
+
def extract_mysql_descriptions(
|
71
|
+
engine: Engine,
|
72
|
+
multi_schema: bool = False,
|
73
|
+
) -> Dict[str, Dict[str, Any]]:
|
53
74
|
"""Extracts descriptions for tables and columns from a MySQL database.
|
54
75
|
|
55
76
|
This method retrieves the descriptions of tables and their columns
|
@@ -57,6 +78,7 @@ def extract_mysql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
|
|
57
78
|
|
58
79
|
Args:
|
59
80
|
engine (Engine): SQLAlchemy engine connected to a MySQL database.
|
81
|
+
multi_schema (bool): Generate descriptions for all schemas in the database.
|
60
82
|
|
61
83
|
Returns:
|
62
84
|
Dict[str, Dict[str, Any]]: A dictionary mapping table names to a
|
@@ -64,31 +86,45 @@ def extract_mysql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
|
|
64
86
|
column descriptions.
|
65
87
|
"""
|
66
88
|
inspector = inspect(engine)
|
67
|
-
table_names: List[str] = inspector.get_table_names()
|
68
|
-
|
69
89
|
result: Dict[str, Dict[str, Any]] = {}
|
70
90
|
|
71
|
-
|
72
|
-
|
73
|
-
query = text(
|
74
|
-
"SELECT table_comment FROM information_schema.tables WHERE"
|
75
|
-
" table_schema = :schema AND table_name = :table"
|
76
|
-
)
|
77
|
-
table_result = conn.execute(
|
78
|
-
query, {"schema": engine.url.database, "table": table}
|
79
|
-
)
|
80
|
-
table_comment = table_result.scalar() or ""
|
91
|
+
def gen_schema_descriptions(schema: Optional[str] = None) -> None:
|
92
|
+
table_names: List[str] = inspector.get_table_names(schema=schema)
|
81
93
|
|
82
|
-
|
83
|
-
for
|
84
|
-
|
94
|
+
with engine.connect() as conn:
|
95
|
+
for table in table_names:
|
96
|
+
if schema is None:
|
97
|
+
table_name = table
|
98
|
+
else:
|
99
|
+
table_name = f"{schema}.{table}"
|
100
|
+
|
101
|
+
query = text(
|
102
|
+
"SELECT table_comment FROM information_schema.tables WHERE"
|
103
|
+
" table_schema = :schema AND table_name = :table"
|
104
|
+
)
|
105
|
+
table_result = conn.execute(
|
106
|
+
query, {"schema": engine.url.database, "table": table_name}
|
107
|
+
)
|
108
|
+
table_comment = table_result.scalar() or ""
|
109
|
+
|
110
|
+
columns = {}
|
111
|
+
for col in inspector.get_columns(table, schema=schema):
|
112
|
+
columns[col["name"]] = col.get("comment", "")
|
113
|
+
|
114
|
+
result[table_name] = {"description": table_comment, "columns": columns}
|
85
115
|
|
86
|
-
|
116
|
+
if multi_schema:
|
117
|
+
for schema in inspector.get_schema_names():
|
118
|
+
gen_schema_descriptions(schema)
|
119
|
+
else:
|
120
|
+
gen_schema_descriptions()
|
87
121
|
|
88
122
|
return result
|
89
123
|
|
90
124
|
|
91
|
-
def extract_default_descriptions(
|
125
|
+
def extract_default_descriptions(
|
126
|
+
engine: Engine, multi_schema: bool = False
|
127
|
+
) -> Dict[str, Dict[str, Any]]:
|
92
128
|
"""Extracts default descriptions for tables and columns from a database.
|
93
129
|
|
94
130
|
This method retrieves the table and column names from the given database
|
@@ -96,6 +132,7 @@ def extract_default_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
|
|
96
132
|
|
97
133
|
Args:
|
98
134
|
engine (Engine): SQLAlchemy engine connected to a database.
|
135
|
+
multi_schema (bool): Generate descriptions for all schemas in the database.
|
99
136
|
|
100
137
|
Returns:
|
101
138
|
Dict[str, Dict[str, Any]]: A dictionary mapping table names to a
|
@@ -103,26 +140,36 @@ def extract_default_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
|
|
103
140
|
empty column descriptions.
|
104
141
|
"""
|
105
142
|
inspector = inspect(engine)
|
106
|
-
table_names: List[str] = inspector.get_table_names()
|
107
|
-
|
108
143
|
result: Dict[str, Dict[str, Any]] = {}
|
109
144
|
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
145
|
+
def gen_schema_descriptions(schema: Optional[str] = None) -> None:
|
146
|
+
table_names: List[str] = inspector.get_table_names(schema=schema)
|
147
|
+
|
148
|
+
for table in table_names:
|
149
|
+
columns = {}
|
150
|
+
for col in inspector.get_columns(table):
|
151
|
+
columns[col["name"]] = ""
|
152
|
+
|
153
|
+
result[table] = {"description": "", "columns": columns}
|
114
154
|
|
115
|
-
|
155
|
+
if multi_schema:
|
156
|
+
for schema in inspector.get_schema_names():
|
157
|
+
gen_schema_descriptions(schema)
|
158
|
+
else:
|
159
|
+
gen_schema_descriptions()
|
116
160
|
|
117
161
|
return result
|
118
162
|
|
119
163
|
|
120
|
-
def extract_schema_descriptions(
|
164
|
+
def extract_schema_descriptions(
|
165
|
+
engine: Engine, multi_schema: bool = False
|
166
|
+
) -> Dict[str, Dict[str, Any]]:
|
121
167
|
"""
|
122
168
|
Extracts the schema descriptions from the database connected to by the engine.
|
123
169
|
|
124
170
|
Args:
|
125
171
|
engine (Engine): SQLAlchemy engine instance.
|
172
|
+
multi_schema (bool): Generate descriptions for all schemas in the database.
|
126
173
|
|
127
174
|
Returns:
|
128
175
|
Dict[str, Dict[str, Any]]: A dictionary representation of table and column
|
@@ -133,4 +180,6 @@ def extract_schema_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
|
|
133
180
|
"postgresql": extract_postgresql_descriptions,
|
134
181
|
"mysql": extract_mysql_descriptions,
|
135
182
|
}
|
136
|
-
return extractors.get(engine.dialect.name, extract_default_descriptions)(
|
183
|
+
return extractors.get(engine.dialect.name, extract_default_descriptions)(
|
184
|
+
engine, multi_schema=multi_schema
|
185
|
+
)
|