langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. langroid/__init__.py +95 -0
  2. langroid/agent/__init__.py +40 -0
  3. langroid/agent/base.py +222 -91
  4. langroid/agent/batch.py +264 -0
  5. langroid/agent/callbacks/chainlit.py +608 -0
  6. langroid/agent/chat_agent.py +247 -101
  7. langroid/agent/chat_document.py +41 -4
  8. langroid/agent/openai_assistant.py +842 -0
  9. langroid/agent/special/__init__.py +50 -0
  10. langroid/agent/special/doc_chat_agent.py +837 -141
  11. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  12. langroid/agent/special/lance_rag/__init__.py +9 -0
  13. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  14. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  15. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  16. langroid/agent/special/lance_tools.py +44 -0
  17. langroid/agent/special/neo4j/__init__.py +0 -0
  18. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  20. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  21. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  22. langroid/agent/special/relevance_extractor_agent.py +127 -0
  23. langroid/agent/special/retriever_agent.py +32 -198
  24. langroid/agent/special/sql/__init__.py +11 -0
  25. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  26. langroid/agent/special/sql/utils/__init__.py +22 -0
  27. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  28. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  29. langroid/agent/special/table_chat_agent.py +43 -9
  30. langroid/agent/task.py +475 -122
  31. langroid/agent/tool_message.py +75 -13
  32. langroid/agent/tools/__init__.py +13 -0
  33. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  34. langroid/agent/tools/google_search_tool.py +11 -0
  35. langroid/agent/tools/metaphor_search_tool.py +67 -0
  36. langroid/agent/tools/recipient_tool.py +16 -29
  37. langroid/agent/tools/run_python_code.py +60 -0
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/agent/tools/segment_extract_tool.py +36 -0
  40. langroid/cachedb/__init__.py +9 -0
  41. langroid/cachedb/base.py +22 -2
  42. langroid/cachedb/momento_cachedb.py +26 -2
  43. langroid/cachedb/redis_cachedb.py +78 -11
  44. langroid/embedding_models/__init__.py +34 -0
  45. langroid/embedding_models/base.py +21 -2
  46. langroid/embedding_models/models.py +120 -18
  47. langroid/embedding_models/protoc/embeddings.proto +19 -0
  48. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  49. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  50. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  51. langroid/embedding_models/remote_embeds.py +153 -0
  52. langroid/language_models/__init__.py +45 -0
  53. langroid/language_models/azure_openai.py +80 -27
  54. langroid/language_models/base.py +117 -12
  55. langroid/language_models/config.py +5 -0
  56. langroid/language_models/openai_assistants.py +3 -0
  57. langroid/language_models/openai_gpt.py +558 -174
  58. langroid/language_models/prompt_formatter/__init__.py +15 -0
  59. langroid/language_models/prompt_formatter/base.py +4 -6
  60. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  61. langroid/language_models/utils.py +18 -21
  62. langroid/mytypes.py +25 -8
  63. langroid/parsing/__init__.py +46 -0
  64. langroid/parsing/document_parser.py +260 -63
  65. langroid/parsing/image_text.py +32 -0
  66. langroid/parsing/parse_json.py +143 -0
  67. langroid/parsing/parser.py +122 -59
  68. langroid/parsing/repo_loader.py +114 -52
  69. langroid/parsing/search.py +68 -63
  70. langroid/parsing/spider.py +3 -2
  71. langroid/parsing/table_loader.py +44 -0
  72. langroid/parsing/url_loader.py +59 -11
  73. langroid/parsing/urls.py +85 -37
  74. langroid/parsing/utils.py +298 -4
  75. langroid/parsing/web_search.py +73 -0
  76. langroid/prompts/__init__.py +11 -0
  77. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  78. langroid/prompts/prompts_config.py +1 -1
  79. langroid/utils/__init__.py +17 -0
  80. langroid/utils/algorithms/__init__.py +3 -0
  81. langroid/utils/algorithms/graph.py +103 -0
  82. langroid/utils/configuration.py +36 -5
  83. langroid/utils/constants.py +4 -0
  84. langroid/utils/globals.py +2 -2
  85. langroid/utils/logging.py +2 -5
  86. langroid/utils/output/__init__.py +21 -0
  87. langroid/utils/output/printing.py +47 -1
  88. langroid/utils/output/status.py +33 -0
  89. langroid/utils/pandas_utils.py +30 -0
  90. langroid/utils/pydantic_utils.py +616 -2
  91. langroid/utils/system.py +98 -0
  92. langroid/vector_store/__init__.py +40 -0
  93. langroid/vector_store/base.py +203 -6
  94. langroid/vector_store/chromadb.py +59 -32
  95. langroid/vector_store/lancedb.py +463 -0
  96. langroid/vector_store/meilisearch.py +10 -7
  97. langroid/vector_store/momento.py +262 -0
  98. langroid/vector_store/qdrantdb.py +104 -22
  99. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
  100. langroid-0.1.219.dist-info/RECORD +127 -0
  101. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
  102. langroid/agent/special/recipient_validator_agent.py +0 -157
  103. langroid/parsing/json.py +0 -64
  104. langroid/utils/web/selenium_login.py +0 -36
  105. langroid-0.1.85.dist-info/RECORD +0 -94
  106. /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
  107. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -0,0 +1,127 @@
1
+ """
2
+ Agent to retrieve relevant segments from a body of text,
3
+ that are relevant to a query.
4
+
5
+ """
6
+
7
+ import logging
8
+ from typing import Optional, no_type_check
9
+
10
+ from rich.console import Console
11
+
12
+ from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
13
+ from langroid.agent.chat_document import ChatDocument
14
+ from langroid.agent.tools.segment_extract_tool import SegmentExtractTool
15
+ from langroid.language_models.base import LLMConfig
16
+ from langroid.language_models.openai_gpt import OpenAIGPTConfig
17
+ from langroid.mytypes import Entity
18
+ from langroid.parsing.utils import extract_numbered_segments, number_segments
19
+ from langroid.utils.constants import DONE, NO_ANSWER
20
+
21
+ console = Console()
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class RelevanceExtractorAgentConfig(ChatAgentConfig):
26
+ llm: LLMConfig | None = OpenAIGPTConfig()
27
+ segment_length: int = 1 # number of sentences per segment
28
+ query: str = "" # query for relevance extraction
29
+ system_message = """
30
+ The user will give you a PASSAGE containing segments numbered as
31
+ <#1#>, <#2#>, <#3#>, etc.,
32
+ followed by a QUERY. Extract ONLY the segment-numbers from
33
+ the PASSAGE that are RELEVANT to the QUERY.
34
+ Present the extracted segment-numbers using the `extract_segments` tool/function.
35
+ """
36
+
37
+
38
+ class RelevanceExtractorAgent(ChatAgent):
39
+ """
40
+ Agent for extracting segments from text, that are relevant to a given query.
41
+ """
42
+
43
+ def __init__(self, config: RelevanceExtractorAgentConfig):
44
+ super().__init__(config)
45
+ self.config: RelevanceExtractorAgentConfig = config
46
+ self.enable_message(SegmentExtractTool)
47
+ self.numbered_passage: Optional[str] = None
48
+
49
+ @no_type_check
50
+ def llm_response(
51
+ self, message: Optional[str | ChatDocument] = None
52
+ ) -> Optional[ChatDocument]:
53
+ """Compose a prompt asking to extract relevant segments from a passage.
54
+ Steps:
55
+ - number the segments in the passage
56
+ - compose prompt
57
+ - send to LLM
58
+ """
59
+ assert self.config.query is not None, "No query specified"
60
+ assert message is not None, "No message specified"
61
+ message_str = message.content if isinstance(message, ChatDocument) else message
62
+ # number the segments in the passage
63
+ self.numbered_passage = number_segments(message_str, self.config.segment_length)
64
+ # compose prompt
65
+ prompt = f"""
66
+ PASSAGE:
67
+ {self.numbered_passage}
68
+
69
+ QUERY: {self.config.query}
70
+ """
71
+ # send to LLM
72
+ return super().llm_response(prompt)
73
+
74
+ @no_type_check
75
+ async def llm_response_async(
76
+ self, message: Optional[str | ChatDocument] = None
77
+ ) -> Optional[ChatDocument]:
78
+ """
79
+ Compose a prompt asking to extract relevant segments from a passage.
80
+ Steps:
81
+ - number the segments in the passage
82
+ - compose prompt
83
+ - send to LLM
84
+ The LLM is expected to generate a structured msg according to the
85
+ SegmentExtractTool schema, i.e. it should contain a `segment_list` field
86
+ whose value is a list of segment numbers or ranges, like "10,12,14-17".
87
+ """
88
+
89
+ assert self.config.query is not None, "No query specified"
90
+ assert message is not None, "No message specified"
91
+ message_str = message.content if isinstance(message, ChatDocument) else message
92
+ # number the segments in the passage
93
+ self.numbered_passage = number_segments(message_str, self.config.segment_length)
94
+ # compose prompt
95
+ prompt = f"""
96
+ PASSAGE:
97
+ {self.numbered_passage}
98
+
99
+ QUERY: {self.config.query}
100
+ """
101
+ # send to LLM
102
+ return await super().llm_response_async(prompt)
103
+
104
+ def extract_segments(self, msg: SegmentExtractTool) -> str:
105
+ """Method to handle a segmentExtractTool message from LLM"""
106
+ spec = msg.segment_list
107
+ if len(self.message_history) == 0:
108
+ return DONE + " " + NO_ANSWER
109
+ if spec is None or spec.strip() in ["", NO_ANSWER]:
110
+ return DONE + " " + NO_ANSWER
111
+ assert self.numbered_passage is not None, "No numbered passage"
112
+ # assume this has numbered segments
113
+ try:
114
+ extracts = extract_numbered_segments(self.numbered_passage, spec)
115
+ except Exception:
116
+ return DONE + " " + NO_ANSWER
117
+ # this response ends the task by saying DONE
118
+ return DONE + " " + extracts
119
+
120
+ def handle_message_fallback(
121
+ self, msg: str | ChatDocument
122
+ ) -> str | ChatDocument | None:
123
+ """Handle case where LLM forgets to use SegmentExtractTool"""
124
+ if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM:
125
+ return DONE + " " + NO_ANSWER
126
+ else:
127
+ return None
@@ -1,222 +1,56 @@
1
1
  """
2
- Agent to retrieve relevant verbatim whole docs/records from a vector store.
3
- See test_retriever_agent.py for example usage:
2
+ Deprecated: use DocChatAgent instead, with DocChatAgentConfig.retrieve_only=True,
3
+ and if you want to retrieve FULL relevant doc-contents rather than just extracts,
4
+ then set DocChatAgentConfig.extraction_granularity=-1
5
+
6
+ This is an agent to retrieve relevant extracts from a vector store,
7
+ where the LLM is used to filter for "true" relevance after retrieval from the
8
+ vector store.
9
+ This is essentially the same as DocChatAgent, except that instead of
10
+ generating final summary answer based on relevant extracts, it just returns
11
+ those extracts.
12
+ See test_retriever_agent.py for example usage.
4
13
  """
14
+
5
15
  import logging
6
- from abc import ABC, abstractmethod
7
- from typing import List, Optional, Sequence
16
+ from typing import Sequence
8
17
 
9
- from rich import print
10
18
  from rich.console import Console
11
19
 
12
- from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
13
20
  from langroid.agent.special.doc_chat_agent import DocChatAgent, DocChatAgentConfig
14
- from langroid.embedding_models.models import OpenAIEmbeddingsConfig
15
- from langroid.language_models.base import StreamingIfAllowed
16
- from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
17
- from langroid.mytypes import DocMetaData, Document, Entity
18
- from langroid.parsing.parser import ParsingConfig, Splitter
19
- from langroid.prompts.prompts_config import PromptsConfig
20
- from langroid.utils.constants import NO_ANSWER
21
- from langroid.vector_store.base import VectorStoreConfig
22
- from langroid.vector_store.qdrantdb import QdrantDBConfig
21
+ from langroid.mytypes import DocMetaData, Document
23
22
 
24
23
  console = Console()
25
24
  logger = logging.getLogger(__name__)
26
25
 
26
+ # for backwards compatibility:
27
+ RecordMetadata = DocMetaData
28
+ RecordDoc = Document
29
+ RetrieverAgentConfig = DocChatAgentConfig
27
30
 
28
- class RecordMetadata(DocMetaData):
29
- id: None | int | str = None
30
-
31
-
32
- class RecordDoc(Document):
33
- metadata: RecordMetadata
34
-
35
-
36
- class RetrieverAgentConfig(DocChatAgentConfig):
37
- n_matches: int = 3
38
- debug: bool = False
39
- max_context_tokens = 500
40
- conversation_mode = True
41
- cache: bool = True # cache results
42
- gpt4: bool = True # use GPT-4
43
- stream: bool = True # allow streaming where needed
44
- max_tokens: int = 10000
45
- vecdb: VectorStoreConfig = QdrantDBConfig(
46
- collection_name=None,
47
- storage_path=".qdrant/data/",
48
- embedding=OpenAIEmbeddingsConfig(
49
- model_type="openai",
50
- model_name="text-embedding-ada-002",
51
- dims=1536,
52
- ),
53
- )
54
31
 
55
- llm: OpenAIGPTConfig = OpenAIGPTConfig(
56
- type="openai",
57
- chat_model=OpenAIChatModel.GPT4,
58
- )
59
- parsing: ParsingConfig = ParsingConfig(
60
- splitter=Splitter.TOKENS,
61
- chunk_size=100,
62
- n_similar_docs=5,
63
- )
64
-
65
- prompts: PromptsConfig = PromptsConfig(
66
- max_tokens=1000,
67
- )
68
-
69
-
70
- class RetrieverAgent(DocChatAgent, ABC):
32
+ class RetrieverAgent(DocChatAgent):
71
33
  """
72
- Agent for retrieving whole records/docs matching a query
34
+ Agent for just retrieving chunks/docs/extracts matching a query
73
35
  """
74
36
 
75
- def __init__(self, config: RetrieverAgentConfig):
37
+ def __init__(self, config: DocChatAgentConfig):
76
38
  super().__init__(config)
77
- self.config: RetrieverAgentConfig = config
39
+ self.config: DocChatAgentConfig = config
40
+ logger.warning(
41
+ """
42
+ `RetrieverAgent` is deprecated. Use `DocChatAgent` instead, with
43
+ `DocChatAgentConfig.retrieve_only=True`, and if you want to retrieve
44
+ FULL relevant doc-contents rather than just extracts, then set
45
+ `DocChatAgentConfig.extraction_granularity=-1`
46
+ """
47
+ )
78
48
 
79
- @abstractmethod
80
- def get_records(self) -> Sequence[RecordDoc]:
81
- pass
49
+ def get_records(self) -> Sequence[Document]:
50
+ raise NotImplementedError
82
51
 
83
52
  def ingest(self) -> None:
84
53
  records = self.get_records()
85
54
  if self.vecdb is None:
86
55
  raise ValueError("No vector store specified")
87
56
  self.vecdb.add_documents(records)
88
-
89
- def llm_response(
90
- self,
91
- query: None | str | ChatDocument = None,
92
- ) -> Optional[ChatDocument]:
93
- if not self.llm_can_respond(query):
94
- return None
95
- if query is None:
96
- return super().llm_response(None) # type: ignore
97
- if isinstance(query, ChatDocument):
98
- query_str = query.content
99
- else:
100
- query_str = query
101
- docs = self.get_relevant_extracts(query_str)
102
- if len(docs) == 0:
103
- return None
104
- content = "\n\n".join([d.content for d in docs])
105
- print(f"[green]{content}")
106
- meta = dict(
107
- sender=Entity.LLM,
108
- )
109
- meta.update(docs[0].metadata)
110
-
111
- return ChatDocument(
112
- content=content,
113
- metadata=ChatDocMetaData(**meta),
114
- )
115
-
116
- def get_nearest_docs(self, query: str) -> List[Document]:
117
- """
118
- Given a query, get the records/docs whose contents are closest to the
119
- query, in terms of vector similarity.
120
- Args:
121
- query: query string
122
- Returns:
123
- list of Document objects
124
- """
125
- if self.vecdb is None:
126
- logger.warning("No vector store specified")
127
- return []
128
- with console.status("[cyan]Searching VecDB for similar docs/records..."):
129
- docs_and_scores = self.vecdb.similar_texts_with_scores(
130
- query,
131
- k=self.config.parsing.n_similar_docs,
132
- )
133
- docs: List[Document] = [
134
- Document(content=d.content, metadata=d.metadata)
135
- for (d, _) in docs_and_scores
136
- ]
137
- return docs
138
-
139
- def get_relevant_extracts(self, query: str) -> List[Document]:
140
- """
141
- Given a query, get the records/docs whose contents are most relevant to the
142
- query. First get nearest docs from vector store, then select the best
143
- matches according to the LLM.
144
- Args:
145
- query (str): query string
146
-
147
- Returns:
148
- List[Document]: list of Document objects
149
- """
150
- response = Document(
151
- content=NO_ANSWER,
152
- metadata=DocMetaData(
153
- source="None",
154
- ),
155
- )
156
- nearest_docs = self.get_nearest_docs(query)
157
- if len(nearest_docs) == 0:
158
- return [response]
159
- if self.llm is None:
160
- logger.warning("No LLM specified")
161
- return nearest_docs
162
- with console.status("LLM selecting relevant docs from retrieved ones..."):
163
- with StreamingIfAllowed(self.llm, False):
164
- doc_list = self.llm_select_relevant_docs(query, nearest_docs)
165
-
166
- return doc_list
167
-
168
- def llm_select_relevant_docs(
169
- self, query: str, docs: List[Document]
170
- ) -> List[Document]:
171
- """
172
- Given a query and a list of docs, select the docs whose contents match best,
173
- according to the LLM. Use the doc IDs to select the docs from the vector
174
- store.
175
- Args:
176
- query: query string
177
- docs: list of Document objects
178
- Returns:
179
- list of Document objects
180
- """
181
- doc_contents = "\n\n".join(
182
- [f"DOC: ID={d.id()}, CONTENT: {d.content}" for d in docs]
183
- )
184
- prompt = f"""
185
- Given the following QUERY:
186
- {query}
187
- and the following DOCS with IDs and contents
188
- {doc_contents}
189
-
190
- Find at most {self.config.n_matches} DOCs that are most relevant to the QUERY.
191
- Return your answer as a sequence of DOC IDS ONLY, for example:
192
- "id1 id2 id3..."
193
- If there are no relevant docs, simply say {NO_ANSWER}.
194
- Even if there is only one relevant doc, return it as a single ID.
195
- Do not give any explanations or justifications.
196
- """
197
- default_response = Document(
198
- content=NO_ANSWER,
199
- metadata=DocMetaData(
200
- source="None",
201
- ),
202
- )
203
-
204
- if self.llm is None:
205
- logger.warning("No LLM specified")
206
- return [default_response]
207
- response = self.llm.generate(
208
- prompt, max_tokens=self.config.llm.max_output_tokens
209
- )
210
- if response.message == NO_ANSWER:
211
- return [default_response]
212
- ids = response.message.split()
213
- if len(ids) == 0:
214
- return [default_response]
215
- if self.vecdb is None:
216
- logger.warning("No vector store specified")
217
- return [default_response]
218
- docs = self.vecdb.get_documents_by_ids(ids)
219
- return [
220
- Document(content=d.content, metadata=DocMetaData(source="LLM"))
221
- for d in docs
222
- ]
@@ -0,0 +1,11 @@
1
+ from .sql_chat_agent import SQLChatAgentConfig, SQLChatAgent
2
+
3
+ from . import sql_chat_agent
4
+ from . import utils
5
+
6
+ __all__ = [
7
+ "SQLChatAgentConfig",
8
+ "SQLChatAgent",
9
+ "sql_chat_agent",
10
+ "utils",
11
+ ]
@@ -6,12 +6,13 @@ Functionality includes:
6
6
  - adding table and column context
7
7
  - asking a question about a SQL schema
8
8
  """
9
+
9
10
  import logging
10
- from typing import Any, Dict, Optional, Sequence, Union
11
+ from typing import Any, Dict, List, Optional, Sequence, Union
11
12
 
12
13
  from rich import print
13
14
  from rich.console import Console
14
- from sqlalchemy import MetaData, Row, create_engine, text
15
+ from sqlalchemy import MetaData, Row, create_engine, inspect, text
15
16
  from sqlalchemy.engine import Engine
16
17
  from sqlalchemy.exc import SQLAlchemyError
17
18
  from sqlalchemy.orm import Session, sessionmaker
@@ -35,9 +36,7 @@ from langroid.agent.special.sql.utils.tools import (
35
36
  GetTableSchemaTool,
36
37
  RunQueryTool,
37
38
  )
38
- from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
39
39
  from langroid.mytypes import Entity
40
- from langroid.prompts.prompts_config import PromptsConfig
41
40
  from langroid.vector_store.base import VectorStoreConfig
42
41
 
43
42
  logger = logging.getLogger(__name__)
@@ -67,7 +66,6 @@ SQL_ERROR_MSG = "There was an error in your SQL Query"
67
66
  class SQLChatAgentConfig(ChatAgentConfig):
68
67
  system_message: str = DEFAULT_SQL_CHAT_SYSTEM_MESSAGE
69
68
  user_message: None | str = None
70
- max_context_tokens: int = 1000
71
69
  cache: bool = True # cache results
72
70
  debug: bool = False
73
71
  stream: bool = True # allow streaming where needed
@@ -76,6 +74,7 @@ class SQLChatAgentConfig(ChatAgentConfig):
76
74
  vecdb: None | VectorStoreConfig = None
77
75
  context_descriptions: Dict[str, Dict[str, Union[str, Dict[str, str]]]] = {}
78
76
  use_schema_tools: bool = False
77
+ multi_schema: bool = False
79
78
 
80
79
  """
81
80
  Optional, but strongly recommended, context descriptions for tables, columns,
@@ -90,6 +89,9 @@ class SQLChatAgentConfig(ChatAgentConfig):
90
89
  is another table name and the value is a description of the relationship to
91
90
  that table.
92
91
 
92
+ If multi_schema support is enabled, the tables names in the description
93
+ should be of the form 'schema_name.table_name'.
94
+
93
95
  For example:
94
96
  {
95
97
  'table1': {
@@ -109,15 +111,6 @@ class SQLChatAgentConfig(ChatAgentConfig):
109
111
  }
110
112
  """
111
113
 
112
- llm: OpenAIGPTConfig = OpenAIGPTConfig(
113
- type="openai",
114
- chat_model=OpenAIChatModel.GPT4,
115
- completion_model=OpenAIChatModel.GPT4,
116
- )
117
- prompts: PromptsConfig = PromptsConfig(
118
- max_tokens=1000,
119
- )
120
-
121
114
 
122
115
  class SQLChatAgent(ChatAgent):
123
116
  """
@@ -155,19 +148,44 @@ class SQLChatAgent(ChatAgent):
155
148
  """Initialize the database metadata."""
156
149
  if self.engine is None:
157
150
  raise ValueError("Database engine is None")
151
+ self.metadata: MetaData | List[MetaData] = []
158
152
 
159
- self.metadata = MetaData()
160
- self.metadata.reflect(self.engine)
161
- logger.info(
162
- "SQLChatAgent initialized with database: %s and tables: %s",
163
- self.engine,
164
- self.metadata.tables,
165
- )
153
+ if self.config.multi_schema:
154
+ logger.info(
155
+ "Initializing SQLChatAgent with database: %s",
156
+ self.engine,
157
+ )
158
+
159
+ self.metadata = []
160
+ inspector = inspect(self.engine)
161
+
162
+ for schema in inspector.get_schema_names():
163
+ metadata = MetaData(schema=schema)
164
+ metadata.reflect(self.engine)
165
+ self.metadata.append(metadata)
166
+
167
+ logger.info(
168
+ "Initializing SQLChatAgent with database: %s, schema: %s, "
169
+ "and tables: %s",
170
+ self.engine,
171
+ schema,
172
+ metadata.tables,
173
+ )
174
+ else:
175
+ self.metadata = MetaData()
176
+ self.metadata.reflect(self.engine)
177
+ logger.info(
178
+ "SQLChatAgent initialized with database: %s and tables: %s",
179
+ self.engine,
180
+ self.metadata.tables,
181
+ )
166
182
 
167
183
  def _init_table_metadata(self) -> None:
168
184
  """Initialize metadata for the tables present in the database."""
169
185
  if not self.config.context_descriptions and isinstance(self.engine, Engine):
170
- self.config.context_descriptions = extract_schema_descriptions(self.engine)
186
+ self.config.context_descriptions = extract_schema_descriptions(
187
+ self.engine, self.config.multi_schema
188
+ )
171
189
 
172
190
  if self.config.use_schema_tools:
173
191
  self.table_metadata = populate_metadata_with_schema_tools(
@@ -228,8 +246,10 @@ class SQLChatAgent(ChatAgent):
228
246
  if isinstance(msg, ChatDocument) and msg.function_call is not None:
229
247
  sender_name = msg.function_call.name
230
248
 
249
+ content = results.content if isinstance(results, ChatDocument) else results
250
+
231
251
  return ChatDocument(
232
- content=results,
252
+ content=content,
233
253
  metadata=ChatDocMetaData(
234
254
  source=Entity.AGENT,
235
255
  sender=Entity.AGENT,
@@ -329,6 +349,10 @@ class SQLChatAgent(ChatAgent):
329
349
  Returns:
330
350
  str: The names of all tables in the database.
331
351
  """
352
+ if isinstance(self.metadata, list):
353
+ table_names = [", ".join(md.tables.keys()) for md in self.metadata]
354
+ return ", ".join(table_names)
355
+
332
356
  return ", ".join(self.metadata.tables.keys())
333
357
 
334
358
  def get_table_schema(self, msg: GetTableSchemaTool) -> str:
@@ -0,0 +1,22 @@
1
+ from .tools import (
2
+ RunQueryTool,
3
+ GetTableNamesTool,
4
+ GetTableSchemaTool,
5
+ GetColumnDescriptionsTool,
6
+ )
7
+
8
+ from . import description_extractors
9
+ from . import populate_metadata
10
+ from . import system_message
11
+ from . import tools
12
+
13
+ __all__ = [
14
+ "RunQueryTool",
15
+ "GetTableNamesTool",
16
+ "GetTableSchemaTool",
17
+ "GetColumnDescriptionsTool",
18
+ "description_extractors",
19
+ "populate_metadata",
20
+ "system_message",
21
+ "tools",
22
+ ]