langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +95 -0
- langroid/agent/__init__.py +40 -0
- langroid/agent/base.py +222 -91
- langroid/agent/batch.py +264 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +247 -101
- langroid/agent/chat_document.py +41 -4
- langroid/agent/openai_assistant.py +842 -0
- langroid/agent/special/__init__.py +50 -0
- langroid/agent/special/doc_chat_agent.py +837 -141
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +32 -198
- langroid/agent/special/sql/__init__.py +11 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +22 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +475 -122
- langroid/agent/tool_message.py +75 -13
- langroid/agent/tools/__init__.py +13 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +16 -29
- langroid/agent/tools/run_python_code.py +60 -0
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/agent/tools/segment_extract_tool.py +36 -0
- langroid/cachedb/__init__.py +9 -0
- langroid/cachedb/base.py +22 -2
- langroid/cachedb/momento_cachedb.py +26 -2
- langroid/cachedb/redis_cachedb.py +78 -11
- langroid/embedding_models/__init__.py +34 -0
- langroid/embedding_models/base.py +21 -2
- langroid/embedding_models/models.py +120 -18
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +45 -0
- langroid/language_models/azure_openai.py +80 -27
- langroid/language_models/base.py +117 -12
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_assistants.py +3 -0
- langroid/language_models/openai_gpt.py +558 -174
- langroid/language_models/prompt_formatter/__init__.py +15 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +18 -21
- langroid/mytypes.py +25 -8
- langroid/parsing/__init__.py +46 -0
- langroid/parsing/document_parser.py +260 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +122 -59
- langroid/parsing/repo_loader.py +114 -52
- langroid/parsing/search.py +68 -63
- langroid/parsing/spider.py +3 -2
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -11
- langroid/parsing/urls.py +85 -37
- langroid/parsing/utils.py +298 -4
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +11 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +17 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +36 -5
- langroid/utils/constants.py +4 -0
- langroid/utils/globals.py +2 -2
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +21 -0
- langroid/utils/output/printing.py +47 -1
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +616 -2
- langroid/utils/system.py +98 -0
- langroid/vector_store/__init__.py +40 -0
- langroid/vector_store/base.py +203 -6
- langroid/vector_store/chromadb.py +59 -32
- langroid/vector_store/lancedb.py +463 -0
- langroid/vector_store/meilisearch.py +10 -7
- langroid/vector_store/momento.py +262 -0
- langroid/vector_store/qdrantdb.py +104 -22
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
- langroid-0.1.219.dist-info/RECORD +127 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.85.dist-info/RECORD +0 -94
- /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -1,10 +1,13 @@
|
|
1
|
-
from typing import Any, Dict, List
|
1
|
+
from typing import Any, Dict, List, Optional
|
2
2
|
|
3
3
|
from sqlalchemy import inspect, text
|
4
4
|
from sqlalchemy.engine import Engine
|
5
5
|
|
6
6
|
|
7
|
-
def extract_postgresql_descriptions(
|
7
|
+
def extract_postgresql_descriptions(
|
8
|
+
engine: Engine,
|
9
|
+
multi_schema: bool = False,
|
10
|
+
) -> Dict[str, Dict[str, Any]]:
|
8
11
|
"""
|
9
12
|
Extracts descriptions for tables and columns from a PostgreSQL database.
|
10
13
|
|
@@ -13,6 +16,7 @@ def extract_postgresql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]
|
|
13
16
|
|
14
17
|
Args:
|
15
18
|
engine (Engine): SQLAlchemy engine connected to a PostgreSQL database.
|
19
|
+
multi_schema (bool): Generate descriptions for all schemas in the database.
|
16
20
|
|
17
21
|
Returns:
|
18
22
|
Dict[str, Dict[str, Any]]: A dictionary mapping table names to a
|
@@ -20,36 +24,53 @@ def extract_postgresql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]
|
|
20
24
|
column descriptions.
|
21
25
|
"""
|
22
26
|
inspector = inspect(engine)
|
23
|
-
table_names: List[str] = inspector.get_table_names()
|
24
|
-
|
25
27
|
result: Dict[str, Dict[str, Any]] = {}
|
26
28
|
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
29
|
+
def gen_schema_descriptions(schema: Optional[str] = None) -> None:
|
30
|
+
table_names: List[str] = inspector.get_table_names(schema=schema)
|
31
|
+
with engine.connect() as conn:
|
32
|
+
for table in table_names:
|
33
|
+
if schema is None:
|
34
|
+
table_name = table
|
35
|
+
else:
|
36
|
+
table_name = f"{schema}.{table}"
|
35
37
|
|
36
|
-
|
37
|
-
col_data = inspector.get_columns(table)
|
38
|
-
for idx, col in enumerate(col_data, start=1):
|
39
|
-
col_comment = (
|
38
|
+
table_comment = (
|
40
39
|
conn.execute(
|
41
|
-
text(f"SELECT
|
40
|
+
text(f"SELECT obj_description('{table_name}'::regclass)")
|
42
41
|
).scalar()
|
43
42
|
or ""
|
44
43
|
)
|
45
|
-
columns[col["name"]] = col_comment
|
46
44
|
|
47
|
-
|
45
|
+
columns = {}
|
46
|
+
col_data = inspector.get_columns(table, schema=schema)
|
47
|
+
for idx, col in enumerate(col_data, start=1):
|
48
|
+
col_comment = (
|
49
|
+
conn.execute(
|
50
|
+
text(
|
51
|
+
f"SELECT col_description('{table_name}'::regclass, "
|
52
|
+
f"{idx})"
|
53
|
+
)
|
54
|
+
).scalar()
|
55
|
+
or ""
|
56
|
+
)
|
57
|
+
columns[col["name"]] = col_comment
|
58
|
+
|
59
|
+
result[table_name] = {"description": table_comment, "columns": columns}
|
60
|
+
|
61
|
+
if multi_schema:
|
62
|
+
for schema in inspector.get_schema_names():
|
63
|
+
gen_schema_descriptions(schema)
|
64
|
+
else:
|
65
|
+
gen_schema_descriptions()
|
48
66
|
|
49
67
|
return result
|
50
68
|
|
51
69
|
|
52
|
-
def extract_mysql_descriptions(
|
70
|
+
def extract_mysql_descriptions(
|
71
|
+
engine: Engine,
|
72
|
+
multi_schema: bool = False,
|
73
|
+
) -> Dict[str, Dict[str, Any]]:
|
53
74
|
"""Extracts descriptions for tables and columns from a MySQL database.
|
54
75
|
|
55
76
|
This method retrieves the descriptions of tables and their columns
|
@@ -57,6 +78,7 @@ def extract_mysql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
|
|
57
78
|
|
58
79
|
Args:
|
59
80
|
engine (Engine): SQLAlchemy engine connected to a MySQL database.
|
81
|
+
multi_schema (bool): Generate descriptions for all schemas in the database.
|
60
82
|
|
61
83
|
Returns:
|
62
84
|
Dict[str, Dict[str, Any]]: A dictionary mapping table names to a
|
@@ -64,31 +86,45 @@ def extract_mysql_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
|
|
64
86
|
column descriptions.
|
65
87
|
"""
|
66
88
|
inspector = inspect(engine)
|
67
|
-
table_names: List[str] = inspector.get_table_names()
|
68
|
-
|
69
89
|
result: Dict[str, Dict[str, Any]] = {}
|
70
90
|
|
71
|
-
|
72
|
-
|
73
|
-
query = text(
|
74
|
-
"SELECT table_comment FROM information_schema.tables WHERE"
|
75
|
-
" table_schema = :schema AND table_name = :table"
|
76
|
-
)
|
77
|
-
table_result = conn.execute(
|
78
|
-
query, {"schema": engine.url.database, "table": table}
|
79
|
-
)
|
80
|
-
table_comment = table_result.scalar() or ""
|
91
|
+
def gen_schema_descriptions(schema: Optional[str] = None) -> None:
|
92
|
+
table_names: List[str] = inspector.get_table_names(schema=schema)
|
81
93
|
|
82
|
-
|
83
|
-
for
|
84
|
-
|
94
|
+
with engine.connect() as conn:
|
95
|
+
for table in table_names:
|
96
|
+
if schema is None:
|
97
|
+
table_name = table
|
98
|
+
else:
|
99
|
+
table_name = f"{schema}.{table}"
|
100
|
+
|
101
|
+
query = text(
|
102
|
+
"SELECT table_comment FROM information_schema.tables WHERE"
|
103
|
+
" table_schema = :schema AND table_name = :table"
|
104
|
+
)
|
105
|
+
table_result = conn.execute(
|
106
|
+
query, {"schema": engine.url.database, "table": table_name}
|
107
|
+
)
|
108
|
+
table_comment = table_result.scalar() or ""
|
109
|
+
|
110
|
+
columns = {}
|
111
|
+
for col in inspector.get_columns(table, schema=schema):
|
112
|
+
columns[col["name"]] = col.get("comment", "")
|
113
|
+
|
114
|
+
result[table_name] = {"description": table_comment, "columns": columns}
|
85
115
|
|
86
|
-
|
116
|
+
if multi_schema:
|
117
|
+
for schema in inspector.get_schema_names():
|
118
|
+
gen_schema_descriptions(schema)
|
119
|
+
else:
|
120
|
+
gen_schema_descriptions()
|
87
121
|
|
88
122
|
return result
|
89
123
|
|
90
124
|
|
91
|
-
def extract_default_descriptions(
|
125
|
+
def extract_default_descriptions(
|
126
|
+
engine: Engine, multi_schema: bool = False
|
127
|
+
) -> Dict[str, Dict[str, Any]]:
|
92
128
|
"""Extracts default descriptions for tables and columns from a database.
|
93
129
|
|
94
130
|
This method retrieves the table and column names from the given database
|
@@ -96,6 +132,7 @@ def extract_default_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
|
|
96
132
|
|
97
133
|
Args:
|
98
134
|
engine (Engine): SQLAlchemy engine connected to a database.
|
135
|
+
multi_schema (bool): Generate descriptions for all schemas in the database.
|
99
136
|
|
100
137
|
Returns:
|
101
138
|
Dict[str, Dict[str, Any]]: A dictionary mapping table names to a
|
@@ -103,26 +140,36 @@ def extract_default_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
|
|
103
140
|
empty column descriptions.
|
104
141
|
"""
|
105
142
|
inspector = inspect(engine)
|
106
|
-
table_names: List[str] = inspector.get_table_names()
|
107
|
-
|
108
143
|
result: Dict[str, Dict[str, Any]] = {}
|
109
144
|
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
145
|
+
def gen_schema_descriptions(schema: Optional[str] = None) -> None:
|
146
|
+
table_names: List[str] = inspector.get_table_names(schema=schema)
|
147
|
+
|
148
|
+
for table in table_names:
|
149
|
+
columns = {}
|
150
|
+
for col in inspector.get_columns(table):
|
151
|
+
columns[col["name"]] = ""
|
152
|
+
|
153
|
+
result[table] = {"description": "", "columns": columns}
|
114
154
|
|
115
|
-
|
155
|
+
if multi_schema:
|
156
|
+
for schema in inspector.get_schema_names():
|
157
|
+
gen_schema_descriptions(schema)
|
158
|
+
else:
|
159
|
+
gen_schema_descriptions()
|
116
160
|
|
117
161
|
return result
|
118
162
|
|
119
163
|
|
120
|
-
def extract_schema_descriptions(
|
164
|
+
def extract_schema_descriptions(
|
165
|
+
engine: Engine, multi_schema: bool = False
|
166
|
+
) -> Dict[str, Dict[str, Any]]:
|
121
167
|
"""
|
122
168
|
Extracts the schema descriptions from the database connected to by the engine.
|
123
169
|
|
124
170
|
Args:
|
125
171
|
engine (Engine): SQLAlchemy engine instance.
|
172
|
+
multi_schema (bool): Generate descriptions for all schemas in the database.
|
126
173
|
|
127
174
|
Returns:
|
128
175
|
Dict[str, Dict[str, Any]]: A dictionary representation of table and column
|
@@ -133,4 +180,6 @@ def extract_schema_descriptions(engine: Engine) -> Dict[str, Dict[str, Any]]:
|
|
133
180
|
"postgresql": extract_postgresql_descriptions,
|
134
181
|
"mysql": extract_mysql_descriptions,
|
135
182
|
}
|
136
|
-
return extractors.get(engine.dialect.name, extract_default_descriptions)(
|
183
|
+
return extractors.get(engine.dialect.name, extract_default_descriptions)(
|
184
|
+
engine, multi_schema=multi_schema
|
185
|
+
)
|
@@ -1,10 +1,11 @@
|
|
1
|
-
from typing import Dict, Union
|
1
|
+
from typing import Dict, List, Union
|
2
2
|
|
3
3
|
from sqlalchemy import MetaData
|
4
4
|
|
5
5
|
|
6
6
|
def populate_metadata_with_schema_tools(
|
7
|
-
metadata: MetaData
|
7
|
+
metadata: MetaData | List[MetaData],
|
8
|
+
info: Dict[str, Dict[str, Union[str, Dict[str, str]]]],
|
8
9
|
) -> Dict[str, Dict[str, Union[str, Dict[str, str]]]]:
|
9
10
|
"""
|
10
11
|
Extracts information from an SQLAlchemy database's metadata and combines it
|
@@ -18,28 +19,35 @@ def populate_metadata_with_schema_tools(
|
|
18
19
|
Returns:
|
19
20
|
Dict[str, Dict[str, Any]]: A dictionary with table and context information.
|
20
21
|
"""
|
21
|
-
|
22
22
|
db_info: Dict[str, Dict[str, Union[str, Dict[str, str]]]] = {}
|
23
23
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
str(column.type
|
36
|
-
|
24
|
+
def populate_metadata(md: MetaData) -> None:
|
25
|
+
# Create empty metadata dictionary with column datatypes
|
26
|
+
for table_name, table in md.tables.items():
|
27
|
+
# Populate tables with empty descriptions
|
28
|
+
db_info[table_name] = {
|
29
|
+
"description": info[table_name]["description"] or "",
|
30
|
+
"columns": {},
|
31
|
+
}
|
32
|
+
|
33
|
+
for column in table.columns:
|
34
|
+
# Populate columns with datatype
|
35
|
+
db_info[table_name]["columns"][str(column.name)] = ( # type: ignore
|
36
|
+
str(column.type)
|
37
|
+
)
|
38
|
+
|
39
|
+
if isinstance(metadata, list):
|
40
|
+
for md in metadata:
|
41
|
+
populate_metadata(md)
|
42
|
+
else:
|
43
|
+
populate_metadata(metadata)
|
37
44
|
|
38
45
|
return db_info
|
39
46
|
|
40
47
|
|
41
48
|
def populate_metadata(
|
42
|
-
metadata: MetaData
|
49
|
+
metadata: MetaData | List[MetaData],
|
50
|
+
info: Dict[str, Dict[str, Union[str, Dict[str, str]]]],
|
43
51
|
) -> Dict[str, Dict[str, Union[str, Dict[str, str]]]]:
|
44
52
|
"""
|
45
53
|
Populate metadata based on the provided database metadata and additional info.
|
@@ -51,11 +59,10 @@ def populate_metadata(
|
|
51
59
|
Returns:
|
52
60
|
Dict: A dictionary containing populated metadata information.
|
53
61
|
"""
|
54
|
-
|
55
62
|
# Fetch basic metadata info using available tools
|
56
|
-
db_info: Dict[
|
57
|
-
|
58
|
-
|
63
|
+
db_info: Dict[str, Dict[str, Union[str, Dict[str, str]]]] = (
|
64
|
+
populate_metadata_with_schema_tools(metadata=metadata, info=info)
|
65
|
+
)
|
59
66
|
|
60
67
|
# Iterate over tables to update column metadata
|
61
68
|
for table_name in db_info.keys():
|
@@ -7,33 +7,37 @@ code to answer the query. The code is passed via the `run_code` tool/function-ca
|
|
7
7
|
which is handled by the Agent's `run_code` method. This method executes/evaluates
|
8
8
|
the code and returns the result as a string.
|
9
9
|
"""
|
10
|
+
|
10
11
|
import io
|
11
12
|
import logging
|
12
13
|
import sys
|
13
|
-
from typing import List, no_type_check
|
14
|
+
from typing import List, Optional, no_type_check
|
14
15
|
|
15
16
|
import numpy as np
|
16
17
|
import pandas as pd
|
17
18
|
from rich.console import Console
|
18
19
|
|
20
|
+
import langroid as lr
|
21
|
+
from langroid.agent import ChatDocument
|
19
22
|
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
20
23
|
from langroid.agent.tool_message import ToolMessage
|
21
24
|
from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
|
22
25
|
from langroid.parsing.table_loader import read_tabular_data
|
23
26
|
from langroid.prompts.prompts_config import PromptsConfig
|
27
|
+
from langroid.utils.constants import DONE
|
24
28
|
from langroid.vector_store.base import VectorStoreConfig
|
25
29
|
|
26
30
|
logger = logging.getLogger(__name__)
|
27
31
|
|
28
32
|
console = Console()
|
29
33
|
|
30
|
-
DEFAULT_TABLE_CHAT_SYSTEM_MESSAGE = """
|
34
|
+
DEFAULT_TABLE_CHAT_SYSTEM_MESSAGE = f"""
|
31
35
|
You are a savvy data scientist, with expertise in analyzing tabular datasets,
|
32
36
|
using Python and the Pandas library for dataframe manipulation.
|
33
37
|
Since you do not have access to the dataframe 'df', you
|
34
|
-
will need to use the `run_code` tool/function-call to answer
|
38
|
+
will need to use the `run_code` tool/function-call to answer my questions.
|
35
39
|
Here is a summary of the dataframe:
|
36
|
-
{summary}
|
40
|
+
{{summary}}
|
37
41
|
Do not assume any columns other than those shown.
|
38
42
|
In the code you submit to the `run_code` tool/function,
|
39
43
|
do not forget to include any necessary imports, such as `import pandas as pd`.
|
@@ -45,10 +49,12 @@ If you receive a null or other unexpected result, see if you have made an assump
|
|
45
49
|
in your code, and try another way, or use `run_code` to explore the dataframe
|
46
50
|
before submitting your final code.
|
47
51
|
|
48
|
-
Once you have the answer to the question,
|
49
|
-
If you receive an error message,
|
50
|
-
again with the corrected code.
|
52
|
+
Once you have the answer to the question, possibly after a few steps,
|
53
|
+
say {DONE} and show me the answer. If you receive an error message,
|
54
|
+
try using the `run_code` tool/function again with the corrected code.
|
51
55
|
|
56
|
+
VERY IMPORTANT: When using the `run_code` tool/function, DO NOT EXPLAIN ANYTHING,
|
57
|
+
SIMPLY USE THE TOOL, with the CODE.
|
52
58
|
Start by asking me what I want to know about the data.
|
53
59
|
"""
|
54
60
|
|
@@ -72,7 +78,7 @@ def dataframe_summary(df: pd.DataFrame) -> str:
|
|
72
78
|
)
|
73
79
|
|
74
80
|
# Numerical data summary
|
75
|
-
num_summary = df.describe().
|
81
|
+
num_summary = df.describe().map(lambda x: "{:.2f}".format(x))
|
76
82
|
num_str = "Numerical Column Summary:\n" + num_summary.to_string() + "\n\n"
|
77
83
|
|
78
84
|
# Categorical data summary
|
@@ -101,7 +107,6 @@ def dataframe_summary(df: pd.DataFrame) -> str:
|
|
101
107
|
class TableChatAgentConfig(ChatAgentConfig):
|
102
108
|
system_message: str = DEFAULT_TABLE_CHAT_SYSTEM_MESSAGE
|
103
109
|
user_message: None | str = None
|
104
|
-
max_context_tokens: int = 1000
|
105
110
|
cache: bool = True # cache results
|
106
111
|
debug: bool = False
|
107
112
|
stream: bool = True # allow streaming where needed
|
@@ -125,6 +130,7 @@ class RunCodeTool(ToolMessage):
|
|
125
130
|
purpose: str = """
|
126
131
|
To run <code> on the dataframe 'df' and
|
127
132
|
return the results to answer a question.
|
133
|
+
IMPORTANT: ALL the code should be in the <code> field.
|
128
134
|
"""
|
129
135
|
code: str
|
130
136
|
|
@@ -141,6 +147,8 @@ class TableChatAgent(ChatAgent):
|
|
141
147
|
Agent for chatting with a collection of documents.
|
142
148
|
"""
|
143
149
|
|
150
|
+
sent_code: bool = False
|
151
|
+
|
144
152
|
def __init__(self, config: TableChatAgentConfig):
|
145
153
|
if isinstance(config.data, pd.DataFrame):
|
146
154
|
df = config.data
|
@@ -165,6 +173,15 @@ class TableChatAgent(ChatAgent):
|
|
165
173
|
# enable the agent to use and handle the RunCodeTool
|
166
174
|
self.enable_message(RunCodeTool)
|
167
175
|
|
176
|
+
def user_response(
|
177
|
+
self,
|
178
|
+
msg: Optional[str | ChatDocument] = None,
|
179
|
+
) -> Optional[ChatDocument]:
|
180
|
+
response = super().user_response(msg)
|
181
|
+
if response is not None and response.content != "":
|
182
|
+
self.sent_code = False
|
183
|
+
return response
|
184
|
+
|
168
185
|
def run_code(self, msg: RunCodeTool) -> str:
|
169
186
|
"""
|
170
187
|
Handle a RunCodeTool message by running the code and returning the result.
|
@@ -174,6 +191,7 @@ class TableChatAgent(ChatAgent):
|
|
174
191
|
Returns:
|
175
192
|
str: The result of running the code along with any print output.
|
176
193
|
"""
|
194
|
+
self.sent_code = True
|
177
195
|
code = msg.code
|
178
196
|
# Create a dictionary that maps 'df' to the actual DataFrame
|
179
197
|
local_vars = {"df": self.df}
|
@@ -220,3 +238,19 @@ class TableChatAgent(ChatAgent):
|
|
220
238
|
result = "No result"
|
221
239
|
# Return the result
|
222
240
|
return result
|
241
|
+
|
242
|
+
def handle_message_fallback(
|
243
|
+
self, msg: str | ChatDocument
|
244
|
+
) -> str | ChatDocument | None:
|
245
|
+
"""Handle scenario where LLM forgets to say DONE or forgets to use run_code"""
|
246
|
+
if isinstance(msg, ChatDocument) and msg.metadata.sender == lr.Entity.LLM:
|
247
|
+
if self.sent_code:
|
248
|
+
return DONE
|
249
|
+
else:
|
250
|
+
return """
|
251
|
+
You forgot to use the `run_code` tool/function to find the answer.
|
252
|
+
Try again using the `run_code` tool/function.
|
253
|
+
Remember that ALL your code, including imports,
|
254
|
+
should be in the `code` field.
|
255
|
+
"""
|
256
|
+
return None
|