langroid 0.33.4__py3-none-any.whl → 0.33.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +106 -0
- langroid/agent/__init__.py +41 -0
- langroid/agent/base.py +1983 -0
- langroid/agent/batch.py +398 -0
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +598 -0
- langroid/agent/chat_agent.py +1899 -0
- langroid/agent/chat_document.py +454 -0
- langroid/agent/openai_assistant.py +882 -0
- langroid/agent/special/__init__.py +59 -0
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +656 -0
- langroid/agent/special/arangodb/system_messages.py +186 -0
- langroid/agent/special/arangodb/tools.py +107 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/doc_chat_agent.py +1466 -0
- langroid/agent/special/lance_doc_chat_agent.py +262 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +198 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
- langroid/agent/special/lance_tools.py +61 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +56 -0
- langroid/agent/special/sql/__init__.py +17 -0
- langroid/agent/special/sql/sql_chat_agent.py +654 -0
- langroid/agent/special/sql/utils/__init__.py +21 -0
- langroid/agent/special/sql/utils/description_extractors.py +190 -0
- langroid/agent/special/sql/utils/populate_metadata.py +85 -0
- langroid/agent/special/sql/utils/system_message.py +35 -0
- langroid/agent/special/sql/utils/tools.py +64 -0
- langroid/agent/special/table_chat_agent.py +263 -0
- langroid/agent/task.py +2095 -0
- langroid/agent/tool_message.py +393 -0
- langroid/agent/tools/__init__.py +38 -0
- langroid/agent/tools/duckduckgo_search_tool.py +50 -0
- langroid/agent/tools/file_tools.py +234 -0
- langroid/agent/tools/google_search_tool.py +39 -0
- langroid/agent/tools/metaphor_search_tool.py +68 -0
- langroid/agent/tools/orchestration.py +303 -0
- langroid/agent/tools/recipient_tool.py +235 -0
- langroid/agent/tools/retrieval_tool.py +32 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/agent/tools/segment_extract_tool.py +41 -0
- langroid/agent/xml_tool_message.py +382 -0
- langroid/cachedb/__init__.py +17 -0
- langroid/cachedb/base.py +58 -0
- langroid/cachedb/momento_cachedb.py +108 -0
- langroid/cachedb/redis_cachedb.py +153 -0
- langroid/embedding_models/__init__.py +39 -0
- langroid/embedding_models/base.py +74 -0
- langroid/embedding_models/models.py +461 -0
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/exceptions.py +71 -0
- langroid/language_models/__init__.py +53 -0
- langroid/language_models/azure_openai.py +153 -0
- langroid/language_models/base.py +678 -0
- langroid/language_models/config.py +18 -0
- langroid/language_models/mock_lm.py +124 -0
- langroid/language_models/openai_gpt.py +1964 -0
- langroid/language_models/prompt_formatter/__init__.py +16 -0
- langroid/language_models/prompt_formatter/base.py +40 -0
- langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
- langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
- langroid/language_models/utils.py +151 -0
- langroid/mytypes.py +84 -0
- langroid/parsing/__init__.py +52 -0
- langroid/parsing/agent_chats.py +38 -0
- langroid/parsing/code_parser.py +121 -0
- langroid/parsing/document_parser.py +718 -0
- langroid/parsing/para_sentence_split.py +62 -0
- langroid/parsing/parse_json.py +155 -0
- langroid/parsing/parser.py +313 -0
- langroid/parsing/repo_loader.py +790 -0
- langroid/parsing/routing.py +36 -0
- langroid/parsing/search.py +275 -0
- langroid/parsing/spider.py +102 -0
- langroid/parsing/table_loader.py +94 -0
- langroid/parsing/url_loader.py +111 -0
- langroid/parsing/urls.py +273 -0
- langroid/parsing/utils.py +373 -0
- langroid/parsing/web_search.py +156 -0
- langroid/prompts/__init__.py +9 -0
- langroid/prompts/dialog.py +17 -0
- langroid/prompts/prompts_config.py +5 -0
- langroid/prompts/templates.py +141 -0
- langroid/pydantic_v1/__init__.py +10 -0
- langroid/pydantic_v1/main.py +4 -0
- langroid/utils/__init__.py +19 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +98 -0
- langroid/utils/constants.py +30 -0
- langroid/utils/git_utils.py +252 -0
- langroid/utils/globals.py +49 -0
- langroid/utils/logging.py +135 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/output/__init__.py +20 -0
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +99 -0
- langroid/utils/output/status.py +40 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +602 -0
- langroid/utils/system.py +286 -0
- langroid/utils/types.py +93 -0
- langroid/vector_store/__init__.py +50 -0
- langroid/vector_store/base.py +359 -0
- langroid/vector_store/chromadb.py +214 -0
- langroid/vector_store/lancedb.py +406 -0
- langroid/vector_store/meilisearch.py +299 -0
- langroid/vector_store/momento.py +278 -0
- langroid/vector_store/qdrantdb.py +468 -0
- {langroid-0.33.4.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
- langroid-0.33.7.dist-info/RECORD +127 -0
- {langroid-0.33.4.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
- langroid-0.33.4.dist-info/RECORD +0 -7
- langroid-0.33.4.dist-info/entry_points.txt +0 -4
- pyproject.toml +0 -356
- {langroid-0.33.4.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,190 @@
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
2
|
+
|
3
|
+
from langroid.exceptions import LangroidImportError
|
4
|
+
|
5
|
+
try:
|
6
|
+
from sqlalchemy import inspect, text
|
7
|
+
from sqlalchemy.engine import Engine
|
8
|
+
except ImportError as e:
|
9
|
+
raise LangroidImportError(extra="sql", error=str(e))
|
10
|
+
|
11
|
+
|
12
|
+
def extract_postgresql_descriptions(
|
13
|
+
engine: Engine,
|
14
|
+
multi_schema: bool = False,
|
15
|
+
) -> Dict[str, Dict[str, Any]]:
|
16
|
+
"""
|
17
|
+
Extracts descriptions for tables and columns from a PostgreSQL database.
|
18
|
+
|
19
|
+
This method retrieves the descriptions of tables and their columns
|
20
|
+
from a PostgreSQL database using the provided SQLAlchemy engine.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
engine (Engine): SQLAlchemy engine connected to a PostgreSQL database.
|
24
|
+
multi_schema (bool): Generate descriptions for all schemas in the database.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
Dict[str, Dict[str, Any]]: A dictionary mapping table names to a
|
28
|
+
dictionary containing the table description and a dictionary of
|
29
|
+
column descriptions.
|
30
|
+
"""
|
31
|
+
inspector = inspect(engine)
|
32
|
+
result: Dict[str, Dict[str, Any]] = {}
|
33
|
+
|
34
|
+
def gen_schema_descriptions(schema: Optional[str] = None) -> None:
|
35
|
+
table_names: List[str] = inspector.get_table_names(schema=schema)
|
36
|
+
with engine.connect() as conn:
|
37
|
+
for table in table_names:
|
38
|
+
if schema is None:
|
39
|
+
table_name = table
|
40
|
+
else:
|
41
|
+
table_name = f"{schema}.{table}"
|
42
|
+
|
43
|
+
table_comment = (
|
44
|
+
conn.execute(
|
45
|
+
text(f"SELECT obj_description('{table_name}'::regclass)")
|
46
|
+
).scalar()
|
47
|
+
or ""
|
48
|
+
)
|
49
|
+
|
50
|
+
columns = {}
|
51
|
+
col_data = inspector.get_columns(table, schema=schema)
|
52
|
+
for idx, col in enumerate(col_data, start=1):
|
53
|
+
col_comment = (
|
54
|
+
conn.execute(
|
55
|
+
text(
|
56
|
+
f"SELECT col_description('{table_name}'::regclass, "
|
57
|
+
f"{idx})"
|
58
|
+
)
|
59
|
+
).scalar()
|
60
|
+
or ""
|
61
|
+
)
|
62
|
+
columns[col["name"]] = col_comment
|
63
|
+
|
64
|
+
result[table_name] = {"description": table_comment, "columns": columns}
|
65
|
+
|
66
|
+
if multi_schema:
|
67
|
+
for schema in inspector.get_schema_names():
|
68
|
+
gen_schema_descriptions(schema)
|
69
|
+
else:
|
70
|
+
gen_schema_descriptions()
|
71
|
+
|
72
|
+
return result
|
73
|
+
|
74
|
+
|
75
|
+
def extract_mysql_descriptions(
|
76
|
+
engine: Engine,
|
77
|
+
multi_schema: bool = False,
|
78
|
+
) -> Dict[str, Dict[str, Any]]:
|
79
|
+
"""Extracts descriptions for tables and columns from a MySQL database.
|
80
|
+
|
81
|
+
This method retrieves the descriptions of tables and their columns
|
82
|
+
from a MySQL database using the provided SQLAlchemy engine.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
engine (Engine): SQLAlchemy engine connected to a MySQL database.
|
86
|
+
multi_schema (bool): Generate descriptions for all schemas in the database.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
Dict[str, Dict[str, Any]]: A dictionary mapping table names to a
|
90
|
+
dictionary containing the table description and a dictionary of
|
91
|
+
column descriptions.
|
92
|
+
"""
|
93
|
+
inspector = inspect(engine)
|
94
|
+
result: Dict[str, Dict[str, Any]] = {}
|
95
|
+
|
96
|
+
def gen_schema_descriptions(schema: Optional[str] = None) -> None:
|
97
|
+
table_names: List[str] = inspector.get_table_names(schema=schema)
|
98
|
+
|
99
|
+
with engine.connect() as conn:
|
100
|
+
for table in table_names:
|
101
|
+
if schema is None:
|
102
|
+
table_name = table
|
103
|
+
else:
|
104
|
+
table_name = f"{schema}.{table}"
|
105
|
+
|
106
|
+
query = text(
|
107
|
+
"SELECT table_comment FROM information_schema.tables WHERE"
|
108
|
+
" table_schema = :schema AND table_name = :table"
|
109
|
+
)
|
110
|
+
table_result = conn.execute(
|
111
|
+
query, {"schema": engine.url.database, "table": table_name}
|
112
|
+
)
|
113
|
+
table_comment = table_result.scalar() or ""
|
114
|
+
|
115
|
+
columns = {}
|
116
|
+
for col in inspector.get_columns(table, schema=schema):
|
117
|
+
columns[col["name"]] = col.get("comment", "")
|
118
|
+
|
119
|
+
result[table_name] = {"description": table_comment, "columns": columns}
|
120
|
+
|
121
|
+
if multi_schema:
|
122
|
+
for schema in inspector.get_schema_names():
|
123
|
+
gen_schema_descriptions(schema)
|
124
|
+
else:
|
125
|
+
gen_schema_descriptions()
|
126
|
+
|
127
|
+
return result
|
128
|
+
|
129
|
+
|
130
|
+
def extract_default_descriptions(
|
131
|
+
engine: Engine, multi_schema: bool = False
|
132
|
+
) -> Dict[str, Dict[str, Any]]:
|
133
|
+
"""Extracts default descriptions for tables and columns from a database.
|
134
|
+
|
135
|
+
This method retrieves the table and column names from the given database
|
136
|
+
and associates empty descriptions with them.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
engine (Engine): SQLAlchemy engine connected to a database.
|
140
|
+
multi_schema (bool): Generate descriptions for all schemas in the database.
|
141
|
+
|
142
|
+
Returns:
|
143
|
+
Dict[str, Dict[str, Any]]: A dictionary mapping table names to a
|
144
|
+
dictionary containing an empty table description and a dictionary of
|
145
|
+
empty column descriptions.
|
146
|
+
"""
|
147
|
+
inspector = inspect(engine)
|
148
|
+
result: Dict[str, Dict[str, Any]] = {}
|
149
|
+
|
150
|
+
def gen_schema_descriptions(schema: Optional[str] = None) -> None:
|
151
|
+
table_names: List[str] = inspector.get_table_names(schema=schema)
|
152
|
+
|
153
|
+
for table in table_names:
|
154
|
+
columns = {}
|
155
|
+
for col in inspector.get_columns(table):
|
156
|
+
columns[col["name"]] = ""
|
157
|
+
|
158
|
+
result[table] = {"description": "", "columns": columns}
|
159
|
+
|
160
|
+
if multi_schema:
|
161
|
+
for schema in inspector.get_schema_names():
|
162
|
+
gen_schema_descriptions(schema)
|
163
|
+
else:
|
164
|
+
gen_schema_descriptions()
|
165
|
+
|
166
|
+
return result
|
167
|
+
|
168
|
+
|
169
|
+
def extract_schema_descriptions(
|
170
|
+
engine: Engine, multi_schema: bool = False
|
171
|
+
) -> Dict[str, Dict[str, Any]]:
|
172
|
+
"""
|
173
|
+
Extracts the schema descriptions from the database connected to by the engine.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
engine (Engine): SQLAlchemy engine instance.
|
177
|
+
multi_schema (bool): Generate descriptions for all schemas in the database.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
Dict[str, Dict[str, Any]]: A dictionary representation of table and column
|
181
|
+
descriptions.
|
182
|
+
"""
|
183
|
+
|
184
|
+
extractors = {
|
185
|
+
"postgresql": extract_postgresql_descriptions,
|
186
|
+
"mysql": extract_mysql_descriptions,
|
187
|
+
}
|
188
|
+
return extractors.get(engine.dialect.name, extract_default_descriptions)(
|
189
|
+
engine, multi_schema=multi_schema
|
190
|
+
)
|
@@ -0,0 +1,85 @@
|
|
1
|
+
from typing import Dict, List, Union
|
2
|
+
|
3
|
+
from langroid.exceptions import LangroidImportError
|
4
|
+
|
5
|
+
try:
|
6
|
+
from sqlalchemy import MetaData
|
7
|
+
except ImportError as e:
|
8
|
+
raise LangroidImportError(extra="sql", error=str(e))
|
9
|
+
|
10
|
+
|
11
|
+
def populate_metadata_with_schema_tools(
|
12
|
+
metadata: MetaData | List[MetaData],
|
13
|
+
info: Dict[str, Dict[str, Union[str, Dict[str, str]]]],
|
14
|
+
) -> Dict[str, Dict[str, Union[str, Dict[str, str]]]]:
|
15
|
+
"""
|
16
|
+
Extracts information from an SQLAlchemy database's metadata and combines it
|
17
|
+
with another dictionary with context descriptions.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
metadata (MetaData): SQLAlchemy metadata object of the database.
|
21
|
+
info (Dict[str, Dict[str, Any]]): A dictionary with table and column
|
22
|
+
descriptions.
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
Dict[str, Dict[str, Any]]: A dictionary with table and context information.
|
26
|
+
"""
|
27
|
+
db_info: Dict[str, Dict[str, Union[str, Dict[str, str]]]] = {}
|
28
|
+
|
29
|
+
def populate_metadata(md: MetaData) -> None:
|
30
|
+
# Create empty metadata dictionary with column datatypes
|
31
|
+
for table_name, table in md.tables.items():
|
32
|
+
# Populate tables with empty descriptions
|
33
|
+
db_info[table_name] = {
|
34
|
+
"description": info[table_name]["description"] or "",
|
35
|
+
"columns": {},
|
36
|
+
}
|
37
|
+
|
38
|
+
for column in table.columns:
|
39
|
+
# Populate columns with datatype
|
40
|
+
db_info[table_name]["columns"][str(column.name)] = ( # type: ignore
|
41
|
+
str(column.type)
|
42
|
+
)
|
43
|
+
|
44
|
+
if isinstance(metadata, list):
|
45
|
+
for md in metadata:
|
46
|
+
populate_metadata(md)
|
47
|
+
else:
|
48
|
+
populate_metadata(metadata)
|
49
|
+
|
50
|
+
return db_info
|
51
|
+
|
52
|
+
|
53
|
+
def populate_metadata(
|
54
|
+
metadata: MetaData | List[MetaData],
|
55
|
+
info: Dict[str, Dict[str, Union[str, Dict[str, str]]]],
|
56
|
+
) -> Dict[str, Dict[str, Union[str, Dict[str, str]]]]:
|
57
|
+
"""
|
58
|
+
Populate metadata based on the provided database metadata and additional info.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
metadata (MetaData): Metadata object from SQLAlchemy.
|
62
|
+
info (Dict): Additional information for database tables and columns.
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
Dict: A dictionary containing populated metadata information.
|
66
|
+
"""
|
67
|
+
# Fetch basic metadata info using available tools
|
68
|
+
db_info: Dict[str, Dict[str, Union[str, Dict[str, str]]]] = (
|
69
|
+
populate_metadata_with_schema_tools(metadata=metadata, info=info)
|
70
|
+
)
|
71
|
+
|
72
|
+
# Iterate over tables to update column metadata
|
73
|
+
for table_name in db_info.keys():
|
74
|
+
# Update only if additional info for the table exists
|
75
|
+
if table_name in info:
|
76
|
+
for column_name in db_info[table_name]["columns"]:
|
77
|
+
# Merge and update column description if available
|
78
|
+
if column_name in info[table_name]["columns"]:
|
79
|
+
db_info[table_name]["columns"][column_name] = ( # type: ignore
|
80
|
+
db_info[table_name]["columns"][column_name] # type: ignore
|
81
|
+
+ "; "
|
82
|
+
+ info[table_name]["columns"][column_name] # type: ignore
|
83
|
+
)
|
84
|
+
|
85
|
+
return db_info
|
@@ -0,0 +1,35 @@
|
|
1
|
+
DEFAULT_SYS_MSG = """You are a savvy data scientist/database administrator,
|
2
|
+
with expertise in answering questions by querying a {dialect} database.
|
3
|
+
You do not have access to the database 'db' directly, so you will need to use the
|
4
|
+
`run_query` tool/function-call to answer questions.
|
5
|
+
|
6
|
+
The below JSON schema maps the SQL database structure. It outlines tables, each
|
7
|
+
with a description and columns. Each table is identified by a key,
|
8
|
+
and holds a description and a dictionary of columns,
|
9
|
+
with column names as keys and their descriptions as values.
|
10
|
+
{schema_dict}
|
11
|
+
|
12
|
+
ONLY the tables and column names and tables specified above should be used in
|
13
|
+
the generated queries.
|
14
|
+
You must be smart about using the right tables and columns based on the
|
15
|
+
english description. If you are thinking of using a table or column that
|
16
|
+
does not exist, you are probably on the wrong track, so you should try
|
17
|
+
your best to answer based on an existing table or column.
|
18
|
+
DO NOT assume any tables or columns other than those above."""
|
19
|
+
|
20
|
+
SCHEMA_TOOLS_SYS_MSG = """You are a savvy data scientist/database administrator,
|
21
|
+
with expertise in answering questions by interacting with a SQL database.
|
22
|
+
|
23
|
+
You will have to follow these steps to complete your job:
|
24
|
+
1) Use the `get_table_names` tool/function-call to get a list of all possibly
|
25
|
+
relevant table names.
|
26
|
+
2) Use the `get_table_schema` tool/function-call to get the schema of all
|
27
|
+
possibly relevant tables to identify possibly relevant columns. Only
|
28
|
+
call this method on potentially relevant tables.
|
29
|
+
3) Use the `get_column_descriptions` tool/function-call to get more information
|
30
|
+
about any relevant columns.
|
31
|
+
4) Write a {dialect} query and use `run_query` tool the Execute the SQL query
|
32
|
+
on the database to obtain the results.
|
33
|
+
|
34
|
+
Do not make assumptions about the database schema before using the tools.
|
35
|
+
Use the tool/functions to learn more about the database schema."""
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from typing import List, Tuple
|
2
|
+
|
3
|
+
from langroid.agent.tool_message import ToolMessage
|
4
|
+
|
5
|
+
|
6
|
+
class RunQueryTool(ToolMessage):
|
7
|
+
request: str = "run_query"
|
8
|
+
purpose: str = """
|
9
|
+
To run <query> on the database 'db' and
|
10
|
+
return the results to answer a question.
|
11
|
+
"""
|
12
|
+
query: str
|
13
|
+
|
14
|
+
@classmethod
|
15
|
+
def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
|
16
|
+
return [
|
17
|
+
cls(
|
18
|
+
query="SELECT * FROM movies WHERE genre = 'comedy'",
|
19
|
+
),
|
20
|
+
(
|
21
|
+
"Find all movies with a rating of 5",
|
22
|
+
cls(
|
23
|
+
query="SELECT * FROM movies WHERE rating = 5",
|
24
|
+
),
|
25
|
+
),
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
class GetTableNamesTool(ToolMessage):
|
30
|
+
request: str = "get_table_names"
|
31
|
+
purpose: str = """
|
32
|
+
To retrieve the names of all <tables> in the database 'db'.
|
33
|
+
"""
|
34
|
+
|
35
|
+
|
36
|
+
class GetTableSchemaTool(ToolMessage):
|
37
|
+
request: str = "get_table_schema"
|
38
|
+
purpose: str = """
|
39
|
+
To retrieve the schema of all provided <tables> in the database 'db'.
|
40
|
+
"""
|
41
|
+
tables: List[str]
|
42
|
+
|
43
|
+
@classmethod
|
44
|
+
def example(cls) -> "GetTableSchemaTool":
|
45
|
+
return cls(
|
46
|
+
tables=["employees", "departments", "sales"],
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class GetColumnDescriptionsTool(ToolMessage):
|
51
|
+
request: str = "get_column_descriptions"
|
52
|
+
purpose: str = """
|
53
|
+
To retrieve the description of one or more <columns> from the respective
|
54
|
+
<table> in the database 'db'.
|
55
|
+
"""
|
56
|
+
table: str
|
57
|
+
columns: str
|
58
|
+
|
59
|
+
@classmethod
|
60
|
+
def example(cls) -> "GetColumnDescriptionsTool":
|
61
|
+
return cls(
|
62
|
+
table="employees",
|
63
|
+
columns="name, department_id",
|
64
|
+
)
|
@@ -0,0 +1,263 @@
|
|
1
|
+
"""
|
2
|
+
Agent that supports asking queries about a tabular dataset, internally
|
3
|
+
represented as a Pandas dataframe. The `TableChatAgent` is configured with a
|
4
|
+
dataset, which can be a Pandas df, file or URL. The delimiter/separator
|
5
|
+
is auto-detected. In response to a user query, the Agent's LLM generates a Pandas
|
6
|
+
expression (involving a dataframe `df`) to answer the query.
|
7
|
+
The expression is passed via the `pandas_eval` tool/function-call,
|
8
|
+
which is handled by the Agent's `pandas_eval` method. This method evaluates
|
9
|
+
the expression and returns the result as a string.
|
10
|
+
"""
|
11
|
+
|
12
|
+
import io
|
13
|
+
import logging
|
14
|
+
import sys
|
15
|
+
from typing import List, Optional, Tuple, no_type_check
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import pandas as pd
|
19
|
+
from rich.console import Console
|
20
|
+
|
21
|
+
import langroid as lr
|
22
|
+
from langroid.agent import ChatDocument
|
23
|
+
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
24
|
+
from langroid.agent.tool_message import ToolMessage
|
25
|
+
from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
|
26
|
+
from langroid.parsing.table_loader import read_tabular_data
|
27
|
+
from langroid.prompts.prompts_config import PromptsConfig
|
28
|
+
from langroid.utils.constants import DONE, PASS
|
29
|
+
from langroid.vector_store.base import VectorStoreConfig
|
30
|
+
|
31
|
+
logger = logging.getLogger(__name__)
|
32
|
+
|
33
|
+
console = Console()
|
34
|
+
|
35
|
+
DEFAULT_TABLE_CHAT_SYSTEM_MESSAGE = f"""
|
36
|
+
You are a savvy data scientist, with expertise in analyzing tabular datasets,
|
37
|
+
using Python and the Pandas library for dataframe manipulation.
|
38
|
+
Since you do not have access to the dataframe 'df', you
|
39
|
+
will need to use the `pandas_eval` tool/function-call to answer my questions.
|
40
|
+
Here is a summary of the dataframe:
|
41
|
+
{{summary}}
|
42
|
+
Do not assume any columns other than those shown.
|
43
|
+
In the expression you submit to the `pandas_eval` tool/function,
|
44
|
+
you are allowed to use the variable 'df' to refer to the dataframe.
|
45
|
+
|
46
|
+
Sometimes you may not be able to answer the question in a single call to `pandas_eval`,
|
47
|
+
so you can use a series of calls to `pandas_eval` to build up the answer.
|
48
|
+
For example you may first want to know something about the possible values in a column.
|
49
|
+
|
50
|
+
If you receive a null or other unexpected result, see if you have made an assumption
|
51
|
+
in your code, and try another way, or use `pandas_eval` to explore the dataframe
|
52
|
+
before submitting your final code.
|
53
|
+
|
54
|
+
Once you have the answer to the question, possibly after a few steps,
|
55
|
+
say {DONE} and PRESENT THE ANSWER TO ME; do not just say {DONE}.
|
56
|
+
If you receive an error message,
|
57
|
+
try using the `pandas_eval` tool/function again with the corrected code.
|
58
|
+
|
59
|
+
VERY IMPORTANT: When using the `pandas_eval` tool/function, DO NOT EXPLAIN ANYTHING,
|
60
|
+
SIMPLY USE THE TOOL, with the CODE.
|
61
|
+
Start by asking me what I want to know about the data.
|
62
|
+
"""
|
63
|
+
|
64
|
+
|
65
|
+
@no_type_check
|
66
|
+
def dataframe_summary(df: pd.DataFrame) -> str:
|
67
|
+
"""
|
68
|
+
Generate a structured summary for a pandas DataFrame containing numerical
|
69
|
+
and categorical values.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
df (pd.DataFrame): The input DataFrame to summarize.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
str: A nicely structured and formatted summary string.
|
76
|
+
"""
|
77
|
+
|
78
|
+
# Column names display
|
79
|
+
col_names_str = (
|
80
|
+
"COLUMN NAMES:\n" + " ".join([f"'{col}'" for col in df.columns]) + "\n\n"
|
81
|
+
)
|
82
|
+
|
83
|
+
# Numerical data summary
|
84
|
+
num_summary = df.describe().map(lambda x: "{:.2f}".format(x))
|
85
|
+
num_str = "Numerical Column Summary:\n" + num_summary.to_string() + "\n\n"
|
86
|
+
|
87
|
+
# Categorical data summary
|
88
|
+
cat_columns = df.select_dtypes(include=[np.object_]).columns
|
89
|
+
cat_summary_list = []
|
90
|
+
|
91
|
+
for col in cat_columns:
|
92
|
+
unique_values = df[col].unique()
|
93
|
+
if len(unique_values) < 10:
|
94
|
+
cat_summary_list.append(f"'{col}': {', '.join(map(str, unique_values))}")
|
95
|
+
else:
|
96
|
+
cat_summary_list.append(f"'{col}': {df[col].nunique()} unique values")
|
97
|
+
|
98
|
+
cat_str = "Categorical Column Summary:\n" + "\n".join(cat_summary_list) + "\n\n"
|
99
|
+
|
100
|
+
# Missing values summary
|
101
|
+
nan_summary = df.isnull().sum().rename("missing_values").to_frame()
|
102
|
+
nan_str = "Missing Values Column Summary:\n" + nan_summary.to_string() + "\n"
|
103
|
+
|
104
|
+
# Combine the summaries into one structured string
|
105
|
+
summary_str = col_names_str + num_str + cat_str + nan_str
|
106
|
+
|
107
|
+
return summary_str
|
108
|
+
|
109
|
+
|
110
|
+
class TableChatAgentConfig(ChatAgentConfig):
|
111
|
+
system_message: str = DEFAULT_TABLE_CHAT_SYSTEM_MESSAGE
|
112
|
+
user_message: None | str = None
|
113
|
+
cache: bool = True # cache results
|
114
|
+
debug: bool = False
|
115
|
+
stream: bool = True # allow streaming where needed
|
116
|
+
data: str | pd.DataFrame # data file, URL, or DataFrame
|
117
|
+
separator: None | str = None # separator for data file
|
118
|
+
vecdb: None | VectorStoreConfig = None
|
119
|
+
llm: OpenAIGPTConfig = OpenAIGPTConfig(
|
120
|
+
type="openai",
|
121
|
+
chat_model=OpenAIChatModel.GPT4,
|
122
|
+
completion_model=OpenAIChatModel.GPT4,
|
123
|
+
)
|
124
|
+
prompts: PromptsConfig = PromptsConfig(
|
125
|
+
max_tokens=1000,
|
126
|
+
)
|
127
|
+
|
128
|
+
|
129
|
+
class PandasEvalTool(ToolMessage):
|
130
|
+
"""Tool/function to evaluate a pandas expression involving a dataframe `df`"""
|
131
|
+
|
132
|
+
request: str = "pandas_eval"
|
133
|
+
purpose: str = """
|
134
|
+
To eval a pandas <expression> on the dataframe 'df' and
|
135
|
+
return the results to answer a question.
|
136
|
+
IMPORTANT: the <expression> field should be a valid pandas expression.
|
137
|
+
"""
|
138
|
+
expression: str
|
139
|
+
|
140
|
+
@classmethod
|
141
|
+
def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
|
142
|
+
return [
|
143
|
+
cls(expression="df.head()"),
|
144
|
+
cls(expression="df[(df['gender'] == 'Male')]['income'].mean()"),
|
145
|
+
]
|
146
|
+
|
147
|
+
@classmethod
|
148
|
+
def instructions(cls) -> str:
|
149
|
+
return """
|
150
|
+
Use the `pandas_eval` tool/function to evaluate a pandas expression
|
151
|
+
involving the dataframe 'df' to answer the user's question.
|
152
|
+
"""
|
153
|
+
|
154
|
+
|
155
|
+
class TableChatAgent(ChatAgent):
|
156
|
+
"""
|
157
|
+
Agent for chatting with a collection of documents.
|
158
|
+
"""
|
159
|
+
|
160
|
+
sent_expression: bool = False
|
161
|
+
|
162
|
+
def __init__(self, config: TableChatAgentConfig):
|
163
|
+
if isinstance(config.data, pd.DataFrame):
|
164
|
+
df = config.data
|
165
|
+
else:
|
166
|
+
df = read_tabular_data(config.data, config.separator)
|
167
|
+
|
168
|
+
df.columns = df.columns.str.strip().str.replace(" +", "_", regex=True)
|
169
|
+
|
170
|
+
self.df = df
|
171
|
+
summary = dataframe_summary(df)
|
172
|
+
config.system_message = config.system_message.format(summary=summary)
|
173
|
+
|
174
|
+
super().__init__(config)
|
175
|
+
self.config: TableChatAgentConfig = config
|
176
|
+
|
177
|
+
logger.info(
|
178
|
+
f"""TableChatAgent initialized with dataframe of shape {self.df.shape}
|
179
|
+
and columns:
|
180
|
+
{self.df.columns}
|
181
|
+
"""
|
182
|
+
)
|
183
|
+
# enable the agent to use and handle the PandasEvalTool
|
184
|
+
self.enable_message(PandasEvalTool)
|
185
|
+
|
186
|
+
def user_response(
|
187
|
+
self,
|
188
|
+
msg: Optional[str | ChatDocument] = None,
|
189
|
+
) -> Optional[ChatDocument]:
|
190
|
+
response = super().user_response(msg)
|
191
|
+
if response is not None and response.content != "":
|
192
|
+
self.sent_expression = False
|
193
|
+
return response
|
194
|
+
|
195
|
+
def pandas_eval(self, msg: PandasEvalTool) -> str:
|
196
|
+
"""
|
197
|
+
Handle a PandasEvalTool message by evaluating the `expression` field
|
198
|
+
and returning the result.
|
199
|
+
Args:
|
200
|
+
msg (PandasEvalTool): The tool-message to handle.
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
str: The result of running the code along with any print output.
|
204
|
+
"""
|
205
|
+
self.sent_expression = True
|
206
|
+
exprn = msg.expression
|
207
|
+
local_vars = {"df": self.df}
|
208
|
+
# Create a string-based I/O stream
|
209
|
+
code_out = io.StringIO()
|
210
|
+
|
211
|
+
# Temporarily redirect standard output to our string-based I/O stream
|
212
|
+
sys.stdout = code_out
|
213
|
+
|
214
|
+
# Evaluate the last line and get the result
|
215
|
+
try:
|
216
|
+
eval_result = pd.eval(exprn, local_dict=local_vars)
|
217
|
+
except Exception as e:
|
218
|
+
eval_result = f"ERROR: {type(e)}: {e}"
|
219
|
+
|
220
|
+
if eval_result is None:
|
221
|
+
eval_result = ""
|
222
|
+
|
223
|
+
# Always restore the original standard output
|
224
|
+
sys.stdout = sys.__stdout__
|
225
|
+
|
226
|
+
# If df has been modified in-place, save the changes back to self.df
|
227
|
+
self.df = local_vars["df"]
|
228
|
+
|
229
|
+
# Get the resulting string from the I/O stream
|
230
|
+
print_result = code_out.getvalue() or ""
|
231
|
+
sep = "\n" if print_result else ""
|
232
|
+
# Combine the print and eval results
|
233
|
+
result = f"{print_result}{sep}{eval_result}"
|
234
|
+
if result == "":
|
235
|
+
result = "No result"
|
236
|
+
# Return the result
|
237
|
+
return result
|
238
|
+
|
239
|
+
def handle_message_fallback(
|
240
|
+
self, msg: str | ChatDocument
|
241
|
+
) -> str | ChatDocument | None:
|
242
|
+
"""Handle various LLM deviations"""
|
243
|
+
if isinstance(msg, ChatDocument) and msg.metadata.sender == lr.Entity.LLM:
|
244
|
+
if msg.content.strip() == DONE and self.sent_expression:
|
245
|
+
# LLM sent an expression (i.e. used the `pandas_eval` tool)
|
246
|
+
# but upon receiving the results, simply said DONE without
|
247
|
+
# narrating the result as instructed.
|
248
|
+
return """
|
249
|
+
You forgot to PRESENT the answer to the user's query
|
250
|
+
based on the results from `pandas_eval` tool.
|
251
|
+
"""
|
252
|
+
if self.sent_expression:
|
253
|
+
# LLM forgot to say DONE
|
254
|
+
self.sent_expression = False
|
255
|
+
return DONE + " " + PASS
|
256
|
+
else:
|
257
|
+
# LLM forgot to use the `pandas_eval` tool
|
258
|
+
return """
|
259
|
+
You forgot to use the `pandas_eval` tool/function
|
260
|
+
to find the answer.
|
261
|
+
Try again using the `pandas_eval` tool/function.
|
262
|
+
"""
|
263
|
+
return None
|