plot-agent 0.3.1__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {plot_agent-0.3.1 → plot_agent-0.4.0}/PKG-INFO +58 -3
- {plot_agent-0.3.1 → plot_agent-0.4.0}/README.md +57 -2
- plot_agent-0.4.0/plot_agent/agent.py +407 -0
- {plot_agent-0.3.1 → plot_agent-0.4.0}/plot_agent/execution.py +17 -5
- {plot_agent-0.3.1 → plot_agent-0.4.0}/plot_agent.egg-info/PKG-INFO +58 -3
- {plot_agent-0.3.1 → plot_agent-0.4.0}/pyproject.toml +1 -1
- plot_agent-0.3.1/plot_agent/agent.py +0 -269
- {plot_agent-0.3.1 → plot_agent-0.4.0}/LICENSE +0 -0
- {plot_agent-0.3.1 → plot_agent-0.4.0}/plot_agent/__init__.py +0 -0
- {plot_agent-0.3.1 → plot_agent-0.4.0}/plot_agent/models.py +0 -0
- {plot_agent-0.3.1 → plot_agent-0.4.0}/plot_agent/prompt.py +0 -0
- {plot_agent-0.3.1 → plot_agent-0.4.0}/plot_agent.egg-info/SOURCES.txt +0 -0
- {plot_agent-0.3.1 → plot_agent-0.4.0}/plot_agent.egg-info/dependency_links.txt +0 -0
- {plot_agent-0.3.1 → plot_agent-0.4.0}/plot_agent.egg-info/top_level.txt +0 -0
- {plot_agent-0.3.1 → plot_agent-0.4.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: plot-agent
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: An AI-powered data visualization assistant using Plotly
|
|
5
5
|
Author-email: andrewm4894 <andrewm4894@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
|
|
@@ -19,6 +19,8 @@ Dynamic: license-file
|
|
|
19
19
|
|
|
20
20
|
An AI-powered data visualization assistant that helps users create Plotly visualizations in Python.
|
|
21
21
|
|
|
22
|
+
Built on LangGraph with tool-calling to reliably execute generated Plotly code in a sandbox and keep the current `fig` in sync.
|
|
23
|
+
|
|
22
24
|
## Installation
|
|
23
25
|
|
|
24
26
|
You can install the package using pip:
|
|
@@ -37,7 +39,7 @@ Here's a simple minimal example of how to use Plot Agent:
|
|
|
37
39
|
import pandas as pd
|
|
38
40
|
from plot_agent.agent import PlotAgent
|
|
39
41
|
|
|
40
|
-
# ensure OPENAI_API_KEY is set
|
|
42
|
+
# ensure OPENAI_API_KEY is set (env or .env); optional debug via PLOT_AGENT_DEBUG=1
|
|
41
43
|
|
|
42
44
|
# Create a sample dataframe
|
|
43
45
|
df = pd.DataFrame({
|
|
@@ -92,19 +94,72 @@ fig.update_layout(
|
|
|
92
94
|
)
|
|
93
95
|
```
|
|
94
96
|
|
|
97
|
+
## How it works
|
|
98
|
+
|
|
99
|
+
```mermaid
|
|
100
|
+
flowchart TD
|
|
101
|
+
A[User message] --> B{LangGraph ReAct Agent}
|
|
102
|
+
subgraph Tools
|
|
103
|
+
T1[execute_plotly_code<br/>- runs code in sandbox<br/>- returns success/fig/error]
|
|
104
|
+
T2[does_fig_exist]
|
|
105
|
+
T3[view_generated_code]
|
|
106
|
+
end
|
|
107
|
+
B -- tool call --> T1
|
|
108
|
+
T1 -- result --> B
|
|
109
|
+
B -- optional --> T2
|
|
110
|
+
B -- optional --> T3
|
|
111
|
+
B --> C[AI response]
|
|
112
|
+
C --> D{Agent wrapper}
|
|
113
|
+
D -- persist messages --> B
|
|
114
|
+
D -- extract code blocks --> E[Sandbox execution]
|
|
115
|
+
E --> F[fig]
|
|
116
|
+
F --> G[get_figure]
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
- The LangGraph agent plans and decides when to call tools.
|
|
120
|
+
- The wrapper persists full graph messages between turns and executes any returned code blocks to keep `fig` updated.
|
|
121
|
+
- A safe execution environment runs code with an allowlist and a main-thread-only timeout.
|
|
122
|
+
|
|
95
123
|
## Features
|
|
96
124
|
|
|
97
125
|
- AI-powered visualization generation
|
|
98
126
|
- Support for various Plotly chart types
|
|
99
127
|
- Automatic data preprocessing
|
|
100
128
|
- Interactive visualization capabilities
|
|
101
|
-
-
|
|
129
|
+
- LangGraph-based tool calling and control flow
|
|
130
|
+
- Debug logging via `PlotAgent(debug=True)` or `PLOT_AGENT_DEBUG=1`
|
|
102
131
|
|
|
103
132
|
## Requirements
|
|
104
133
|
|
|
105
134
|
- Python 3.8 or higher
|
|
106
135
|
- Dependencies are automatically installed with the package
|
|
107
136
|
|
|
137
|
+
## Development
|
|
138
|
+
|
|
139
|
+
- Run unit tests:
|
|
140
|
+
|
|
141
|
+
```bash
|
|
142
|
+
make test
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
- Execute all example notebooks:
|
|
146
|
+
|
|
147
|
+
```bash
|
|
148
|
+
make run-examples
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
- Execute with debug logs enabled:
|
|
152
|
+
|
|
153
|
+
```bash
|
|
154
|
+
make run-examples-debug
|
|
155
|
+
```
|
|
156
|
+
|
|
157
|
+
- Quick CLI repro that prints evolving code each step:
|
|
158
|
+
|
|
159
|
+
```bash
|
|
160
|
+
make run-example-script
|
|
161
|
+
```
|
|
162
|
+
|
|
108
163
|
## License
|
|
109
164
|
|
|
110
165
|
This project is licensed under the MIT License - see the LICENSE file for details.
|
|
@@ -5,6 +5,8 @@
|
|
|
5
5
|
|
|
6
6
|
An AI-powered data visualization assistant that helps users create Plotly visualizations in Python.
|
|
7
7
|
|
|
8
|
+
Built on LangGraph with tool-calling to reliably execute generated Plotly code in a sandbox and keep the current `fig` in sync.
|
|
9
|
+
|
|
8
10
|
## Installation
|
|
9
11
|
|
|
10
12
|
You can install the package using pip:
|
|
@@ -23,7 +25,7 @@ Here's a simple minimal example of how to use Plot Agent:
|
|
|
23
25
|
import pandas as pd
|
|
24
26
|
from plot_agent.agent import PlotAgent
|
|
25
27
|
|
|
26
|
-
# ensure OPENAI_API_KEY is set
|
|
28
|
+
# ensure OPENAI_API_KEY is set (env or .env); optional debug via PLOT_AGENT_DEBUG=1
|
|
27
29
|
|
|
28
30
|
# Create a sample dataframe
|
|
29
31
|
df = pd.DataFrame({
|
|
@@ -78,19 +80,72 @@ fig.update_layout(
|
|
|
78
80
|
)
|
|
79
81
|
```
|
|
80
82
|
|
|
83
|
+
## How it works
|
|
84
|
+
|
|
85
|
+
```mermaid
|
|
86
|
+
flowchart TD
|
|
87
|
+
A[User message] --> B{LangGraph ReAct Agent}
|
|
88
|
+
subgraph Tools
|
|
89
|
+
T1[execute_plotly_code<br/>- runs code in sandbox<br/>- returns success/fig/error]
|
|
90
|
+
T2[does_fig_exist]
|
|
91
|
+
T3[view_generated_code]
|
|
92
|
+
end
|
|
93
|
+
B -- tool call --> T1
|
|
94
|
+
T1 -- result --> B
|
|
95
|
+
B -- optional --> T2
|
|
96
|
+
B -- optional --> T3
|
|
97
|
+
B --> C[AI response]
|
|
98
|
+
C --> D{Agent wrapper}
|
|
99
|
+
D -- persist messages --> B
|
|
100
|
+
D -- extract code blocks --> E[Sandbox execution]
|
|
101
|
+
E --> F[fig]
|
|
102
|
+
F --> G[get_figure]
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
- The LangGraph agent plans and decides when to call tools.
|
|
106
|
+
- The wrapper persists full graph messages between turns and executes any returned code blocks to keep `fig` updated.
|
|
107
|
+
- A safe execution environment runs code with an allowlist and a main-thread-only timeout.
|
|
108
|
+
|
|
81
109
|
## Features
|
|
82
110
|
|
|
83
111
|
- AI-powered visualization generation
|
|
84
112
|
- Support for various Plotly chart types
|
|
85
113
|
- Automatic data preprocessing
|
|
86
114
|
- Interactive visualization capabilities
|
|
87
|
-
-
|
|
115
|
+
- LangGraph-based tool calling and control flow
|
|
116
|
+
- Debug logging via `PlotAgent(debug=True)` or `PLOT_AGENT_DEBUG=1`
|
|
88
117
|
|
|
89
118
|
## Requirements
|
|
90
119
|
|
|
91
120
|
- Python 3.8 or higher
|
|
92
121
|
- Dependencies are automatically installed with the package
|
|
93
122
|
|
|
123
|
+
## Development
|
|
124
|
+
|
|
125
|
+
- Run unit tests:
|
|
126
|
+
|
|
127
|
+
```bash
|
|
128
|
+
make test
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
- Execute all example notebooks:
|
|
132
|
+
|
|
133
|
+
```bash
|
|
134
|
+
make run-examples
|
|
135
|
+
```
|
|
136
|
+
|
|
137
|
+
- Execute with debug logs enabled:
|
|
138
|
+
|
|
139
|
+
```bash
|
|
140
|
+
make run-examples-debug
|
|
141
|
+
```
|
|
142
|
+
|
|
143
|
+
- Quick CLI repro that prints evolving code each step:
|
|
144
|
+
|
|
145
|
+
```bash
|
|
146
|
+
make run-example-script
|
|
147
|
+
```
|
|
148
|
+
|
|
94
149
|
## License
|
|
95
150
|
|
|
96
151
|
This project is licensed under the MIT License - see the LICENSE file for details.
|
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the PlotAgent class, which is used to generate Plotly code based on a user's plot description.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from io import StringIO
|
|
7
|
+
import os
|
|
8
|
+
import re
|
|
9
|
+
import logging
|
|
10
|
+
from typing import Optional
|
|
11
|
+
from dotenv import load_dotenv
|
|
12
|
+
|
|
13
|
+
from langchain_core.messages import AIMessage, HumanMessage
|
|
14
|
+
from langchain_core.tools import Tool, StructuredTool
|
|
15
|
+
from langgraph.prebuilt import create_react_agent
|
|
16
|
+
from langchain_openai import ChatOpenAI
|
|
17
|
+
|
|
18
|
+
from plot_agent.prompt import DEFAULT_SYSTEM_PROMPT
|
|
19
|
+
from plot_agent.models import (
|
|
20
|
+
GeneratedCodeInput,
|
|
21
|
+
DoesFigExistInput,
|
|
22
|
+
ViewGeneratedCodeInput,
|
|
23
|
+
)
|
|
24
|
+
from plot_agent.execution import PlotAgentExecutionEnvironment
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PlotAgent:
|
|
28
|
+
"""
|
|
29
|
+
A class that uses an LLM to generate Plotly code based on a user's plot description.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
model: str = "gpt-4o-mini",
|
|
35
|
+
system_prompt: Optional[str] = None,
|
|
36
|
+
verbose: bool = True,
|
|
37
|
+
max_iterations: int = 10,
|
|
38
|
+
early_stopping_method: str = "force",
|
|
39
|
+
handle_parsing_errors: bool = True,
|
|
40
|
+
llm_temperature: float = 0.0,
|
|
41
|
+
llm_timeout: int = 60,
|
|
42
|
+
llm_max_retries: int = 1,
|
|
43
|
+
debug: bool = False,
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Initialize the PlotAgent.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
model (str): The model to use for the LLM.
|
|
50
|
+
system_prompt (Optional[str]): The system prompt to use for the LLM.
|
|
51
|
+
verbose (bool): Whether to print verbose output from the agent.
|
|
52
|
+
max_iterations (int): Maximum number of iterations for the agent to take.
|
|
53
|
+
early_stopping_method (str): Method to use for early stopping.
|
|
54
|
+
handle_parsing_errors (bool): Whether to handle parsing errors gracefully.
|
|
55
|
+
"""
|
|
56
|
+
# Load .env if present, then require a valid API key
|
|
57
|
+
load_dotenv()
|
|
58
|
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
59
|
+
if not openai_api_key:
|
|
60
|
+
raise RuntimeError(
|
|
61
|
+
"OPENAI_API_KEY is not set. Provide it via environment or a .env file."
|
|
62
|
+
)
|
|
63
|
+
self.debug = debug or os.getenv("PLOT_AGENT_DEBUG") == "1"
|
|
64
|
+
|
|
65
|
+
# Configure logger
|
|
66
|
+
self._logger = logging.getLogger("plot_agent")
|
|
67
|
+
if self.debug:
|
|
68
|
+
self._logger.setLevel(logging.DEBUG)
|
|
69
|
+
if not self._logger.handlers:
|
|
70
|
+
handler = logging.StreamHandler()
|
|
71
|
+
handler.setFormatter(
|
|
72
|
+
logging.Formatter(
|
|
73
|
+
"%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
74
|
+
datefmt="%H:%M:%S",
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
self._logger.addHandler(handler)
|
|
78
|
+
|
|
79
|
+
self.llm = ChatOpenAI(
|
|
80
|
+
model=model,
|
|
81
|
+
temperature=llm_temperature,
|
|
82
|
+
timeout=llm_timeout,
|
|
83
|
+
max_retries=llm_max_retries,
|
|
84
|
+
)
|
|
85
|
+
self.df = None
|
|
86
|
+
self.df_info = None
|
|
87
|
+
self.df_head = None
|
|
88
|
+
self.sql_query = None
|
|
89
|
+
self.execution_env = None
|
|
90
|
+
self.chat_history = []
|
|
91
|
+
# Internal graph-native message history, including tool messages
|
|
92
|
+
self._graph_messages = []
|
|
93
|
+
self.agent_executor = None
|
|
94
|
+
self.generated_code = None
|
|
95
|
+
self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
|
96
|
+
self.verbose = verbose
|
|
97
|
+
self.max_iterations = max_iterations
|
|
98
|
+
self.early_stopping_method = early_stopping_method
|
|
99
|
+
self.handle_parsing_errors = handle_parsing_errors
|
|
100
|
+
|
|
101
|
+
def set_df(self, df: pd.DataFrame, sql_query: Optional[str] = None):
|
|
102
|
+
"""
|
|
103
|
+
Set the dataframe and capture its schema and sample.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
df (pd.DataFrame): The pandas dataframe to set.
|
|
107
|
+
sql_query (Optional[str]): The SQL query used to generate the dataframe.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
None
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
# Check df
|
|
114
|
+
assert isinstance(df, pd.DataFrame), "The dataframe must be a pandas dataframe."
|
|
115
|
+
assert not df.empty, "The dataframe must not be empty."
|
|
116
|
+
|
|
117
|
+
if sql_query:
|
|
118
|
+
assert isinstance(sql_query, str), "The SQL query must be a string."
|
|
119
|
+
|
|
120
|
+
self.df = df
|
|
121
|
+
|
|
122
|
+
# Capture df.info() output
|
|
123
|
+
buffer = StringIO()
|
|
124
|
+
df.info(buf=buffer)
|
|
125
|
+
self.df_info = buffer.getvalue()
|
|
126
|
+
|
|
127
|
+
# Capture df.head() as string representation
|
|
128
|
+
self.df_head = df.head().to_string()
|
|
129
|
+
|
|
130
|
+
# Store SQL query if provided
|
|
131
|
+
self.sql_query = sql_query
|
|
132
|
+
|
|
133
|
+
# Initialize execution environment
|
|
134
|
+
self.execution_env = PlotAgentExecutionEnvironment(df)
|
|
135
|
+
|
|
136
|
+
# Initialize the agent with tools
|
|
137
|
+
self._initialize_agent()
|
|
138
|
+
# Reset graph messages for a fresh session with this dataframe
|
|
139
|
+
self._graph_messages = []
|
|
140
|
+
if self.debug:
|
|
141
|
+
self._logger.debug("set_df() initialized execution environment and graph")
|
|
142
|
+
|
|
143
|
+
def execute_plotly_code(self, generated_code: str) -> str:
|
|
144
|
+
"""
|
|
145
|
+
Execute the provided Plotly code and return the result.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
generated_code (str): The Plotly code to execute.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
str: The result of the execution.
|
|
152
|
+
"""
|
|
153
|
+
assert isinstance(generated_code, str), "The generated code must be a string."
|
|
154
|
+
|
|
155
|
+
if not self.execution_env:
|
|
156
|
+
return "Error: No dataframe has been set. Please set a dataframe first."
|
|
157
|
+
|
|
158
|
+
# Store this as the last generated code
|
|
159
|
+
self.generated_code = generated_code
|
|
160
|
+
|
|
161
|
+
# Execute the generated code
|
|
162
|
+
code_execution_result = self.execution_env.execute_code(generated_code)
|
|
163
|
+
|
|
164
|
+
# Extract the results from the code execution
|
|
165
|
+
code_execution_success = code_execution_result.get("success", False)
|
|
166
|
+
code_execution_output = code_execution_result.get("output", "")
|
|
167
|
+
code_execution_error = code_execution_result.get("error", "")
|
|
168
|
+
|
|
169
|
+
# Check if the code executed successfully
|
|
170
|
+
if code_execution_success:
|
|
171
|
+
return f"Success: {code_execution_output}"
|
|
172
|
+
else:
|
|
173
|
+
return f"Error: {code_execution_error}\n{code_execution_output}"
|
|
174
|
+
|
|
175
|
+
def does_fig_exist(self, *args, **kwargs) -> str:
|
|
176
|
+
"""
|
|
177
|
+
Check if a figure object is available for display.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
*args: Any positional arguments (ignored)
|
|
181
|
+
**kwargs: Any keyword arguments (ignored)
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
str: A message indicating whether a figure is available for display.
|
|
185
|
+
"""
|
|
186
|
+
if not self.execution_env:
|
|
187
|
+
return "No execution environment has been initialized. Please set a dataframe first."
|
|
188
|
+
|
|
189
|
+
if self.execution_env.fig is not None:
|
|
190
|
+
return "A figure is available for display."
|
|
191
|
+
else:
|
|
192
|
+
return "No figure has been created yet."
|
|
193
|
+
|
|
194
|
+
def view_generated_code(self, *args, **kwargs) -> str:
|
|
195
|
+
"""
|
|
196
|
+
View the generated code.
|
|
197
|
+
"""
|
|
198
|
+
return self.generated_code or ""
|
|
199
|
+
|
|
200
|
+
def _initialize_agent(self):
|
|
201
|
+
"""Initialize a LangGraph ReAct agent with tools and keep API compatibility."""
|
|
202
|
+
|
|
203
|
+
# Initialize the tools
|
|
204
|
+
tools = [
|
|
205
|
+
Tool.from_function(
|
|
206
|
+
func=self.execute_plotly_code,
|
|
207
|
+
name="execute_plotly_code",
|
|
208
|
+
description=(
|
|
209
|
+
"Execute the provided Plotly code and return a result indicating "
|
|
210
|
+
"if the code executed successfully and if a figure object was created."
|
|
211
|
+
),
|
|
212
|
+
args_schema=GeneratedCodeInput,
|
|
213
|
+
),
|
|
214
|
+
StructuredTool.from_function(
|
|
215
|
+
func=self.does_fig_exist,
|
|
216
|
+
name="does_fig_exist",
|
|
217
|
+
description=(
|
|
218
|
+
"Check if a figure exists and is available for display. "
|
|
219
|
+
"This tool takes no arguments and returns a string indicating "
|
|
220
|
+
"if a figure is available for display or not."
|
|
221
|
+
),
|
|
222
|
+
args_schema=DoesFigExistInput,
|
|
223
|
+
),
|
|
224
|
+
StructuredTool.from_function(
|
|
225
|
+
func=self.view_generated_code,
|
|
226
|
+
name="view_generated_code",
|
|
227
|
+
description=(
|
|
228
|
+
"View the generated code. "
|
|
229
|
+
"This tool takes no arguments and returns the generated code as a string."
|
|
230
|
+
),
|
|
231
|
+
args_schema=ViewGeneratedCodeInput,
|
|
232
|
+
),
|
|
233
|
+
]
|
|
234
|
+
|
|
235
|
+
# Prepare system prompt with dataframe information
|
|
236
|
+
sql_context = ""
|
|
237
|
+
if self.sql_query:
|
|
238
|
+
sql_context = (
|
|
239
|
+
"In case it is useful to help with the data understanding, the df was generated using the following SQL query:\n"
|
|
240
|
+
f"```sql\n{self.sql_query}\n```"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Store formatted system instructions for the graph state modifier
|
|
244
|
+
self._system_message_content = self.system_prompt.format(
|
|
245
|
+
df_info=self.df_info,
|
|
246
|
+
df_head=self.df_head,
|
|
247
|
+
sql_context=sql_context,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Create a ReAct agent graph with the provided tools and system prompt
|
|
251
|
+
self._graph = create_react_agent(
|
|
252
|
+
self.llm,
|
|
253
|
+
tools,
|
|
254
|
+
prompt=self._system_message_content,
|
|
255
|
+
debug=self.debug,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Backwards-compatibility: expose under the old attribute name
|
|
259
|
+
self.agent_executor = self._graph
|
|
260
|
+
|
|
261
|
+
def process_message(self, user_message: str) -> str:
|
|
262
|
+
"""Process a user message and return the agent's response."""
|
|
263
|
+
assert isinstance(user_message, str), "The user message must be a string."
|
|
264
|
+
|
|
265
|
+
if not self.agent_executor:
|
|
266
|
+
return "Please set a dataframe first using set_df() method."
|
|
267
|
+
|
|
268
|
+
# Add user message to outward-facing chat history
|
|
269
|
+
self.chat_history.append(HumanMessage(content=user_message))
|
|
270
|
+
|
|
271
|
+
# Reset generated_code
|
|
272
|
+
self.generated_code = None
|
|
273
|
+
|
|
274
|
+
# Short-circuit empty inputs to avoid graph recursion
|
|
275
|
+
if user_message.strip() == "":
|
|
276
|
+
ai_content = (
|
|
277
|
+
"Please provide a non-empty plotting request (e.g., 'scatter x vs y')."
|
|
278
|
+
)
|
|
279
|
+
self.chat_history.append(AIMessage(content=ai_content))
|
|
280
|
+
if self.debug:
|
|
281
|
+
self._logger.debug("empty message received; returning guidance without invoking graph")
|
|
282
|
+
return ai_content
|
|
283
|
+
|
|
284
|
+
# Short-circuit messages that are primarily raw code blocks without a visualization request
|
|
285
|
+
if "```" in user_message and not re.search(
|
|
286
|
+
r"\b(plot|chart|graph|visuali(s|z)e|figure|subplot|heatmap|bar|line|scatter)\b",
|
|
287
|
+
user_message,
|
|
288
|
+
flags=re.IGNORECASE,
|
|
289
|
+
):
|
|
290
|
+
ai_content = (
|
|
291
|
+
"I see a code snippet. Please describe the visualization you want (e.g., 'line chart of y over x')."
|
|
292
|
+
)
|
|
293
|
+
self.chat_history.append(AIMessage(content=ai_content))
|
|
294
|
+
if self.debug:
|
|
295
|
+
self._logger.debug("code-only message received; returning guidance without invoking graph")
|
|
296
|
+
return ai_content
|
|
297
|
+
|
|
298
|
+
# Build graph messages (includes tool call/observation history)
|
|
299
|
+
graph_messages = [*self._graph_messages, HumanMessage(content=user_message)]
|
|
300
|
+
if self.debug:
|
|
301
|
+
self._logger.debug(f"process_message() user: {user_message}")
|
|
302
|
+
self._logger.debug(f"graph message count before invoke: {len(graph_messages)}")
|
|
303
|
+
# Invoke the LangGraph agent
|
|
304
|
+
result = self.agent_executor.invoke(
|
|
305
|
+
{"messages": graph_messages},
|
|
306
|
+
config={"recursion_limit": self.max_iterations},
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Extract the latest AI message from the returned messages
|
|
310
|
+
ai_messages = [m for m in result.get("messages", []) if isinstance(m, AIMessage)]
|
|
311
|
+
ai_content = ai_messages[-1].content if ai_messages else ""
|
|
312
|
+
|
|
313
|
+
# Persist full graph messages for future context
|
|
314
|
+
self._graph_messages = result.get("messages", [])
|
|
315
|
+
if self.debug:
|
|
316
|
+
self._logger.debug(f"graph message count after invoke: {len(self._graph_messages)}")
|
|
317
|
+
|
|
318
|
+
# Add agent response to outward-facing chat history
|
|
319
|
+
self.chat_history.append(AIMessage(content=ai_content))
|
|
320
|
+
|
|
321
|
+
# If the agent didn't execute the code via tool, but we have prior generated_code, execute it
|
|
322
|
+
if self.execution_env and self.execution_env.fig is None and self.generated_code is not None:
|
|
323
|
+
if self.debug:
|
|
324
|
+
self._logger.debug("executing stored generated_code because no fig exists yet")
|
|
325
|
+
exec_result = self.execution_env.execute_code(self.generated_code)
|
|
326
|
+
if self.debug:
|
|
327
|
+
self._logger.debug(f"execution result success={exec_result.get('success')} error={exec_result.get('error')!r}")
|
|
328
|
+
|
|
329
|
+
# If the assistant returned code in the message, execute it to update the figure
|
|
330
|
+
code_executed = False
|
|
331
|
+
if self.execution_env and isinstance(ai_content, str):
|
|
332
|
+
extracted_code = None
|
|
333
|
+
if "```python" in ai_content:
|
|
334
|
+
parts = ai_content.split("```python", 1)
|
|
335
|
+
extracted_code = parts[1].split("```", 1)[0].strip() if len(parts) > 1 else None
|
|
336
|
+
elif "```" in ai_content:
|
|
337
|
+
# Fallback: extract first generic fenced code block
|
|
338
|
+
parts = ai_content.split("```", 1)
|
|
339
|
+
if len(parts) > 1:
|
|
340
|
+
extracted_code = parts[1].split("```", 1)[0].strip()
|
|
341
|
+
if extracted_code:
|
|
342
|
+
if (self.generated_code or "").strip() != extracted_code:
|
|
343
|
+
self.generated_code = extracted_code
|
|
344
|
+
if self.debug:
|
|
345
|
+
self._logger.debug("executing code extracted from AI message")
|
|
346
|
+
exec_result = self.execution_env.execute_code(extracted_code)
|
|
347
|
+
if self.debug:
|
|
348
|
+
self._logger.debug(f"execution result success={exec_result.get('success')} error={exec_result.get('error')!r}")
|
|
349
|
+
code_executed = True
|
|
350
|
+
|
|
351
|
+
# If still no figure and no code was executed, run one guided retry to force tool usage
|
|
352
|
+
if self.execution_env and self.execution_env.fig is None and not code_executed:
|
|
353
|
+
if self.debug:
|
|
354
|
+
self._logger.debug("guided retry: prompting model to use execute_plotly_code tool")
|
|
355
|
+
guided_messages = [
|
|
356
|
+
*self._graph_messages,
|
|
357
|
+
HumanMessage(
|
|
358
|
+
content=(
|
|
359
|
+
"Please use the execute_plotly_code(generated_code) tool with the FULL code to "
|
|
360
|
+
"create a variable named 'fig', then call does_fig_exist(). Return the final "
|
|
361
|
+
"code in a fenced ```python block."
|
|
362
|
+
)
|
|
363
|
+
),
|
|
364
|
+
]
|
|
365
|
+
retry_result = self.agent_executor.invoke(
|
|
366
|
+
{"messages": guided_messages},
|
|
367
|
+
config={"recursion_limit": max(3, self.max_iterations // 2)},
|
|
368
|
+
)
|
|
369
|
+
self._graph_messages = retry_result.get("messages", [])
|
|
370
|
+
retry_ai_messages = [
|
|
371
|
+
m for m in self._graph_messages if isinstance(m, AIMessage)
|
|
372
|
+
]
|
|
373
|
+
retry_content = retry_ai_messages[-1].content if retry_ai_messages else ""
|
|
374
|
+
if isinstance(retry_content, str):
|
|
375
|
+
if "```python" in retry_content:
|
|
376
|
+
parts = retry_content.split("```python", 1)
|
|
377
|
+
retry_code = (
|
|
378
|
+
parts[1].split("```", 1)[0].strip() if len(parts) > 1 else None
|
|
379
|
+
)
|
|
380
|
+
elif "```" in retry_content:
|
|
381
|
+
parts = retry_content.split("```", 1)
|
|
382
|
+
retry_code = (
|
|
383
|
+
parts[1].split("```", 1)[0].strip() if len(parts) > 1 else None
|
|
384
|
+
)
|
|
385
|
+
else:
|
|
386
|
+
retry_code = None
|
|
387
|
+
if retry_code:
|
|
388
|
+
if (self.generated_code or "").strip() != retry_code:
|
|
389
|
+
self.generated_code = retry_code
|
|
390
|
+
if self.debug:
|
|
391
|
+
self._logger.debug("executing code extracted from guided retry response")
|
|
392
|
+
exec_result = self.execution_env.execute_code(retry_code)
|
|
393
|
+
if self.debug:
|
|
394
|
+
self._logger.debug(f"execution result success={exec_result.get('success')} error={exec_result.get('error')!r}")
|
|
395
|
+
|
|
396
|
+
return ai_content if isinstance(ai_content, str) else str(ai_content)
|
|
397
|
+
|
|
398
|
+
def get_figure(self):
|
|
399
|
+
"""Return the current figure if one exists."""
|
|
400
|
+
if self.execution_env and self.execution_env.fig:
|
|
401
|
+
return self.execution_env.fig
|
|
402
|
+
return None
|
|
403
|
+
|
|
404
|
+
def reset_conversation(self):
|
|
405
|
+
"""Reset the conversation history."""
|
|
406
|
+
self.chat_history = []
|
|
407
|
+
self.generated_code = None
|
|
@@ -11,6 +11,7 @@ Security features:
|
|
|
11
11
|
import ast
|
|
12
12
|
import builtins
|
|
13
13
|
import signal
|
|
14
|
+
import threading
|
|
14
15
|
import traceback
|
|
15
16
|
from io import StringIO
|
|
16
17
|
import contextlib
|
|
@@ -183,9 +184,16 @@ class PlotAgentExecutionEnvironment:
|
|
|
183
184
|
"success": False,
|
|
184
185
|
}
|
|
185
186
|
|
|
186
|
-
# Set a timeout
|
|
187
|
-
|
|
188
|
-
|
|
187
|
+
# Set a timeout only if running on the main thread; signals are not supported in worker threads
|
|
188
|
+
timeout_set = False
|
|
189
|
+
try:
|
|
190
|
+
if threading.current_thread() is threading.main_thread():
|
|
191
|
+
signal.signal(signal.SIGALRM, _timeout_handler)
|
|
192
|
+
signal.alarm(self.TIMEOUT_SECONDS)
|
|
193
|
+
timeout_set = True
|
|
194
|
+
except Exception:
|
|
195
|
+
# If setting the signal handler fails (e.g., not in main thread), proceed without timeout
|
|
196
|
+
timeout_set = False
|
|
189
197
|
|
|
190
198
|
# Execute the code
|
|
191
199
|
out_buf, err_buf = StringIO(), StringIO()
|
|
@@ -215,8 +223,12 @@ class PlotAgentExecutionEnvironment:
|
|
|
215
223
|
"success": False,
|
|
216
224
|
}
|
|
217
225
|
finally:
|
|
218
|
-
# Reset the timeout
|
|
219
|
-
|
|
226
|
+
# Reset the timeout if it was set
|
|
227
|
+
if timeout_set:
|
|
228
|
+
try:
|
|
229
|
+
signal.alarm(0)
|
|
230
|
+
except Exception:
|
|
231
|
+
pass
|
|
220
232
|
|
|
221
233
|
# Get the `fig`
|
|
222
234
|
fig = ns.get("fig")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: plot-agent
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: An AI-powered data visualization assistant using Plotly
|
|
5
5
|
Author-email: andrewm4894 <andrewm4894@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
|
|
@@ -19,6 +19,8 @@ Dynamic: license-file
|
|
|
19
19
|
|
|
20
20
|
An AI-powered data visualization assistant that helps users create Plotly visualizations in Python.
|
|
21
21
|
|
|
22
|
+
Built on LangGraph with tool-calling to reliably execute generated Plotly code in a sandbox and keep the current `fig` in sync.
|
|
23
|
+
|
|
22
24
|
## Installation
|
|
23
25
|
|
|
24
26
|
You can install the package using pip:
|
|
@@ -37,7 +39,7 @@ Here's a simple minimal example of how to use Plot Agent:
|
|
|
37
39
|
import pandas as pd
|
|
38
40
|
from plot_agent.agent import PlotAgent
|
|
39
41
|
|
|
40
|
-
# ensure OPENAI_API_KEY is set
|
|
42
|
+
# ensure OPENAI_API_KEY is set (env or .env); optional debug via PLOT_AGENT_DEBUG=1
|
|
41
43
|
|
|
42
44
|
# Create a sample dataframe
|
|
43
45
|
df = pd.DataFrame({
|
|
@@ -92,19 +94,72 @@ fig.update_layout(
|
|
|
92
94
|
)
|
|
93
95
|
```
|
|
94
96
|
|
|
97
|
+
## How it works
|
|
98
|
+
|
|
99
|
+
```mermaid
|
|
100
|
+
flowchart TD
|
|
101
|
+
A[User message] --> B{LangGraph ReAct Agent}
|
|
102
|
+
subgraph Tools
|
|
103
|
+
T1[execute_plotly_code<br/>- runs code in sandbox<br/>- returns success/fig/error]
|
|
104
|
+
T2[does_fig_exist]
|
|
105
|
+
T3[view_generated_code]
|
|
106
|
+
end
|
|
107
|
+
B -- tool call --> T1
|
|
108
|
+
T1 -- result --> B
|
|
109
|
+
B -- optional --> T2
|
|
110
|
+
B -- optional --> T3
|
|
111
|
+
B --> C[AI response]
|
|
112
|
+
C --> D{Agent wrapper}
|
|
113
|
+
D -- persist messages --> B
|
|
114
|
+
D -- extract code blocks --> E[Sandbox execution]
|
|
115
|
+
E --> F[fig]
|
|
116
|
+
F --> G[get_figure]
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
- The LangGraph agent plans and decides when to call tools.
|
|
120
|
+
- The wrapper persists full graph messages between turns and executes any returned code blocks to keep `fig` updated.
|
|
121
|
+
- A safe execution environment runs code with an allowlist and a main-thread-only timeout.
|
|
122
|
+
|
|
95
123
|
## Features
|
|
96
124
|
|
|
97
125
|
- AI-powered visualization generation
|
|
98
126
|
- Support for various Plotly chart types
|
|
99
127
|
- Automatic data preprocessing
|
|
100
128
|
- Interactive visualization capabilities
|
|
101
|
-
-
|
|
129
|
+
- LangGraph-based tool calling and control flow
|
|
130
|
+
- Debug logging via `PlotAgent(debug=True)` or `PLOT_AGENT_DEBUG=1`
|
|
102
131
|
|
|
103
132
|
## Requirements
|
|
104
133
|
|
|
105
134
|
- Python 3.8 or higher
|
|
106
135
|
- Dependencies are automatically installed with the package
|
|
107
136
|
|
|
137
|
+
## Development
|
|
138
|
+
|
|
139
|
+
- Run unit tests:
|
|
140
|
+
|
|
141
|
+
```bash
|
|
142
|
+
make test
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
- Execute all example notebooks:
|
|
146
|
+
|
|
147
|
+
```bash
|
|
148
|
+
make run-examples
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
- Execute with debug logs enabled:
|
|
152
|
+
|
|
153
|
+
```bash
|
|
154
|
+
make run-examples-debug
|
|
155
|
+
```
|
|
156
|
+
|
|
157
|
+
- Quick CLI repro that prints evolving code each step:
|
|
158
|
+
|
|
159
|
+
```bash
|
|
160
|
+
make run-example-script
|
|
161
|
+
```
|
|
162
|
+
|
|
108
163
|
## License
|
|
109
164
|
|
|
110
165
|
This project is licensed under the MIT License - see the LICENSE file for details.
|
|
@@ -1,269 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
This module contains the PlotAgent class, which is used to generate Plotly code based on a user's plot description.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import pandas as pd
|
|
6
|
-
from io import StringIO
|
|
7
|
-
from typing import Optional
|
|
8
|
-
|
|
9
|
-
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
10
|
-
from langchain_core.messages import AIMessage, HumanMessage
|
|
11
|
-
from langchain_core.tools import Tool, StructuredTool
|
|
12
|
-
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
|
13
|
-
from langchain_openai import ChatOpenAI
|
|
14
|
-
|
|
15
|
-
from plot_agent.prompt import DEFAULT_SYSTEM_PROMPT
|
|
16
|
-
from plot_agent.models import (
|
|
17
|
-
GeneratedCodeInput,
|
|
18
|
-
DoesFigExistInput,
|
|
19
|
-
ViewGeneratedCodeInput,
|
|
20
|
-
)
|
|
21
|
-
from plot_agent.execution import PlotAgentExecutionEnvironment
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class PlotAgent:
|
|
25
|
-
"""
|
|
26
|
-
A class that uses an LLM to generate Plotly code based on a user's plot description.
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
def __init__(
|
|
30
|
-
self,
|
|
31
|
-
model="gpt-4o-mini",
|
|
32
|
-
system_prompt: Optional[str] = None,
|
|
33
|
-
verbose: bool = True,
|
|
34
|
-
max_iterations: int = 10,
|
|
35
|
-
early_stopping_method: str = "force",
|
|
36
|
-
handle_parsing_errors: bool = True,
|
|
37
|
-
):
|
|
38
|
-
"""
|
|
39
|
-
Initialize the PlotAgent.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
model (str): The model to use for the LLM.
|
|
43
|
-
system_prompt (Optional[str]): The system prompt to use for the LLM.
|
|
44
|
-
verbose (bool): Whether to print verbose output from the agent.
|
|
45
|
-
max_iterations (int): Maximum number of iterations for the agent to take.
|
|
46
|
-
early_stopping_method (str): Method to use for early stopping.
|
|
47
|
-
handle_parsing_errors (bool): Whether to handle parsing errors gracefully.
|
|
48
|
-
"""
|
|
49
|
-
self.llm = ChatOpenAI(model=model)
|
|
50
|
-
self.df = None
|
|
51
|
-
self.df_info = None
|
|
52
|
-
self.df_head = None
|
|
53
|
-
self.sql_query = None
|
|
54
|
-
self.execution_env = None
|
|
55
|
-
self.chat_history = []
|
|
56
|
-
self.agent_executor = None
|
|
57
|
-
self.generated_code = None
|
|
58
|
-
self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
|
59
|
-
self.verbose = verbose
|
|
60
|
-
self.max_iterations = max_iterations
|
|
61
|
-
self.early_stopping_method = early_stopping_method
|
|
62
|
-
self.handle_parsing_errors = handle_parsing_errors
|
|
63
|
-
|
|
64
|
-
def set_df(self, df: pd.DataFrame, sql_query: Optional[str] = None):
|
|
65
|
-
"""
|
|
66
|
-
Set the dataframe and capture its schema and sample.
|
|
67
|
-
|
|
68
|
-
Args:
|
|
69
|
-
df (pd.DataFrame): The pandas dataframe to set.
|
|
70
|
-
sql_query (Optional[str]): The SQL query used to generate the dataframe.
|
|
71
|
-
|
|
72
|
-
Returns:
|
|
73
|
-
None
|
|
74
|
-
"""
|
|
75
|
-
|
|
76
|
-
# Check df
|
|
77
|
-
assert isinstance(df, pd.DataFrame), "The dataframe must be a pandas dataframe."
|
|
78
|
-
assert not df.empty, "The dataframe must not be empty."
|
|
79
|
-
|
|
80
|
-
if sql_query:
|
|
81
|
-
assert isinstance(sql_query, str), "The SQL query must be a string."
|
|
82
|
-
|
|
83
|
-
self.df = df
|
|
84
|
-
|
|
85
|
-
# Capture df.info() output
|
|
86
|
-
buffer = StringIO()
|
|
87
|
-
df.info(buf=buffer)
|
|
88
|
-
self.df_info = buffer.getvalue()
|
|
89
|
-
|
|
90
|
-
# Capture df.head() as string representation
|
|
91
|
-
self.df_head = df.head().to_string()
|
|
92
|
-
|
|
93
|
-
# Store SQL query if provided
|
|
94
|
-
self.sql_query = sql_query
|
|
95
|
-
|
|
96
|
-
# Initialize execution environment
|
|
97
|
-
self.execution_env = PlotAgentExecutionEnvironment(df)
|
|
98
|
-
|
|
99
|
-
# Initialize the agent with tools
|
|
100
|
-
self._initialize_agent()
|
|
101
|
-
|
|
102
|
-
def execute_plotly_code(self, generated_code: str) -> str:
|
|
103
|
-
"""
|
|
104
|
-
Execute the provided Plotly code and return the result.
|
|
105
|
-
|
|
106
|
-
Args:
|
|
107
|
-
generated_code (str): The Plotly code to execute.
|
|
108
|
-
|
|
109
|
-
Returns:
|
|
110
|
-
str: The result of the execution.
|
|
111
|
-
"""
|
|
112
|
-
assert isinstance(generated_code, str), "The generated code must be a string."
|
|
113
|
-
|
|
114
|
-
if not self.execution_env:
|
|
115
|
-
return "Error: No dataframe has been set. Please set a dataframe first."
|
|
116
|
-
|
|
117
|
-
# Store this as the last generated code
|
|
118
|
-
self.generated_code = generated_code
|
|
119
|
-
|
|
120
|
-
# Execute the generated code
|
|
121
|
-
code_execution_result = self.execution_env.execute_code(generated_code)
|
|
122
|
-
|
|
123
|
-
# Extract the results from the code execution
|
|
124
|
-
code_execution_success = code_execution_result.get("success", False)
|
|
125
|
-
code_execution_output = code_execution_result.get("output", "")
|
|
126
|
-
code_execution_error = code_execution_result.get("error", "")
|
|
127
|
-
|
|
128
|
-
# Check if the code executed successfully
|
|
129
|
-
if code_execution_success:
|
|
130
|
-
return f"Success: {code_execution_output}"
|
|
131
|
-
else:
|
|
132
|
-
return f"Error: {code_execution_error}\n{code_execution_output}"
|
|
133
|
-
|
|
134
|
-
def does_fig_exist(self, *args, **kwargs) -> str:
|
|
135
|
-
"""
|
|
136
|
-
Check if a figure object is available for display.
|
|
137
|
-
|
|
138
|
-
Args:
|
|
139
|
-
*args: Any positional arguments (ignored)
|
|
140
|
-
**kwargs: Any keyword arguments (ignored)
|
|
141
|
-
|
|
142
|
-
Returns:
|
|
143
|
-
str: A message indicating whether a figure is available for display.
|
|
144
|
-
"""
|
|
145
|
-
if not self.execution_env:
|
|
146
|
-
return "No execution environment has been initialized. Please set a dataframe first."
|
|
147
|
-
|
|
148
|
-
if self.execution_env.fig is not None:
|
|
149
|
-
return "A figure is available for display."
|
|
150
|
-
else:
|
|
151
|
-
return "No figure has been created yet."
|
|
152
|
-
|
|
153
|
-
def view_generated_code(self, *args, **kwargs) -> str:
|
|
154
|
-
"""
|
|
155
|
-
View the generated code.
|
|
156
|
-
"""
|
|
157
|
-
return self.generated_code
|
|
158
|
-
|
|
159
|
-
def _initialize_agent(self):
|
|
160
|
-
"""Initialize the LangChain agent with the necessary tools and prompt."""
|
|
161
|
-
|
|
162
|
-
# Initialize the tools
|
|
163
|
-
tools = [
|
|
164
|
-
Tool.from_function(
|
|
165
|
-
func=self.execute_plotly_code,
|
|
166
|
-
name="execute_plotly_code",
|
|
167
|
-
description=(
|
|
168
|
-
"Execute the provided Plotly code and return a result indicating "
|
|
169
|
-
"if the code executed successfully and if a figure object was created."
|
|
170
|
-
),
|
|
171
|
-
args_schema=GeneratedCodeInput,
|
|
172
|
-
),
|
|
173
|
-
StructuredTool.from_function(
|
|
174
|
-
func=self.does_fig_exist,
|
|
175
|
-
name="does_fig_exist",
|
|
176
|
-
description=(
|
|
177
|
-
"Check if a figure exists and is available for display. "
|
|
178
|
-
"This tool takes no arguments and returns a string indicating "
|
|
179
|
-
"if a figure is available for display or not."
|
|
180
|
-
),
|
|
181
|
-
args_schema=DoesFigExistInput,
|
|
182
|
-
),
|
|
183
|
-
StructuredTool.from_function(
|
|
184
|
-
func=self.view_generated_code,
|
|
185
|
-
name="view_generated_code",
|
|
186
|
-
description=(
|
|
187
|
-
"View the generated code. "
|
|
188
|
-
"This tool takes no arguments and returns the generated code as a string."
|
|
189
|
-
),
|
|
190
|
-
args_schema=ViewGeneratedCodeInput,
|
|
191
|
-
),
|
|
192
|
-
]
|
|
193
|
-
|
|
194
|
-
# Create system prompt with dataframe information
|
|
195
|
-
sql_context = ""
|
|
196
|
-
if self.sql_query:
|
|
197
|
-
sql_context = f"In case it is useful to help with the data understanding, the df was generated using the following SQL query:\n```sql\n{self.sql_query}\n```"
|
|
198
|
-
|
|
199
|
-
prompt = ChatPromptTemplate.from_messages(
|
|
200
|
-
[
|
|
201
|
-
(
|
|
202
|
-
"system",
|
|
203
|
-
self.system_prompt.format(
|
|
204
|
-
df_info=self.df_info,
|
|
205
|
-
df_head=self.df_head,
|
|
206
|
-
sql_context=sql_context,
|
|
207
|
-
),
|
|
208
|
-
),
|
|
209
|
-
MessagesPlaceholder(variable_name="chat_history"),
|
|
210
|
-
("human", "{input}"),
|
|
211
|
-
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
|
212
|
-
]
|
|
213
|
-
)
|
|
214
|
-
|
|
215
|
-
agent = create_openai_tools_agent(self.llm, tools, prompt)
|
|
216
|
-
self.agent_executor = AgentExecutor(
|
|
217
|
-
agent=agent,
|
|
218
|
-
tools=tools,
|
|
219
|
-
verbose=self.verbose,
|
|
220
|
-
max_iterations=self.max_iterations,
|
|
221
|
-
early_stopping_method=self.early_stopping_method,
|
|
222
|
-
handle_parsing_errors=self.handle_parsing_errors,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
def process_message(self, user_message: str) -> str:
|
|
226
|
-
"""Process a user message and return the agent's response."""
|
|
227
|
-
assert isinstance(user_message, str), "The user message must be a string."
|
|
228
|
-
|
|
229
|
-
if not self.agent_executor:
|
|
230
|
-
return "Please set a dataframe first using set_df() method."
|
|
231
|
-
|
|
232
|
-
# Add user message to chat history
|
|
233
|
-
self.chat_history.append(HumanMessage(content=user_message))
|
|
234
|
-
|
|
235
|
-
# Reset generated_code
|
|
236
|
-
self.generated_code = None
|
|
237
|
-
|
|
238
|
-
# Get response from agent
|
|
239
|
-
response = self.agent_executor.invoke(
|
|
240
|
-
{"input": user_message, "chat_history": self.chat_history}
|
|
241
|
-
)
|
|
242
|
-
|
|
243
|
-
# Add agent response to chat history
|
|
244
|
-
self.chat_history.append(AIMessage(content=response["output"]))
|
|
245
|
-
|
|
246
|
-
# If the agent didn't execute the code, but did generate code, execute it directly
|
|
247
|
-
if self.execution_env.fig is None and self.generated_code is not None:
|
|
248
|
-
self.execution_env.execute_code(self.generated_code)
|
|
249
|
-
|
|
250
|
-
# If we can extract code from the response when no code was executed, try that too
|
|
251
|
-
if self.execution_env.fig is None and "```python" in response["output"]:
|
|
252
|
-
code_blocks = response["output"].split("```python")
|
|
253
|
-
if len(code_blocks) > 1:
|
|
254
|
-
generated_code = code_blocks[1].split("```")[0].strip()
|
|
255
|
-
self.execution_env.execute_code(generated_code)
|
|
256
|
-
|
|
257
|
-
# Return the agent's response
|
|
258
|
-
return response["output"]
|
|
259
|
-
|
|
260
|
-
def get_figure(self):
|
|
261
|
-
"""Return the current figure if one exists."""
|
|
262
|
-
if self.execution_env and self.execution_env.fig:
|
|
263
|
-
return self.execution_env.fig
|
|
264
|
-
return None
|
|
265
|
-
|
|
266
|
-
def reset_conversation(self):
|
|
267
|
-
"""Reset the conversation history."""
|
|
268
|
-
self.chat_history = []
|
|
269
|
-
self.generated_code = None
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|