langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +95 -0
- langroid/agent/__init__.py +40 -0
- langroid/agent/base.py +222 -91
- langroid/agent/batch.py +264 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +247 -101
- langroid/agent/chat_document.py +41 -4
- langroid/agent/openai_assistant.py +842 -0
- langroid/agent/special/__init__.py +50 -0
- langroid/agent/special/doc_chat_agent.py +837 -141
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +32 -198
- langroid/agent/special/sql/__init__.py +11 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +22 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +475 -122
- langroid/agent/tool_message.py +75 -13
- langroid/agent/tools/__init__.py +13 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +16 -29
- langroid/agent/tools/run_python_code.py +60 -0
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/agent/tools/segment_extract_tool.py +36 -0
- langroid/cachedb/__init__.py +9 -0
- langroid/cachedb/base.py +22 -2
- langroid/cachedb/momento_cachedb.py +26 -2
- langroid/cachedb/redis_cachedb.py +78 -11
- langroid/embedding_models/__init__.py +34 -0
- langroid/embedding_models/base.py +21 -2
- langroid/embedding_models/models.py +120 -18
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +45 -0
- langroid/language_models/azure_openai.py +80 -27
- langroid/language_models/base.py +117 -12
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_assistants.py +3 -0
- langroid/language_models/openai_gpt.py +558 -174
- langroid/language_models/prompt_formatter/__init__.py +15 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +18 -21
- langroid/mytypes.py +25 -8
- langroid/parsing/__init__.py +46 -0
- langroid/parsing/document_parser.py +260 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +122 -59
- langroid/parsing/repo_loader.py +114 -52
- langroid/parsing/search.py +68 -63
- langroid/parsing/spider.py +3 -2
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -11
- langroid/parsing/urls.py +85 -37
- langroid/parsing/utils.py +298 -4
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +11 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +17 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +36 -5
- langroid/utils/constants.py +4 -0
- langroid/utils/globals.py +2 -2
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +21 -0
- langroid/utils/output/printing.py +47 -1
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +616 -2
- langroid/utils/system.py +98 -0
- langroid/vector_store/__init__.py +40 -0
- langroid/vector_store/base.py +203 -6
- langroid/vector_store/chromadb.py +59 -32
- langroid/vector_store/lancedb.py +463 -0
- langroid/vector_store/meilisearch.py +10 -7
- langroid/vector_store/momento.py +262 -0
- langroid/vector_store/qdrantdb.py +104 -22
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
- langroid-0.1.219.dist-info/RECORD +127 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.85.dist-info/RECORD +0 -94
- /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -0,0 +1,127 @@
|
|
1
|
+
"""
|
2
|
+
Agent to retrieve relevant segments from a body of text,
|
3
|
+
that are relevant to a query.
|
4
|
+
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Optional, no_type_check
|
9
|
+
|
10
|
+
from rich.console import Console
|
11
|
+
|
12
|
+
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
13
|
+
from langroid.agent.chat_document import ChatDocument
|
14
|
+
from langroid.agent.tools.segment_extract_tool import SegmentExtractTool
|
15
|
+
from langroid.language_models.base import LLMConfig
|
16
|
+
from langroid.language_models.openai_gpt import OpenAIGPTConfig
|
17
|
+
from langroid.mytypes import Entity
|
18
|
+
from langroid.parsing.utils import extract_numbered_segments, number_segments
|
19
|
+
from langroid.utils.constants import DONE, NO_ANSWER
|
20
|
+
|
21
|
+
console = Console()
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
class RelevanceExtractorAgentConfig(ChatAgentConfig):
|
26
|
+
llm: LLMConfig | None = OpenAIGPTConfig()
|
27
|
+
segment_length: int = 1 # number of sentences per segment
|
28
|
+
query: str = "" # query for relevance extraction
|
29
|
+
system_message = """
|
30
|
+
The user will give you a PASSAGE containing segments numbered as
|
31
|
+
<#1#>, <#2#>, <#3#>, etc.,
|
32
|
+
followed by a QUERY. Extract ONLY the segment-numbers from
|
33
|
+
the PASSAGE that are RELEVANT to the QUERY.
|
34
|
+
Present the extracted segment-numbers using the `extract_segments` tool/function.
|
35
|
+
"""
|
36
|
+
|
37
|
+
|
38
|
+
class RelevanceExtractorAgent(ChatAgent):
|
39
|
+
"""
|
40
|
+
Agent for extracting segments from text, that are relevant to a given query.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(self, config: RelevanceExtractorAgentConfig):
|
44
|
+
super().__init__(config)
|
45
|
+
self.config: RelevanceExtractorAgentConfig = config
|
46
|
+
self.enable_message(SegmentExtractTool)
|
47
|
+
self.numbered_passage: Optional[str] = None
|
48
|
+
|
49
|
+
@no_type_check
|
50
|
+
def llm_response(
|
51
|
+
self, message: Optional[str | ChatDocument] = None
|
52
|
+
) -> Optional[ChatDocument]:
|
53
|
+
"""Compose a prompt asking to extract relevant segments from a passage.
|
54
|
+
Steps:
|
55
|
+
- number the segments in the passage
|
56
|
+
- compose prompt
|
57
|
+
- send to LLM
|
58
|
+
"""
|
59
|
+
assert self.config.query is not None, "No query specified"
|
60
|
+
assert message is not None, "No message specified"
|
61
|
+
message_str = message.content if isinstance(message, ChatDocument) else message
|
62
|
+
# number the segments in the passage
|
63
|
+
self.numbered_passage = number_segments(message_str, self.config.segment_length)
|
64
|
+
# compose prompt
|
65
|
+
prompt = f"""
|
66
|
+
PASSAGE:
|
67
|
+
{self.numbered_passage}
|
68
|
+
|
69
|
+
QUERY: {self.config.query}
|
70
|
+
"""
|
71
|
+
# send to LLM
|
72
|
+
return super().llm_response(prompt)
|
73
|
+
|
74
|
+
@no_type_check
|
75
|
+
async def llm_response_async(
|
76
|
+
self, message: Optional[str | ChatDocument] = None
|
77
|
+
) -> Optional[ChatDocument]:
|
78
|
+
"""
|
79
|
+
Compose a prompt asking to extract relevant segments from a passage.
|
80
|
+
Steps:
|
81
|
+
- number the segments in the passage
|
82
|
+
- compose prompt
|
83
|
+
- send to LLM
|
84
|
+
The LLM is expected to generate a structured msg according to the
|
85
|
+
SegmentExtractTool schema, i.e. it should contain a `segment_list` field
|
86
|
+
whose value is a list of segment numbers or ranges, like "10,12,14-17".
|
87
|
+
"""
|
88
|
+
|
89
|
+
assert self.config.query is not None, "No query specified"
|
90
|
+
assert message is not None, "No message specified"
|
91
|
+
message_str = message.content if isinstance(message, ChatDocument) else message
|
92
|
+
# number the segments in the passage
|
93
|
+
self.numbered_passage = number_segments(message_str, self.config.segment_length)
|
94
|
+
# compose prompt
|
95
|
+
prompt = f"""
|
96
|
+
PASSAGE:
|
97
|
+
{self.numbered_passage}
|
98
|
+
|
99
|
+
QUERY: {self.config.query}
|
100
|
+
"""
|
101
|
+
# send to LLM
|
102
|
+
return await super().llm_response_async(prompt)
|
103
|
+
|
104
|
+
def extract_segments(self, msg: SegmentExtractTool) -> str:
|
105
|
+
"""Method to handle a segmentExtractTool message from LLM"""
|
106
|
+
spec = msg.segment_list
|
107
|
+
if len(self.message_history) == 0:
|
108
|
+
return DONE + " " + NO_ANSWER
|
109
|
+
if spec is None or spec.strip() in ["", NO_ANSWER]:
|
110
|
+
return DONE + " " + NO_ANSWER
|
111
|
+
assert self.numbered_passage is not None, "No numbered passage"
|
112
|
+
# assume this has numbered segments
|
113
|
+
try:
|
114
|
+
extracts = extract_numbered_segments(self.numbered_passage, spec)
|
115
|
+
except Exception:
|
116
|
+
return DONE + " " + NO_ANSWER
|
117
|
+
# this response ends the task by saying DONE
|
118
|
+
return DONE + " " + extracts
|
119
|
+
|
120
|
+
def handle_message_fallback(
|
121
|
+
self, msg: str | ChatDocument
|
122
|
+
) -> str | ChatDocument | None:
|
123
|
+
"""Handle case where LLM forgets to use SegmentExtractTool"""
|
124
|
+
if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM:
|
125
|
+
return DONE + " " + NO_ANSWER
|
126
|
+
else:
|
127
|
+
return None
|
@@ -1,222 +1,56 @@
|
|
1
1
|
"""
|
2
|
-
|
3
|
-
|
2
|
+
Deprecated: use DocChatAgent instead, with DocChatAgentConfig.retrieve_only=True,
|
3
|
+
and if you want to retrieve FULL relevant doc-contents rather than just extracts,
|
4
|
+
then set DocChatAgentConfig.extraction_granularity=-1
|
5
|
+
|
6
|
+
This is an agent to retrieve relevant extracts from a vector store,
|
7
|
+
where the LLM is used to filter for "true" relevance after retrieval from the
|
8
|
+
vector store.
|
9
|
+
This is essentially the same as DocChatAgent, except that instead of
|
10
|
+
generating final summary answer based on relevant extracts, it just returns
|
11
|
+
those extracts.
|
12
|
+
See test_retriever_agent.py for example usage.
|
4
13
|
"""
|
14
|
+
|
5
15
|
import logging
|
6
|
-
from
|
7
|
-
from typing import List, Optional, Sequence
|
16
|
+
from typing import Sequence
|
8
17
|
|
9
|
-
from rich import print
|
10
18
|
from rich.console import Console
|
11
19
|
|
12
|
-
from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
|
13
20
|
from langroid.agent.special.doc_chat_agent import DocChatAgent, DocChatAgentConfig
|
14
|
-
from langroid.
|
15
|
-
from langroid.language_models.base import StreamingIfAllowed
|
16
|
-
from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
|
17
|
-
from langroid.mytypes import DocMetaData, Document, Entity
|
18
|
-
from langroid.parsing.parser import ParsingConfig, Splitter
|
19
|
-
from langroid.prompts.prompts_config import PromptsConfig
|
20
|
-
from langroid.utils.constants import NO_ANSWER
|
21
|
-
from langroid.vector_store.base import VectorStoreConfig
|
22
|
-
from langroid.vector_store.qdrantdb import QdrantDBConfig
|
21
|
+
from langroid.mytypes import DocMetaData, Document
|
23
22
|
|
24
23
|
console = Console()
|
25
24
|
logger = logging.getLogger(__name__)
|
26
25
|
|
26
|
+
# for backwards compatibility:
|
27
|
+
RecordMetadata = DocMetaData
|
28
|
+
RecordDoc = Document
|
29
|
+
RetrieverAgentConfig = DocChatAgentConfig
|
27
30
|
|
28
|
-
class RecordMetadata(DocMetaData):
|
29
|
-
id: None | int | str = None
|
30
|
-
|
31
|
-
|
32
|
-
class RecordDoc(Document):
|
33
|
-
metadata: RecordMetadata
|
34
|
-
|
35
|
-
|
36
|
-
class RetrieverAgentConfig(DocChatAgentConfig):
|
37
|
-
n_matches: int = 3
|
38
|
-
debug: bool = False
|
39
|
-
max_context_tokens = 500
|
40
|
-
conversation_mode = True
|
41
|
-
cache: bool = True # cache results
|
42
|
-
gpt4: bool = True # use GPT-4
|
43
|
-
stream: bool = True # allow streaming where needed
|
44
|
-
max_tokens: int = 10000
|
45
|
-
vecdb: VectorStoreConfig = QdrantDBConfig(
|
46
|
-
collection_name=None,
|
47
|
-
storage_path=".qdrant/data/",
|
48
|
-
embedding=OpenAIEmbeddingsConfig(
|
49
|
-
model_type="openai",
|
50
|
-
model_name="text-embedding-ada-002",
|
51
|
-
dims=1536,
|
52
|
-
),
|
53
|
-
)
|
54
31
|
|
55
|
-
|
56
|
-
type="openai",
|
57
|
-
chat_model=OpenAIChatModel.GPT4,
|
58
|
-
)
|
59
|
-
parsing: ParsingConfig = ParsingConfig(
|
60
|
-
splitter=Splitter.TOKENS,
|
61
|
-
chunk_size=100,
|
62
|
-
n_similar_docs=5,
|
63
|
-
)
|
64
|
-
|
65
|
-
prompts: PromptsConfig = PromptsConfig(
|
66
|
-
max_tokens=1000,
|
67
|
-
)
|
68
|
-
|
69
|
-
|
70
|
-
class RetrieverAgent(DocChatAgent, ABC):
|
32
|
+
class RetrieverAgent(DocChatAgent):
|
71
33
|
"""
|
72
|
-
Agent for retrieving
|
34
|
+
Agent for just retrieving chunks/docs/extracts matching a query
|
73
35
|
"""
|
74
36
|
|
75
|
-
def __init__(self, config:
|
37
|
+
def __init__(self, config: DocChatAgentConfig):
|
76
38
|
super().__init__(config)
|
77
|
-
self.config:
|
39
|
+
self.config: DocChatAgentConfig = config
|
40
|
+
logger.warning(
|
41
|
+
"""
|
42
|
+
`RetrieverAgent` is deprecated. Use `DocChatAgent` instead, with
|
43
|
+
`DocChatAgentConfig.retrieve_only=True`, and if you want to retrieve
|
44
|
+
FULL relevant doc-contents rather than just extracts, then set
|
45
|
+
`DocChatAgentConfig.extraction_granularity=-1`
|
46
|
+
"""
|
47
|
+
)
|
78
48
|
|
79
|
-
|
80
|
-
|
81
|
-
pass
|
49
|
+
def get_records(self) -> Sequence[Document]:
|
50
|
+
raise NotImplementedError
|
82
51
|
|
83
52
|
def ingest(self) -> None:
|
84
53
|
records = self.get_records()
|
85
54
|
if self.vecdb is None:
|
86
55
|
raise ValueError("No vector store specified")
|
87
56
|
self.vecdb.add_documents(records)
|
88
|
-
|
89
|
-
def llm_response(
|
90
|
-
self,
|
91
|
-
query: None | str | ChatDocument = None,
|
92
|
-
) -> Optional[ChatDocument]:
|
93
|
-
if not self.llm_can_respond(query):
|
94
|
-
return None
|
95
|
-
if query is None:
|
96
|
-
return super().llm_response(None) # type: ignore
|
97
|
-
if isinstance(query, ChatDocument):
|
98
|
-
query_str = query.content
|
99
|
-
else:
|
100
|
-
query_str = query
|
101
|
-
docs = self.get_relevant_extracts(query_str)
|
102
|
-
if len(docs) == 0:
|
103
|
-
return None
|
104
|
-
content = "\n\n".join([d.content for d in docs])
|
105
|
-
print(f"[green]{content}")
|
106
|
-
meta = dict(
|
107
|
-
sender=Entity.LLM,
|
108
|
-
)
|
109
|
-
meta.update(docs[0].metadata)
|
110
|
-
|
111
|
-
return ChatDocument(
|
112
|
-
content=content,
|
113
|
-
metadata=ChatDocMetaData(**meta),
|
114
|
-
)
|
115
|
-
|
116
|
-
def get_nearest_docs(self, query: str) -> List[Document]:
|
117
|
-
"""
|
118
|
-
Given a query, get the records/docs whose contents are closest to the
|
119
|
-
query, in terms of vector similarity.
|
120
|
-
Args:
|
121
|
-
query: query string
|
122
|
-
Returns:
|
123
|
-
list of Document objects
|
124
|
-
"""
|
125
|
-
if self.vecdb is None:
|
126
|
-
logger.warning("No vector store specified")
|
127
|
-
return []
|
128
|
-
with console.status("[cyan]Searching VecDB for similar docs/records..."):
|
129
|
-
docs_and_scores = self.vecdb.similar_texts_with_scores(
|
130
|
-
query,
|
131
|
-
k=self.config.parsing.n_similar_docs,
|
132
|
-
)
|
133
|
-
docs: List[Document] = [
|
134
|
-
Document(content=d.content, metadata=d.metadata)
|
135
|
-
for (d, _) in docs_and_scores
|
136
|
-
]
|
137
|
-
return docs
|
138
|
-
|
139
|
-
def get_relevant_extracts(self, query: str) -> List[Document]:
|
140
|
-
"""
|
141
|
-
Given a query, get the records/docs whose contents are most relevant to the
|
142
|
-
query. First get nearest docs from vector store, then select the best
|
143
|
-
matches according to the LLM.
|
144
|
-
Args:
|
145
|
-
query (str): query string
|
146
|
-
|
147
|
-
Returns:
|
148
|
-
List[Document]: list of Document objects
|
149
|
-
"""
|
150
|
-
response = Document(
|
151
|
-
content=NO_ANSWER,
|
152
|
-
metadata=DocMetaData(
|
153
|
-
source="None",
|
154
|
-
),
|
155
|
-
)
|
156
|
-
nearest_docs = self.get_nearest_docs(query)
|
157
|
-
if len(nearest_docs) == 0:
|
158
|
-
return [response]
|
159
|
-
if self.llm is None:
|
160
|
-
logger.warning("No LLM specified")
|
161
|
-
return nearest_docs
|
162
|
-
with console.status("LLM selecting relevant docs from retrieved ones..."):
|
163
|
-
with StreamingIfAllowed(self.llm, False):
|
164
|
-
doc_list = self.llm_select_relevant_docs(query, nearest_docs)
|
165
|
-
|
166
|
-
return doc_list
|
167
|
-
|
168
|
-
def llm_select_relevant_docs(
|
169
|
-
self, query: str, docs: List[Document]
|
170
|
-
) -> List[Document]:
|
171
|
-
"""
|
172
|
-
Given a query and a list of docs, select the docs whose contents match best,
|
173
|
-
according to the LLM. Use the doc IDs to select the docs from the vector
|
174
|
-
store.
|
175
|
-
Args:
|
176
|
-
query: query string
|
177
|
-
docs: list of Document objects
|
178
|
-
Returns:
|
179
|
-
list of Document objects
|
180
|
-
"""
|
181
|
-
doc_contents = "\n\n".join(
|
182
|
-
[f"DOC: ID={d.id()}, CONTENT: {d.content}" for d in docs]
|
183
|
-
)
|
184
|
-
prompt = f"""
|
185
|
-
Given the following QUERY:
|
186
|
-
{query}
|
187
|
-
and the following DOCS with IDs and contents
|
188
|
-
{doc_contents}
|
189
|
-
|
190
|
-
Find at most {self.config.n_matches} DOCs that are most relevant to the QUERY.
|
191
|
-
Return your answer as a sequence of DOC IDS ONLY, for example:
|
192
|
-
"id1 id2 id3..."
|
193
|
-
If there are no relevant docs, simply say {NO_ANSWER}.
|
194
|
-
Even if there is only one relevant doc, return it as a single ID.
|
195
|
-
Do not give any explanations or justifications.
|
196
|
-
"""
|
197
|
-
default_response = Document(
|
198
|
-
content=NO_ANSWER,
|
199
|
-
metadata=DocMetaData(
|
200
|
-
source="None",
|
201
|
-
),
|
202
|
-
)
|
203
|
-
|
204
|
-
if self.llm is None:
|
205
|
-
logger.warning("No LLM specified")
|
206
|
-
return [default_response]
|
207
|
-
response = self.llm.generate(
|
208
|
-
prompt, max_tokens=self.config.llm.max_output_tokens
|
209
|
-
)
|
210
|
-
if response.message == NO_ANSWER:
|
211
|
-
return [default_response]
|
212
|
-
ids = response.message.split()
|
213
|
-
if len(ids) == 0:
|
214
|
-
return [default_response]
|
215
|
-
if self.vecdb is None:
|
216
|
-
logger.warning("No vector store specified")
|
217
|
-
return [default_response]
|
218
|
-
docs = self.vecdb.get_documents_by_ids(ids)
|
219
|
-
return [
|
220
|
-
Document(content=d.content, metadata=DocMetaData(source="LLM"))
|
221
|
-
for d in docs
|
222
|
-
]
|
@@ -6,12 +6,13 @@ Functionality includes:
|
|
6
6
|
- adding table and column context
|
7
7
|
- asking a question about a SQL schema
|
8
8
|
"""
|
9
|
+
|
9
10
|
import logging
|
10
|
-
from typing import Any, Dict, Optional, Sequence, Union
|
11
|
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
11
12
|
|
12
13
|
from rich import print
|
13
14
|
from rich.console import Console
|
14
|
-
from sqlalchemy import MetaData, Row, create_engine, text
|
15
|
+
from sqlalchemy import MetaData, Row, create_engine, inspect, text
|
15
16
|
from sqlalchemy.engine import Engine
|
16
17
|
from sqlalchemy.exc import SQLAlchemyError
|
17
18
|
from sqlalchemy.orm import Session, sessionmaker
|
@@ -35,9 +36,7 @@ from langroid.agent.special.sql.utils.tools import (
|
|
35
36
|
GetTableSchemaTool,
|
36
37
|
RunQueryTool,
|
37
38
|
)
|
38
|
-
from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
|
39
39
|
from langroid.mytypes import Entity
|
40
|
-
from langroid.prompts.prompts_config import PromptsConfig
|
41
40
|
from langroid.vector_store.base import VectorStoreConfig
|
42
41
|
|
43
42
|
logger = logging.getLogger(__name__)
|
@@ -67,7 +66,6 @@ SQL_ERROR_MSG = "There was an error in your SQL Query"
|
|
67
66
|
class SQLChatAgentConfig(ChatAgentConfig):
|
68
67
|
system_message: str = DEFAULT_SQL_CHAT_SYSTEM_MESSAGE
|
69
68
|
user_message: None | str = None
|
70
|
-
max_context_tokens: int = 1000
|
71
69
|
cache: bool = True # cache results
|
72
70
|
debug: bool = False
|
73
71
|
stream: bool = True # allow streaming where needed
|
@@ -76,6 +74,7 @@ class SQLChatAgentConfig(ChatAgentConfig):
|
|
76
74
|
vecdb: None | VectorStoreConfig = None
|
77
75
|
context_descriptions: Dict[str, Dict[str, Union[str, Dict[str, str]]]] = {}
|
78
76
|
use_schema_tools: bool = False
|
77
|
+
multi_schema: bool = False
|
79
78
|
|
80
79
|
"""
|
81
80
|
Optional, but strongly recommended, context descriptions for tables, columns,
|
@@ -90,6 +89,9 @@ class SQLChatAgentConfig(ChatAgentConfig):
|
|
90
89
|
is another table name and the value is a description of the relationship to
|
91
90
|
that table.
|
92
91
|
|
92
|
+
If multi_schema support is enabled, the tables names in the description
|
93
|
+
should be of the form 'schema_name.table_name'.
|
94
|
+
|
93
95
|
For example:
|
94
96
|
{
|
95
97
|
'table1': {
|
@@ -109,15 +111,6 @@ class SQLChatAgentConfig(ChatAgentConfig):
|
|
109
111
|
}
|
110
112
|
"""
|
111
113
|
|
112
|
-
llm: OpenAIGPTConfig = OpenAIGPTConfig(
|
113
|
-
type="openai",
|
114
|
-
chat_model=OpenAIChatModel.GPT4,
|
115
|
-
completion_model=OpenAIChatModel.GPT4,
|
116
|
-
)
|
117
|
-
prompts: PromptsConfig = PromptsConfig(
|
118
|
-
max_tokens=1000,
|
119
|
-
)
|
120
|
-
|
121
114
|
|
122
115
|
class SQLChatAgent(ChatAgent):
|
123
116
|
"""
|
@@ -155,19 +148,44 @@ class SQLChatAgent(ChatAgent):
|
|
155
148
|
"""Initialize the database metadata."""
|
156
149
|
if self.engine is None:
|
157
150
|
raise ValueError("Database engine is None")
|
151
|
+
self.metadata: MetaData | List[MetaData] = []
|
158
152
|
|
159
|
-
self.
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
153
|
+
if self.config.multi_schema:
|
154
|
+
logger.info(
|
155
|
+
"Initializing SQLChatAgent with database: %s",
|
156
|
+
self.engine,
|
157
|
+
)
|
158
|
+
|
159
|
+
self.metadata = []
|
160
|
+
inspector = inspect(self.engine)
|
161
|
+
|
162
|
+
for schema in inspector.get_schema_names():
|
163
|
+
metadata = MetaData(schema=schema)
|
164
|
+
metadata.reflect(self.engine)
|
165
|
+
self.metadata.append(metadata)
|
166
|
+
|
167
|
+
logger.info(
|
168
|
+
"Initializing SQLChatAgent with database: %s, schema: %s, "
|
169
|
+
"and tables: %s",
|
170
|
+
self.engine,
|
171
|
+
schema,
|
172
|
+
metadata.tables,
|
173
|
+
)
|
174
|
+
else:
|
175
|
+
self.metadata = MetaData()
|
176
|
+
self.metadata.reflect(self.engine)
|
177
|
+
logger.info(
|
178
|
+
"SQLChatAgent initialized with database: %s and tables: %s",
|
179
|
+
self.engine,
|
180
|
+
self.metadata.tables,
|
181
|
+
)
|
166
182
|
|
167
183
|
def _init_table_metadata(self) -> None:
|
168
184
|
"""Initialize metadata for the tables present in the database."""
|
169
185
|
if not self.config.context_descriptions and isinstance(self.engine, Engine):
|
170
|
-
self.config.context_descriptions = extract_schema_descriptions(
|
186
|
+
self.config.context_descriptions = extract_schema_descriptions(
|
187
|
+
self.engine, self.config.multi_schema
|
188
|
+
)
|
171
189
|
|
172
190
|
if self.config.use_schema_tools:
|
173
191
|
self.table_metadata = populate_metadata_with_schema_tools(
|
@@ -228,8 +246,10 @@ class SQLChatAgent(ChatAgent):
|
|
228
246
|
if isinstance(msg, ChatDocument) and msg.function_call is not None:
|
229
247
|
sender_name = msg.function_call.name
|
230
248
|
|
249
|
+
content = results.content if isinstance(results, ChatDocument) else results
|
250
|
+
|
231
251
|
return ChatDocument(
|
232
|
-
content=
|
252
|
+
content=content,
|
233
253
|
metadata=ChatDocMetaData(
|
234
254
|
source=Entity.AGENT,
|
235
255
|
sender=Entity.AGENT,
|
@@ -329,6 +349,10 @@ class SQLChatAgent(ChatAgent):
|
|
329
349
|
Returns:
|
330
350
|
str: The names of all tables in the database.
|
331
351
|
"""
|
352
|
+
if isinstance(self.metadata, list):
|
353
|
+
table_names = [", ".join(md.tables.keys()) for md in self.metadata]
|
354
|
+
return ", ".join(table_names)
|
355
|
+
|
332
356
|
return ", ".join(self.metadata.tables.keys())
|
333
357
|
|
334
358
|
def get_table_schema(self, msg: GetTableSchemaTool) -> str:
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from .tools import (
|
2
|
+
RunQueryTool,
|
3
|
+
GetTableNamesTool,
|
4
|
+
GetTableSchemaTool,
|
5
|
+
GetColumnDescriptionsTool,
|
6
|
+
)
|
7
|
+
|
8
|
+
from . import description_extractors
|
9
|
+
from . import populate_metadata
|
10
|
+
from . import system_message
|
11
|
+
from . import tools
|
12
|
+
|
13
|
+
__all__ = [
|
14
|
+
"RunQueryTool",
|
15
|
+
"GetTableNamesTool",
|
16
|
+
"GetTableSchemaTool",
|
17
|
+
"GetColumnDescriptionsTool",
|
18
|
+
"description_extractors",
|
19
|
+
"populate_metadata",
|
20
|
+
"system_message",
|
21
|
+
"tools",
|
22
|
+
]
|