langroid 0.1.139__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. langroid/__init__.py +70 -0
  2. langroid/agent/__init__.py +22 -0
  3. langroid/agent/base.py +120 -33
  4. langroid/agent/batch.py +134 -35
  5. langroid/agent/callbacks/__init__.py +0 -0
  6. langroid/agent/callbacks/chainlit.py +608 -0
  7. langroid/agent/chat_agent.py +164 -100
  8. langroid/agent/chat_document.py +19 -2
  9. langroid/agent/openai_assistant.py +20 -10
  10. langroid/agent/special/__init__.py +33 -10
  11. langroid/agent/special/doc_chat_agent.py +521 -108
  12. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  13. langroid/agent/special/lance_rag/__init__.py +9 -0
  14. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  15. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  16. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  17. langroid/agent/special/lance_tools.py +44 -0
  18. langroid/agent/special/neo4j/__init__.py +0 -0
  19. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  20. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  21. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  22. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  23. langroid/agent/special/relevance_extractor_agent.py +23 -7
  24. langroid/agent/special/retriever_agent.py +29 -174
  25. langroid/agent/special/sql/__init__.py +7 -0
  26. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  27. langroid/agent/special/sql/utils/__init__.py +11 -0
  28. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  29. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  30. langroid/agent/special/table_chat_agent.py +43 -9
  31. langroid/agent/task.py +423 -114
  32. langroid/agent/tool_message.py +67 -10
  33. langroid/agent/tools/__init__.py +8 -0
  34. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  35. langroid/agent/tools/google_search_tool.py +11 -0
  36. langroid/agent/tools/metaphor_search_tool.py +67 -0
  37. langroid/agent/tools/recipient_tool.py +6 -24
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/cachedb/__init__.py +6 -0
  40. langroid/embedding_models/__init__.py +24 -0
  41. langroid/embedding_models/base.py +9 -1
  42. langroid/embedding_models/models.py +117 -17
  43. langroid/embedding_models/protoc/embeddings.proto +19 -0
  44. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  45. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  46. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  47. langroid/embedding_models/remote_embeds.py +153 -0
  48. langroid/language_models/__init__.py +22 -0
  49. langroid/language_models/azure_openai.py +47 -4
  50. langroid/language_models/base.py +26 -10
  51. langroid/language_models/config.py +5 -0
  52. langroid/language_models/openai_gpt.py +407 -121
  53. langroid/language_models/prompt_formatter/__init__.py +9 -0
  54. langroid/language_models/prompt_formatter/base.py +4 -6
  55. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  56. langroid/language_models/utils.py +10 -9
  57. langroid/mytypes.py +10 -4
  58. langroid/parsing/__init__.py +33 -1
  59. langroid/parsing/document_parser.py +259 -63
  60. langroid/parsing/image_text.py +32 -0
  61. langroid/parsing/parse_json.py +143 -0
  62. langroid/parsing/parser.py +20 -7
  63. langroid/parsing/repo_loader.py +108 -46
  64. langroid/parsing/search.py +8 -0
  65. langroid/parsing/table_loader.py +44 -0
  66. langroid/parsing/url_loader.py +59 -13
  67. langroid/parsing/urls.py +18 -9
  68. langroid/parsing/utils.py +130 -9
  69. langroid/parsing/web_search.py +73 -0
  70. langroid/prompts/__init__.py +7 -0
  71. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  72. langroid/prompts/prompts_config.py +1 -1
  73. langroid/utils/__init__.py +10 -0
  74. langroid/utils/algorithms/__init__.py +3 -0
  75. langroid/utils/configuration.py +0 -1
  76. langroid/utils/constants.py +4 -0
  77. langroid/utils/logging.py +2 -5
  78. langroid/utils/output/__init__.py +15 -2
  79. langroid/utils/output/status.py +33 -0
  80. langroid/utils/pandas_utils.py +30 -0
  81. langroid/utils/pydantic_utils.py +446 -4
  82. langroid/utils/system.py +36 -1
  83. langroid/vector_store/__init__.py +34 -2
  84. langroid/vector_store/base.py +33 -2
  85. langroid/vector_store/chromadb.py +42 -13
  86. langroid/vector_store/lancedb.py +226 -60
  87. langroid/vector_store/meilisearch.py +7 -6
  88. langroid/vector_store/momento.py +3 -2
  89. langroid/vector_store/qdrantdb.py +82 -11
  90. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/METADATA +190 -129
  91. langroid-0.1.219.dist-info/RECORD +127 -0
  92. langroid/agent/special/recipient_validator_agent.py +0 -157
  93. langroid/parsing/json.py +0 -64
  94. langroid/utils/web/selenium_login.py +0 -36
  95. langroid-0.1.139.dist-info/RECORD +0 -103
  96. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
  97. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/WHEEL +0 -0
@@ -3,6 +3,7 @@ Agent to retrieve relevant segments from a body of text,
3
3
  that are relevant to a query.
4
4
 
5
5
  """
6
+
6
7
  import logging
7
8
  from typing import Optional, no_type_check
8
9
 
@@ -11,16 +12,18 @@ from rich.console import Console
11
12
  from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
12
13
  from langroid.agent.chat_document import ChatDocument
13
14
  from langroid.agent.tools.segment_extract_tool import SegmentExtractTool
15
+ from langroid.language_models.base import LLMConfig
14
16
  from langroid.language_models.openai_gpt import OpenAIGPTConfig
17
+ from langroid.mytypes import Entity
15
18
  from langroid.parsing.utils import extract_numbered_segments, number_segments
16
- from langroid.utils.constants import NO_ANSWER
19
+ from langroid.utils.constants import DONE, NO_ANSWER
17
20
 
18
21
  console = Console()
19
22
  logger = logging.getLogger(__name__)
20
23
 
21
24
 
22
25
  class RelevanceExtractorAgentConfig(ChatAgentConfig):
23
- llm: OpenAIGPTConfig = OpenAIGPTConfig()
26
+ llm: LLMConfig | None = OpenAIGPTConfig()
24
27
  segment_length: int = 1 # number of sentences per segment
25
28
  query: str = "" # query for relevance extraction
26
29
  system_message = """
@@ -28,6 +31,7 @@ class RelevanceExtractorAgentConfig(ChatAgentConfig):
28
31
  <#1#>, <#2#>, <#3#>, etc.,
29
32
  followed by a QUERY. Extract ONLY the segment-numbers from
30
33
  the PASSAGE that are RELEVANT to the QUERY.
34
+ Present the extracted segment-numbers using the `extract_segments` tool/function.
31
35
  """
32
36
 
33
37
 
@@ -101,11 +105,23 @@ class RelevanceExtractorAgent(ChatAgent):
101
105
  """Method to handle a segmentExtractTool message from LLM"""
102
106
  spec = msg.segment_list
103
107
  if len(self.message_history) == 0:
104
- return NO_ANSWER
105
- if spec is None or spec.strip() == "":
106
- return NO_ANSWER
108
+ return DONE + " " + NO_ANSWER
109
+ if spec is None or spec.strip() in ["", NO_ANSWER]:
110
+ return DONE + " " + NO_ANSWER
107
111
  assert self.numbered_passage is not None, "No numbered passage"
108
112
  # assume this has numbered segments
109
- extracts = extract_numbered_segments(self.numbered_passage, spec)
113
+ try:
114
+ extracts = extract_numbered_segments(self.numbered_passage, spec)
115
+ except Exception:
116
+ return DONE + " " + NO_ANSWER
110
117
  # this response ends the task by saying DONE
111
- return "DONE " + extracts
118
+ return DONE + " " + extracts
119
+
120
+ def handle_message_fallback(
121
+ self, msg: str | ChatDocument
122
+ ) -> str | ChatDocument | None:
123
+ """Handle case where LLM forgets to use SegmentExtractTool"""
124
+ if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM:
125
+ return DONE + " " + NO_ANSWER
126
+ else:
127
+ return None
@@ -1,201 +1,56 @@
1
1
  """
2
- Agent to retrieve relevant verbatim whole docs/records from a vector store,
2
+ Deprecated: use DocChatAgent instead, with DocChatAgentConfig.retrieve_only=True,
3
+ and if you want to retrieve FULL relevant doc-contents rather than just extracts,
4
+ then set DocChatAgentConfig.extraction_granularity=-1
5
+
6
+ This is an agent to retrieve relevant extracts from a vector store,
3
7
  where the LLM is used to filter for "true" relevance after retrieval from the
4
8
  vector store.
9
+ This is essentially the same as DocChatAgent, except that instead of
10
+ generating final summary answer based on relevant extracts, it just returns
11
+ those extracts.
5
12
  See test_retriever_agent.py for example usage.
6
13
  """
14
+
7
15
  import logging
8
- from abc import ABC, abstractmethod
9
- from typing import List, Optional, Sequence
16
+ from typing import Sequence
10
17
 
11
- from rich import print
12
18
  from rich.console import Console
13
19
 
14
- from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
15
20
  from langroid.agent.special.doc_chat_agent import DocChatAgent, DocChatAgentConfig
16
- from langroid.embedding_models.models import OpenAIEmbeddingsConfig
17
- from langroid.language_models.base import StreamingIfAllowed
18
- from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
19
- from langroid.mytypes import DocMetaData, Document, Entity
20
- from langroid.parsing.parser import ParsingConfig, Splitter
21
- from langroid.prompts.prompts_config import PromptsConfig
22
- from langroid.utils.constants import NO_ANSWER
23
- from langroid.vector_store.base import VectorStoreConfig
24
- from langroid.vector_store.qdrantdb import QdrantDBConfig
21
+ from langroid.mytypes import DocMetaData, Document
25
22
 
26
23
  console = Console()
27
24
  logger = logging.getLogger(__name__)
28
25
 
26
+ # for backwards compatibility:
27
+ RecordMetadata = DocMetaData
28
+ RecordDoc = Document
29
+ RetrieverAgentConfig = DocChatAgentConfig
29
30
 
30
- class RecordMetadata(DocMetaData):
31
- id: None | str = None
32
-
33
-
34
- class RecordDoc(Document):
35
- metadata: RecordMetadata
36
-
37
-
38
- class RetrieverAgentConfig(DocChatAgentConfig):
39
- n_matches: int = 3
40
- debug: bool = False
41
- max_context_tokens = 500
42
- conversation_mode = True
43
- cache: bool = True # cache results
44
- gpt4: bool = True # use GPT-4
45
- stream: bool = True # allow streaming where needed
46
- max_tokens: int = 10000
47
- vecdb: VectorStoreConfig = QdrantDBConfig(
48
- collection_name=None,
49
- storage_path=".qdrant/data/",
50
- embedding=OpenAIEmbeddingsConfig(
51
- model_type="openai",
52
- model_name="text-embedding-ada-002",
53
- dims=1536,
54
- ),
55
- )
56
-
57
- llm: OpenAIGPTConfig = OpenAIGPTConfig(
58
- type="openai",
59
- chat_model=OpenAIChatModel.GPT4,
60
- )
61
- parsing: ParsingConfig = ParsingConfig(
62
- splitter=Splitter.TOKENS,
63
- chunk_size=100,
64
- n_similar_docs=5,
65
- )
66
-
67
- prompts: PromptsConfig = PromptsConfig(
68
- max_tokens=1000,
69
- )
70
31
 
71
-
72
- class RetrieverAgent(DocChatAgent, ABC):
32
+ class RetrieverAgent(DocChatAgent):
73
33
  """
74
- Agent for retrieving whole records/docs matching a query
34
+ Agent for just retrieving chunks/docs/extracts matching a query
75
35
  """
76
36
 
77
- def __init__(self, config: RetrieverAgentConfig):
37
+ def __init__(self, config: DocChatAgentConfig):
78
38
  super().__init__(config)
79
- self.config: RetrieverAgentConfig = config
39
+ self.config: DocChatAgentConfig = config
40
+ logger.warning(
41
+ """
42
+ `RetrieverAgent` is deprecated. Use `DocChatAgent` instead, with
43
+ `DocChatAgentConfig.retrieve_only=True`, and if you want to retrieve
44
+ FULL relevant doc-contents rather than just extracts, then set
45
+ `DocChatAgentConfig.extraction_granularity=-1`
46
+ """
47
+ )
80
48
 
81
- @abstractmethod
82
- def get_records(self) -> Sequence[RecordDoc]:
83
- pass
49
+ def get_records(self) -> Sequence[Document]:
50
+ raise NotImplementedError
84
51
 
85
52
  def ingest(self) -> None:
86
53
  records = self.get_records()
87
54
  if self.vecdb is None:
88
55
  raise ValueError("No vector store specified")
89
56
  self.vecdb.add_documents(records)
90
-
91
- def llm_response(
92
- self,
93
- query: None | str | ChatDocument = None,
94
- ) -> Optional[ChatDocument]:
95
- if not self.llm_can_respond(query):
96
- return None
97
- if query is None:
98
- return super().llm_response(None) # type: ignore
99
- if isinstance(query, ChatDocument):
100
- query_str = query.content
101
- else:
102
- query_str = query
103
- docs = self.get_relevant_extracts(query_str)
104
- if len(docs) == 0:
105
- return None
106
- content = "\n\n".join([d.content for d in docs])
107
- print(f"[green]{content}")
108
- meta = dict(
109
- sender=Entity.LLM,
110
- )
111
- meta.update(docs[0].metadata)
112
-
113
- return ChatDocument(
114
- content=content,
115
- metadata=ChatDocMetaData(**meta),
116
- )
117
-
118
- def get_relevant_extracts(self, query: str) -> List[Document]:
119
- """
120
- Given a query, get the records/docs whose contents are most relevant to the
121
- query. First get nearest docs from vector store, then select the best
122
- matches according to the LLM.
123
- Args:
124
- query (str): query string
125
-
126
- Returns:
127
- List[Document]: list of Document objects
128
- """
129
- response = Document(
130
- content=NO_ANSWER,
131
- metadata=DocMetaData(
132
- source="None",
133
- ),
134
- )
135
- nearest_docs = self.get_relevant_chunks(query)
136
- if len(nearest_docs) == 0:
137
- return [response]
138
- if self.llm is None:
139
- logger.warning("No LLM specified")
140
- return nearest_docs
141
- with console.status("LLM selecting relevant docs from retrieved ones..."):
142
- with StreamingIfAllowed(self.llm, False):
143
- doc_list = self.llm_select_relevant_docs(query, nearest_docs)
144
-
145
- return doc_list
146
-
147
- def llm_select_relevant_docs(
148
- self, query: str, docs: List[Document]
149
- ) -> List[Document]:
150
- """
151
- Given a query and a list of docs, select the docs whose contents match best,
152
- according to the LLM. Use the doc IDs to select the docs from the vector
153
- store.
154
- Args:
155
- query: query string
156
- docs: list of Document objects
157
- Returns:
158
- list of Document objects
159
- """
160
- doc_contents = "\n\n".join(
161
- [f"DOC: ID={d.id()}, CONTENT: {d.content}" for d in docs]
162
- )
163
- prompt = f"""
164
- Given the following QUERY:
165
- {query}
166
- and the following DOCS with IDs and contents
167
- {doc_contents}
168
-
169
- Find at most {self.config.n_matches} DOCs that are most relevant to the QUERY.
170
- Return your answer as a sequence of DOC IDS ONLY, for example:
171
- "id1 id2 id3..."
172
- If there are no relevant docs, simply say {NO_ANSWER}.
173
- Even if there is only one relevant doc, return it as a single ID.
174
- Do not give any explanations or justifications.
175
- """
176
- default_response = Document(
177
- content=NO_ANSWER,
178
- metadata=DocMetaData(
179
- source="None",
180
- ),
181
- )
182
-
183
- if self.llm is None:
184
- logger.warning("No LLM specified")
185
- return [default_response]
186
- response = self.llm.generate(
187
- prompt, max_tokens=self.config.llm.max_output_tokens
188
- )
189
- if response.message == NO_ANSWER:
190
- return [default_response]
191
- ids = response.message.split()
192
- if len(ids) == 0:
193
- return [default_response]
194
- if self.vecdb is None:
195
- logger.warning("No vector store specified")
196
- return [default_response]
197
- docs = self.vecdb.get_documents_by_ids(ids)
198
- return [
199
- Document(content=d.content, metadata=DocMetaData(source="LLM"))
200
- for d in docs
201
- ]
@@ -2,3 +2,10 @@ from .sql_chat_agent import SQLChatAgentConfig, SQLChatAgent
2
2
 
3
3
  from . import sql_chat_agent
4
4
  from . import utils
5
+
6
+ __all__ = [
7
+ "SQLChatAgentConfig",
8
+ "SQLChatAgent",
9
+ "sql_chat_agent",
10
+ "utils",
11
+ ]
@@ -6,12 +6,13 @@ Functionality includes:
6
6
  - adding table and column context
7
7
  - asking a question about a SQL schema
8
8
  """
9
+
9
10
  import logging
10
- from typing import Any, Dict, Optional, Sequence, Union
11
+ from typing import Any, Dict, List, Optional, Sequence, Union
11
12
 
12
13
  from rich import print
13
14
  from rich.console import Console
14
- from sqlalchemy import MetaData, Row, create_engine, text
15
+ from sqlalchemy import MetaData, Row, create_engine, inspect, text
15
16
  from sqlalchemy.engine import Engine
16
17
  from sqlalchemy.exc import SQLAlchemyError
17
18
  from sqlalchemy.orm import Session, sessionmaker
@@ -35,9 +36,7 @@ from langroid.agent.special.sql.utils.tools import (
35
36
  GetTableSchemaTool,
36
37
  RunQueryTool,
37
38
  )
38
- from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
39
39
  from langroid.mytypes import Entity
40
- from langroid.prompts.prompts_config import PromptsConfig
41
40
  from langroid.vector_store.base import VectorStoreConfig
42
41
 
43
42
  logger = logging.getLogger(__name__)
@@ -67,7 +66,6 @@ SQL_ERROR_MSG = "There was an error in your SQL Query"
67
66
  class SQLChatAgentConfig(ChatAgentConfig):
68
67
  system_message: str = DEFAULT_SQL_CHAT_SYSTEM_MESSAGE
69
68
  user_message: None | str = None
70
- max_context_tokens: int = 1000
71
69
  cache: bool = True # cache results
72
70
  debug: bool = False
73
71
  stream: bool = True # allow streaming where needed
@@ -76,6 +74,7 @@ class SQLChatAgentConfig(ChatAgentConfig):
76
74
  vecdb: None | VectorStoreConfig = None
77
75
  context_descriptions: Dict[str, Dict[str, Union[str, Dict[str, str]]]] = {}
78
76
  use_schema_tools: bool = False
77
+ multi_schema: bool = False
79
78
 
80
79
  """
81
80
  Optional, but strongly recommended, context descriptions for tables, columns,
@@ -90,6 +89,9 @@ class SQLChatAgentConfig(ChatAgentConfig):
90
89
  is another table name and the value is a description of the relationship to
91
90
  that table.
92
91
 
92
+ If multi_schema support is enabled, the tables names in the description
93
+ should be of the form 'schema_name.table_name'.
94
+
93
95
  For example:
94
96
  {
95
97
  'table1': {
@@ -109,15 +111,6 @@ class SQLChatAgentConfig(ChatAgentConfig):
109
111
  }
110
112
  """
111
113
 
112
- llm: OpenAIGPTConfig = OpenAIGPTConfig(
113
- type="openai",
114
- chat_model=OpenAIChatModel.GPT4,
115
- completion_model=OpenAIChatModel.GPT4,
116
- )
117
- prompts: PromptsConfig = PromptsConfig(
118
- max_tokens=1000,
119
- )
120
-
121
114
 
122
115
  class SQLChatAgent(ChatAgent):
123
116
  """
@@ -155,19 +148,44 @@ class SQLChatAgent(ChatAgent):
155
148
  """Initialize the database metadata."""
156
149
  if self.engine is None:
157
150
  raise ValueError("Database engine is None")
151
+ self.metadata: MetaData | List[MetaData] = []
158
152
 
159
- self.metadata = MetaData()
160
- self.metadata.reflect(self.engine)
161
- logger.info(
162
- "SQLChatAgent initialized with database: %s and tables: %s",
163
- self.engine,
164
- self.metadata.tables,
165
- )
153
+ if self.config.multi_schema:
154
+ logger.info(
155
+ "Initializing SQLChatAgent with database: %s",
156
+ self.engine,
157
+ )
158
+
159
+ self.metadata = []
160
+ inspector = inspect(self.engine)
161
+
162
+ for schema in inspector.get_schema_names():
163
+ metadata = MetaData(schema=schema)
164
+ metadata.reflect(self.engine)
165
+ self.metadata.append(metadata)
166
+
167
+ logger.info(
168
+ "Initializing SQLChatAgent with database: %s, schema: %s, "
169
+ "and tables: %s",
170
+ self.engine,
171
+ schema,
172
+ metadata.tables,
173
+ )
174
+ else:
175
+ self.metadata = MetaData()
176
+ self.metadata.reflect(self.engine)
177
+ logger.info(
178
+ "SQLChatAgent initialized with database: %s and tables: %s",
179
+ self.engine,
180
+ self.metadata.tables,
181
+ )
166
182
 
167
183
  def _init_table_metadata(self) -> None:
168
184
  """Initialize metadata for the tables present in the database."""
169
185
  if not self.config.context_descriptions and isinstance(self.engine, Engine):
170
- self.config.context_descriptions = extract_schema_descriptions(self.engine)
186
+ self.config.context_descriptions = extract_schema_descriptions(
187
+ self.engine, self.config.multi_schema
188
+ )
171
189
 
172
190
  if self.config.use_schema_tools:
173
191
  self.table_metadata = populate_metadata_with_schema_tools(
@@ -228,8 +246,10 @@ class SQLChatAgent(ChatAgent):
228
246
  if isinstance(msg, ChatDocument) and msg.function_call is not None:
229
247
  sender_name = msg.function_call.name
230
248
 
249
+ content = results.content if isinstance(results, ChatDocument) else results
250
+
231
251
  return ChatDocument(
232
- content=results,
252
+ content=content,
233
253
  metadata=ChatDocMetaData(
234
254
  source=Entity.AGENT,
235
255
  sender=Entity.AGENT,
@@ -329,6 +349,10 @@ class SQLChatAgent(ChatAgent):
329
349
  Returns:
330
350
  str: The names of all tables in the database.
331
351
  """
352
+ if isinstance(self.metadata, list):
353
+ table_names = [", ".join(md.tables.keys()) for md in self.metadata]
354
+ return ", ".join(table_names)
355
+
332
356
  return ", ".join(self.metadata.tables.keys())
333
357
 
334
358
  def get_table_schema(self, msg: GetTableSchemaTool) -> str:
@@ -9,3 +9,14 @@ from . import description_extractors
9
9
  from . import populate_metadata
10
10
  from . import system_message
11
11
  from . import tools
12
+
13
+ __all__ = [
14
+ "RunQueryTool",
15
+ "GetTableNamesTool",
16
+ "GetTableSchemaTool",
17
+ "GetColumnDescriptionsTool",
18
+ "description_extractors",
19
+ "populate_metadata",
20
+ "system_message",
21
+ "tools",
22
+ ]
@@ -1,10 +1,13 @@
1
- from typing import Any, Dict, List
1
+ from typing import Any, Dict, List, Optional
2
2
 
3
3
  from sqlalchemy import inspect, text
4
4
  from sqlalchemy.engine import Engine
5
5
 
6
6
 
7
- def extract_postgresql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
7
+ def extract_postgresql_descriptions(
8
+ engine: Engine,
9
+ multi_schema: bool = False,
10
+ ) -> Dict[str, Dict[str, Any]]:
8
11
  """
9
12
  Extracts descriptions for tables and columns from a PostgreSQL database.
10
13
 
@@ -13,6 +16,7 @@ def extract_postgresql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]
13
16
 
14
17
  Args:
15
18
  engine (Engine): SQLAlchemy engine connected to a PostgreSQL database.
19
+ multi_schema (bool): Generate descriptions for all schemas in the database.
16
20
 
17
21
  Returns:
18
22
  Dict[str, Dict[str, Any]]: A dictionary mapping table names to a
@@ -20,36 +24,53 @@ def extract_postgresql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]
20
24
  column descriptions.
21
25
  """
22
26
  inspector = inspect(engine)
23
- table_names: List[str] = inspector.get_table_names()
24
-
25
27
  result: Dict[str, Dict[str, Any]] = {}
26
28
 
27
- with engine.connect() as conn:
28
- for table in table_names:
29
- table_comment = (
30
- conn.execute(
31
- text(f"SELECT obj_description('{table}'::regclass)")
32
- ).scalar()
33
- or ""
34
- )
29
+ def gen_schema_descriptions(schema: Optional[str] = None) -> None:
30
+ table_names: List[str] = inspector.get_table_names(schema=schema)
31
+ with engine.connect() as conn:
32
+ for table in table_names:
33
+ if schema is None:
34
+ table_name = table
35
+ else:
36
+ table_name = f"{schema}.{table}"
35
37
 
36
- columns = {}
37
- col_data = inspector.get_columns(table)
38
- for idx, col in enumerate(col_data, start=1):
39
- col_comment = (
38
+ table_comment = (
40
39
  conn.execute(
41
- text(f"SELECT col_description('{table}'::regclass, {idx})")
40
+ text(f"SELECT obj_description('{table_name}'::regclass)")
42
41
  ).scalar()
43
42
  or ""
44
43
  )
45
- columns[col["name"]] = col_comment
46
44
 
47
- result[table] = {"description": table_comment, "columns": columns}
45
+ columns = {}
46
+ col_data = inspector.get_columns(table, schema=schema)
47
+ for idx, col in enumerate(col_data, start=1):
48
+ col_comment = (
49
+ conn.execute(
50
+ text(
51
+ f"SELECT col_description('{table_name}'::regclass, "
52
+ f"{idx})"
53
+ )
54
+ ).scalar()
55
+ or ""
56
+ )
57
+ columns[col["name"]] = col_comment
58
+
59
+ result[table_name] = {"description": table_comment, "columns": columns}
60
+
61
+ if multi_schema:
62
+ for schema in inspector.get_schema_names():
63
+ gen_schema_descriptions(schema)
64
+ else:
65
+ gen_schema_descriptions()
48
66
 
49
67
  return result
50
68
 
51
69
 
52
- def extract_mysql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
70
+ def extract_mysql_descriptions(
71
+ engine: Engine,
72
+ multi_schema: bool = False,
73
+ ) -> Dict[str, Dict[str, Any]]:
53
74
  """Extracts descriptions for tables and columns from a MySQL database.
54
75
 
55
76
  This method retrieves the descriptions of tables and their columns
@@ -57,6 +78,7 @@ def extract_mysql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
57
78
 
58
79
  Args:
59
80
  engine (Engine): SQLAlchemy engine connected to a MySQL database.
81
+ multi_schema (bool): Generate descriptions for all schemas in the database.
60
82
 
61
83
  Returns:
62
84
  Dict[str, Dict[str, Any]]: A dictionary mapping table names to a
@@ -64,31 +86,45 @@ def extract_mysql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
64
86
  column descriptions.
65
87
  """
66
88
  inspector = inspect(engine)
67
- table_names: List[str] = inspector.get_table_names()
68
-
69
89
  result: Dict[str, Dict[str, Any]] = {}
70
90
 
71
- with engine.connect() as conn:
72
- for table in table_names:
73
- query = text(
74
- "SELECT table_comment FROM information_schema.tables WHERE"
75
- " table_schema = :schema AND table_name = :table"
76
- )
77
- table_result = conn.execute(
78
- query, {"schema": engine.url.database, "table": table}
79
- )
80
- table_comment = table_result.scalar() or ""
91
+ def gen_schema_descriptions(schema: Optional[str] = None) -> None:
92
+ table_names: List[str] = inspector.get_table_names(schema=schema)
81
93
 
82
- columns = {}
83
- for col in inspector.get_columns(table):
84
- columns[col["name"]] = col.get("comment", "")
94
+ with engine.connect() as conn:
95
+ for table in table_names:
96
+ if schema is None:
97
+ table_name = table
98
+ else:
99
+ table_name = f"{schema}.{table}"
100
+
101
+ query = text(
102
+ "SELECT table_comment FROM information_schema.tables WHERE"
103
+ " table_schema = :schema AND table_name = :table"
104
+ )
105
+ table_result = conn.execute(
106
+ query, {"schema": engine.url.database, "table": table_name}
107
+ )
108
+ table_comment = table_result.scalar() or ""
109
+
110
+ columns = {}
111
+ for col in inspector.get_columns(table, schema=schema):
112
+ columns[col["name"]] = col.get("comment", "")
113
+
114
+ result[table_name] = {"description": table_comment, "columns": columns}
85
115
 
86
- result[table] = {"description": table_comment, "columns": columns}
116
+ if multi_schema:
117
+ for schema in inspector.get_schema_names():
118
+ gen_schema_descriptions(schema)
119
+ else:
120
+ gen_schema_descriptions()
87
121
 
88
122
  return result
89
123
 
90
124
 
91
- def extract_default_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
125
+ def extract_default_descriptions(
126
+ engine: Engine, multi_schema: bool = False
127
+ ) -> Dict[str, Dict[str, Any]]:
92
128
  """Extracts default descriptions for tables and columns from a database.
93
129
 
94
130
  This method retrieves the table and column names from the given database
@@ -96,6 +132,7 @@ def extract_default_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
96
132
 
97
133
  Args:
98
134
  engine (Engine): SQLAlchemy engine connected to a database.
135
+ multi_schema (bool): Generate descriptions for all schemas in the database.
99
136
 
100
137
  Returns:
101
138
  Dict[str, Dict[str, Any]]: A dictionary mapping table names to a
@@ -103,26 +140,36 @@ def extract_default_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
103
140
  empty column descriptions.
104
141
  """
105
142
  inspector = inspect(engine)
106
- table_names: List[str] = inspector.get_table_names()
107
-
108
143
  result: Dict[str, Dict[str, Any]] = {}
109
144
 
110
- for table in table_names:
111
- columns = {}
112
- for col in inspector.get_columns(table):
113
- columns[col["name"]] = ""
145
+ def gen_schema_descriptions(schema: Optional[str] = None) -> None:
146
+ table_names: List[str] = inspector.get_table_names(schema=schema)
147
+
148
+ for table in table_names:
149
+ columns = {}
150
+ for col in inspector.get_columns(table):
151
+ columns[col["name"]] = ""
152
+
153
+ result[table] = {"description": "", "columns": columns}
114
154
 
115
- result[table] = {"description": "", "columns": columns}
155
+ if multi_schema:
156
+ for schema in inspector.get_schema_names():
157
+ gen_schema_descriptions(schema)
158
+ else:
159
+ gen_schema_descriptions()
116
160
 
117
161
  return result
118
162
 
119
163
 
120
- def extract_schema_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
164
+ def extract_schema_descriptions(
165
+ engine: Engine, multi_schema: bool = False
166
+ ) -> Dict[str, Dict[str, Any]]:
121
167
  """
122
168
  Extracts the schema descriptions from the database connected to by the engine.
123
169
 
124
170
  Args:
125
171
  engine (Engine): SQLAlchemy engine instance.
172
+ multi_schema (bool): Generate descriptions for all schemas in the database.
126
173
 
127
174
  Returns:
128
175
  Dict[str, Dict[str, Any]]: A dictionary representation of table and column
@@ -133,4 +180,6 @@ def extract_schema_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
133
180
  "postgresql": extract_postgresql_descriptions,
134
181
  "mysql": extract_mysql_descriptions,
135
182
  }
136
- return extractors.get(engine.dialect.name, extract_default_descriptions)(engine)
183
+ return extractors.get(engine.dialect.name, extract_default_descriptions)(
184
+ engine, multi_schema=multi_schema
185
+ )