plot-agent 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plot_agent/agent.py +207 -55
- plot_agent/execution.py +211 -51
- plot_agent/models.py +4 -0
- plot_agent/prompt.py +4 -0
- {plot_agent-0.3.0.dist-info → plot_agent-0.4.0.dist-info}/METADATA +58 -3
- plot_agent-0.4.0.dist-info/RECORD +10 -0
- {plot_agent-0.3.0.dist-info → plot_agent-0.4.0.dist-info}/WHEEL +1 -1
- plot_agent-0.3.0.dist-info/RECORD +0 -10
- {plot_agent-0.3.0.dist-info → plot_agent-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {plot_agent-0.3.0.dist-info → plot_agent-0.4.0.dist-info}/top_level.txt +0 -0
plot_agent/agent.py
CHANGED
|
@@ -1,11 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the PlotAgent class, which is used to generate Plotly code based on a user's plot description.
|
|
3
|
+
"""
|
|
4
|
+
|
|
1
5
|
import pandas as pd
|
|
2
6
|
from io import StringIO
|
|
7
|
+
import os
|
|
8
|
+
import re
|
|
9
|
+
import logging
|
|
3
10
|
from typing import Optional
|
|
11
|
+
from dotenv import load_dotenv
|
|
4
12
|
|
|
5
|
-
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
6
13
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
7
14
|
from langchain_core.tools import Tool, StructuredTool
|
|
8
|
-
from
|
|
15
|
+
from langgraph.prebuilt import create_react_agent
|
|
9
16
|
from langchain_openai import ChatOpenAI
|
|
10
17
|
|
|
11
18
|
from plot_agent.prompt import DEFAULT_SYSTEM_PROMPT
|
|
@@ -24,12 +31,16 @@ class PlotAgent:
|
|
|
24
31
|
|
|
25
32
|
def __init__(
|
|
26
33
|
self,
|
|
27
|
-
model="gpt-4o-mini",
|
|
34
|
+
model: str = "gpt-4o-mini",
|
|
28
35
|
system_prompt: Optional[str] = None,
|
|
29
36
|
verbose: bool = True,
|
|
30
37
|
max_iterations: int = 10,
|
|
31
38
|
early_stopping_method: str = "force",
|
|
32
39
|
handle_parsing_errors: bool = True,
|
|
40
|
+
llm_temperature: float = 0.0,
|
|
41
|
+
llm_timeout: int = 60,
|
|
42
|
+
llm_max_retries: int = 1,
|
|
43
|
+
debug: bool = False,
|
|
33
44
|
):
|
|
34
45
|
"""
|
|
35
46
|
Initialize the PlotAgent.
|
|
@@ -42,13 +53,43 @@ class PlotAgent:
|
|
|
42
53
|
early_stopping_method (str): Method to use for early stopping.
|
|
43
54
|
handle_parsing_errors (bool): Whether to handle parsing errors gracefully.
|
|
44
55
|
"""
|
|
45
|
-
|
|
56
|
+
# Load .env if present, then require a valid API key
|
|
57
|
+
load_dotenv()
|
|
58
|
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
59
|
+
if not openai_api_key:
|
|
60
|
+
raise RuntimeError(
|
|
61
|
+
"OPENAI_API_KEY is not set. Provide it via environment or a .env file."
|
|
62
|
+
)
|
|
63
|
+
self.debug = debug or os.getenv("PLOT_AGENT_DEBUG") == "1"
|
|
64
|
+
|
|
65
|
+
# Configure logger
|
|
66
|
+
self._logger = logging.getLogger("plot_agent")
|
|
67
|
+
if self.debug:
|
|
68
|
+
self._logger.setLevel(logging.DEBUG)
|
|
69
|
+
if not self._logger.handlers:
|
|
70
|
+
handler = logging.StreamHandler()
|
|
71
|
+
handler.setFormatter(
|
|
72
|
+
logging.Formatter(
|
|
73
|
+
"%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
74
|
+
datefmt="%H:%M:%S",
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
self._logger.addHandler(handler)
|
|
78
|
+
|
|
79
|
+
self.llm = ChatOpenAI(
|
|
80
|
+
model=model,
|
|
81
|
+
temperature=llm_temperature,
|
|
82
|
+
timeout=llm_timeout,
|
|
83
|
+
max_retries=llm_max_retries,
|
|
84
|
+
)
|
|
46
85
|
self.df = None
|
|
47
86
|
self.df_info = None
|
|
48
87
|
self.df_head = None
|
|
49
88
|
self.sql_query = None
|
|
50
89
|
self.execution_env = None
|
|
51
90
|
self.chat_history = []
|
|
91
|
+
# Internal graph-native message history, including tool messages
|
|
92
|
+
self._graph_messages = []
|
|
52
93
|
self.agent_executor = None
|
|
53
94
|
self.generated_code = None
|
|
54
95
|
self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
|
@@ -94,6 +135,10 @@ class PlotAgent:
|
|
|
94
135
|
|
|
95
136
|
# Initialize the agent with tools
|
|
96
137
|
self._initialize_agent()
|
|
138
|
+
# Reset graph messages for a fresh session with this dataframe
|
|
139
|
+
self._graph_messages = []
|
|
140
|
+
if self.debug:
|
|
141
|
+
self._logger.debug("set_df() initialized execution environment and graph")
|
|
97
142
|
|
|
98
143
|
def execute_plotly_code(self, generated_code: str) -> str:
|
|
99
144
|
"""
|
|
@@ -123,7 +168,7 @@ class PlotAgent:
|
|
|
123
168
|
|
|
124
169
|
# Check if the code executed successfully
|
|
125
170
|
if code_execution_success:
|
|
126
|
-
return f"
|
|
171
|
+
return f"Success: {code_execution_output}"
|
|
127
172
|
else:
|
|
128
173
|
return f"Error: {code_execution_error}\n{code_execution_output}"
|
|
129
174
|
|
|
@@ -150,64 +195,69 @@ class PlotAgent:
|
|
|
150
195
|
"""
|
|
151
196
|
View the generated code.
|
|
152
197
|
"""
|
|
153
|
-
return self.generated_code
|
|
198
|
+
return self.generated_code or ""
|
|
154
199
|
|
|
155
200
|
def _initialize_agent(self):
|
|
156
|
-
"""Initialize
|
|
201
|
+
"""Initialize a LangGraph ReAct agent with tools and keep API compatibility."""
|
|
157
202
|
|
|
158
203
|
# Initialize the tools
|
|
159
204
|
tools = [
|
|
160
205
|
Tool.from_function(
|
|
161
206
|
func=self.execute_plotly_code,
|
|
162
207
|
name="execute_plotly_code",
|
|
163
|
-
description=
|
|
208
|
+
description=(
|
|
209
|
+
"Execute the provided Plotly code and return a result indicating "
|
|
210
|
+
"if the code executed successfully and if a figure object was created."
|
|
211
|
+
),
|
|
164
212
|
args_schema=GeneratedCodeInput,
|
|
165
213
|
),
|
|
166
214
|
StructuredTool.from_function(
|
|
167
215
|
func=self.does_fig_exist,
|
|
168
216
|
name="does_fig_exist",
|
|
169
|
-
description=
|
|
217
|
+
description=(
|
|
218
|
+
"Check if a figure exists and is available for display. "
|
|
219
|
+
"This tool takes no arguments and returns a string indicating "
|
|
220
|
+
"if a figure is available for display or not."
|
|
221
|
+
),
|
|
170
222
|
args_schema=DoesFigExistInput,
|
|
171
223
|
),
|
|
172
224
|
StructuredTool.from_function(
|
|
173
225
|
func=self.view_generated_code,
|
|
174
226
|
name="view_generated_code",
|
|
175
|
-
description=
|
|
227
|
+
description=(
|
|
228
|
+
"View the generated code. "
|
|
229
|
+
"This tool takes no arguments and returns the generated code as a string."
|
|
230
|
+
),
|
|
176
231
|
args_schema=ViewGeneratedCodeInput,
|
|
177
232
|
),
|
|
178
233
|
]
|
|
179
234
|
|
|
180
|
-
#
|
|
235
|
+
# Prepare system prompt with dataframe information
|
|
181
236
|
sql_context = ""
|
|
182
237
|
if self.sql_query:
|
|
183
|
-
sql_context =
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
),
|
|
194
|
-
),
|
|
195
|
-
MessagesPlaceholder(variable_name="chat_history"),
|
|
196
|
-
("human", "{input}"),
|
|
197
|
-
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
|
198
|
-
]
|
|
238
|
+
sql_context = (
|
|
239
|
+
"In case it is useful to help with the data understanding, the df was generated using the following SQL query:\n"
|
|
240
|
+
f"```sql\n{self.sql_query}\n```"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Store formatted system instructions for the graph state modifier
|
|
244
|
+
self._system_message_content = self.system_prompt.format(
|
|
245
|
+
df_info=self.df_info,
|
|
246
|
+
df_head=self.df_head,
|
|
247
|
+
sql_context=sql_context,
|
|
199
248
|
)
|
|
200
249
|
|
|
201
|
-
agent
|
|
202
|
-
self.
|
|
203
|
-
|
|
204
|
-
tools
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
early_stopping_method=self.early_stopping_method,
|
|
208
|
-
handle_parsing_errors=self.handle_parsing_errors,
|
|
250
|
+
# Create a ReAct agent graph with the provided tools and system prompt
|
|
251
|
+
self._graph = create_react_agent(
|
|
252
|
+
self.llm,
|
|
253
|
+
tools,
|
|
254
|
+
prompt=self._system_message_content,
|
|
255
|
+
debug=self.debug,
|
|
209
256
|
)
|
|
210
257
|
|
|
258
|
+
# Backwards-compatibility: expose under the old attribute name
|
|
259
|
+
self.agent_executor = self._graph
|
|
260
|
+
|
|
211
261
|
def process_message(self, user_message: str) -> str:
|
|
212
262
|
"""Process a user message and return the agent's response."""
|
|
213
263
|
assert isinstance(user_message, str), "The user message must be a string."
|
|
@@ -215,33 +265,135 @@ class PlotAgent:
|
|
|
215
265
|
if not self.agent_executor:
|
|
216
266
|
return "Please set a dataframe first using set_df() method."
|
|
217
267
|
|
|
218
|
-
# Add user message to chat history
|
|
268
|
+
# Add user message to outward-facing chat history
|
|
219
269
|
self.chat_history.append(HumanMessage(content=user_message))
|
|
220
270
|
|
|
221
271
|
# Reset generated_code
|
|
222
272
|
self.generated_code = None
|
|
223
273
|
|
|
224
|
-
#
|
|
225
|
-
|
|
226
|
-
|
|
274
|
+
# Short-circuit empty inputs to avoid graph recursion
|
|
275
|
+
if user_message.strip() == "":
|
|
276
|
+
ai_content = (
|
|
277
|
+
"Please provide a non-empty plotting request (e.g., 'scatter x vs y')."
|
|
278
|
+
)
|
|
279
|
+
self.chat_history.append(AIMessage(content=ai_content))
|
|
280
|
+
if self.debug:
|
|
281
|
+
self._logger.debug("empty message received; returning guidance without invoking graph")
|
|
282
|
+
return ai_content
|
|
283
|
+
|
|
284
|
+
# Short-circuit messages that are primarily raw code blocks without a visualization request
|
|
285
|
+
if "```" in user_message and not re.search(
|
|
286
|
+
r"\b(plot|chart|graph|visuali(s|z)e|figure|subplot|heatmap|bar|line|scatter)\b",
|
|
287
|
+
user_message,
|
|
288
|
+
flags=re.IGNORECASE,
|
|
289
|
+
):
|
|
290
|
+
ai_content = (
|
|
291
|
+
"I see a code snippet. Please describe the visualization you want (e.g., 'line chart of y over x')."
|
|
292
|
+
)
|
|
293
|
+
self.chat_history.append(AIMessage(content=ai_content))
|
|
294
|
+
if self.debug:
|
|
295
|
+
self._logger.debug("code-only message received; returning guidance without invoking graph")
|
|
296
|
+
return ai_content
|
|
297
|
+
|
|
298
|
+
# Build graph messages (includes tool call/observation history)
|
|
299
|
+
graph_messages = [*self._graph_messages, HumanMessage(content=user_message)]
|
|
300
|
+
if self.debug:
|
|
301
|
+
self._logger.debug(f"process_message() user: {user_message}")
|
|
302
|
+
self._logger.debug(f"graph message count before invoke: {len(graph_messages)}")
|
|
303
|
+
# Invoke the LangGraph agent
|
|
304
|
+
result = self.agent_executor.invoke(
|
|
305
|
+
{"messages": graph_messages},
|
|
306
|
+
config={"recursion_limit": self.max_iterations},
|
|
227
307
|
)
|
|
228
308
|
|
|
229
|
-
#
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
309
|
+
# Extract the latest AI message from the returned messages
|
|
310
|
+
ai_messages = [m for m in result.get("messages", []) if isinstance(m, AIMessage)]
|
|
311
|
+
ai_content = ai_messages[-1].content if ai_messages else ""
|
|
312
|
+
|
|
313
|
+
# Persist full graph messages for future context
|
|
314
|
+
self._graph_messages = result.get("messages", [])
|
|
315
|
+
if self.debug:
|
|
316
|
+
self._logger.debug(f"graph message count after invoke: {len(self._graph_messages)}")
|
|
317
|
+
|
|
318
|
+
# Add agent response to outward-facing chat history
|
|
319
|
+
self.chat_history.append(AIMessage(content=ai_content))
|
|
320
|
+
|
|
321
|
+
# If the agent didn't execute the code via tool, but we have prior generated_code, execute it
|
|
322
|
+
if self.execution_env and self.execution_env.fig is None and self.generated_code is not None:
|
|
323
|
+
if self.debug:
|
|
324
|
+
self._logger.debug("executing stored generated_code because no fig exists yet")
|
|
325
|
+
exec_result = self.execution_env.execute_code(self.generated_code)
|
|
326
|
+
if self.debug:
|
|
327
|
+
self._logger.debug(f"execution result success={exec_result.get('success')} error={exec_result.get('error')!r}")
|
|
328
|
+
|
|
329
|
+
# If the assistant returned code in the message, execute it to update the figure
|
|
330
|
+
code_executed = False
|
|
331
|
+
if self.execution_env and isinstance(ai_content, str):
|
|
332
|
+
extracted_code = None
|
|
333
|
+
if "```python" in ai_content:
|
|
334
|
+
parts = ai_content.split("```python", 1)
|
|
335
|
+
extracted_code = parts[1].split("```", 1)[0].strip() if len(parts) > 1 else None
|
|
336
|
+
elif "```" in ai_content:
|
|
337
|
+
# Fallback: extract first generic fenced code block
|
|
338
|
+
parts = ai_content.split("```", 1)
|
|
339
|
+
if len(parts) > 1:
|
|
340
|
+
extracted_code = parts[1].split("```", 1)[0].strip()
|
|
341
|
+
if extracted_code:
|
|
342
|
+
if (self.generated_code or "").strip() != extracted_code:
|
|
343
|
+
self.generated_code = extracted_code
|
|
344
|
+
if self.debug:
|
|
345
|
+
self._logger.debug("executing code extracted from AI message")
|
|
346
|
+
exec_result = self.execution_env.execute_code(extracted_code)
|
|
347
|
+
if self.debug:
|
|
348
|
+
self._logger.debug(f"execution result success={exec_result.get('success')} error={exec_result.get('error')!r}")
|
|
349
|
+
code_executed = True
|
|
350
|
+
|
|
351
|
+
# If still no figure and no code was executed, run one guided retry to force tool usage
|
|
352
|
+
if self.execution_env and self.execution_env.fig is None and not code_executed:
|
|
353
|
+
if self.debug:
|
|
354
|
+
self._logger.debug("guided retry: prompting model to use execute_plotly_code tool")
|
|
355
|
+
guided_messages = [
|
|
356
|
+
*self._graph_messages,
|
|
357
|
+
HumanMessage(
|
|
358
|
+
content=(
|
|
359
|
+
"Please use the execute_plotly_code(generated_code) tool with the FULL code to "
|
|
360
|
+
"create a variable named 'fig', then call does_fig_exist(). Return the final "
|
|
361
|
+
"code in a fenced ```python block."
|
|
362
|
+
)
|
|
363
|
+
),
|
|
364
|
+
]
|
|
365
|
+
retry_result = self.agent_executor.invoke(
|
|
366
|
+
{"messages": guided_messages},
|
|
367
|
+
config={"recursion_limit": max(3, self.max_iterations // 2)},
|
|
368
|
+
)
|
|
369
|
+
self._graph_messages = retry_result.get("messages", [])
|
|
370
|
+
retry_ai_messages = [
|
|
371
|
+
m for m in self._graph_messages if isinstance(m, AIMessage)
|
|
372
|
+
]
|
|
373
|
+
retry_content = retry_ai_messages[-1].content if retry_ai_messages else ""
|
|
374
|
+
if isinstance(retry_content, str):
|
|
375
|
+
if "```python" in retry_content:
|
|
376
|
+
parts = retry_content.split("```python", 1)
|
|
377
|
+
retry_code = (
|
|
378
|
+
parts[1].split("```", 1)[0].strip() if len(parts) > 1 else None
|
|
379
|
+
)
|
|
380
|
+
elif "```" in retry_content:
|
|
381
|
+
parts = retry_content.split("```", 1)
|
|
382
|
+
retry_code = (
|
|
383
|
+
parts[1].split("```", 1)[0].strip() if len(parts) > 1 else None
|
|
384
|
+
)
|
|
385
|
+
else:
|
|
386
|
+
retry_code = None
|
|
387
|
+
if retry_code:
|
|
388
|
+
if (self.generated_code or "").strip() != retry_code:
|
|
389
|
+
self.generated_code = retry_code
|
|
390
|
+
if self.debug:
|
|
391
|
+
self._logger.debug("executing code extracted from guided retry response")
|
|
392
|
+
exec_result = self.execution_env.execute_code(retry_code)
|
|
393
|
+
if self.debug:
|
|
394
|
+
self._logger.debug(f"execution result success={exec_result.get('success')} error={exec_result.get('error')!r}")
|
|
395
|
+
|
|
396
|
+
return ai_content if isinstance(ai_content, str) else str(ai_content)
|
|
245
397
|
|
|
246
398
|
def get_figure(self):
|
|
247
399
|
"""Return the current figure if one exists."""
|
plot_agent/execution.py
CHANGED
|
@@ -1,90 +1,250 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""
|
|
2
|
+
This module contains the PlotAgentExecutionEnvironment class, which is used to safely execute LLM‑generated plotting code and capture `fig`.
|
|
3
|
+
|
|
4
|
+
Security features:
|
|
5
|
+
• Only allow imports from a fixed list of packages (pandas, numpy,
|
|
6
|
+
matplotlib, plotly, sklearn)
|
|
7
|
+
• AST scan rejects any import outside that list and any __dunder__ access
|
|
8
|
+
• Sandbox builtins to include only a minimal safe set + our _safe_import
|
|
9
|
+
• Enforce a 60 second timeout via signal.alarm
|
|
10
|
+
"""
|
|
11
|
+
import ast
|
|
12
|
+
import builtins
|
|
13
|
+
import signal
|
|
14
|
+
import threading
|
|
3
15
|
import traceback
|
|
16
|
+
from io import StringIO
|
|
17
|
+
import contextlib
|
|
18
|
+
|
|
4
19
|
import pandas as pd
|
|
5
|
-
import plotly.express as px
|
|
6
|
-
import plotly.graph_objects as go
|
|
7
20
|
import numpy as np
|
|
8
21
|
import matplotlib.pyplot as plt
|
|
22
|
+
import plotly.express as px
|
|
23
|
+
import plotly.graph_objects as go
|
|
9
24
|
from plotly.subplots import make_subplots
|
|
10
|
-
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _timeout_handler(signum, frame):
|
|
28
|
+
raise TimeoutError("Code execution timed out")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# List of allowed modules
|
|
32
|
+
_ALLOWED_MODULES = {
|
|
33
|
+
"pandas",
|
|
34
|
+
"numpy",
|
|
35
|
+
"matplotlib",
|
|
36
|
+
"plotly",
|
|
37
|
+
"sklearn",
|
|
38
|
+
"scipy",
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# Wrap the real __import__ so only our allowlist can get in
|
|
43
|
+
_orig_import = builtins.__import__
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _safe_import(name, globals=None, locals=None, fromlist=(), level=0):
|
|
47
|
+
"""
|
|
48
|
+
Wrap the real __import__ so only our allowlist can get in
|
|
49
|
+
"""
|
|
50
|
+
root = name.split(".", 1)[0]
|
|
51
|
+
if root in _ALLOWED_MODULES:
|
|
52
|
+
return _orig_import(name, globals, locals, fromlist, level)
|
|
53
|
+
# If the module is not in the allowlist, raise an ImportError
|
|
54
|
+
raise ImportError(f"Import of module '{name}' is not allowed.")
|
|
11
55
|
|
|
12
56
|
|
|
13
57
|
class PlotAgentExecutionEnvironment:
|
|
14
58
|
"""
|
|
15
|
-
Environment to safely execute
|
|
59
|
+
Environment to safely execute LLM‑generated plotting code and capture `fig`.
|
|
16
60
|
|
|
17
|
-
|
|
18
|
-
|
|
61
|
+
Security features:
|
|
62
|
+
• Only allow imports from a fixed list of packages (pandas, numpy,
|
|
63
|
+
matplotlib, plotly, sklearn)
|
|
64
|
+
• AST scan rejects any import outside that list and any __dunder__ access
|
|
65
|
+
• Sandbox builtins to include only a minimal safe set + our _safe_import
|
|
66
|
+
• Enforce a 60 second timeout via signal.alarm
|
|
67
|
+
• Capture both stdout & stderr
|
|
68
|
+
• Purge any old `fig` between runs
|
|
19
69
|
"""
|
|
20
70
|
|
|
71
|
+
TIMEOUT_SECONDS = 60
|
|
72
|
+
|
|
73
|
+
# A lean set of builtins, plus our safe-import hook
|
|
74
|
+
_SAFE_BUILTINS = {
|
|
75
|
+
"abs": abs,
|
|
76
|
+
"all": all,
|
|
77
|
+
"any": any,
|
|
78
|
+
"bin": bin,
|
|
79
|
+
"bool": bool,
|
|
80
|
+
"chr": chr,
|
|
81
|
+
"dict": dict,
|
|
82
|
+
"divmod": divmod,
|
|
83
|
+
"enumerate": enumerate,
|
|
84
|
+
"float": float,
|
|
85
|
+
"int": int,
|
|
86
|
+
"len": len,
|
|
87
|
+
"list": list,
|
|
88
|
+
"map": map,
|
|
89
|
+
"max": max,
|
|
90
|
+
"min": min,
|
|
91
|
+
"next": next,
|
|
92
|
+
"pow": pow,
|
|
93
|
+
"print": print,
|
|
94
|
+
"range": range,
|
|
95
|
+
"reversed": reversed,
|
|
96
|
+
"round": round,
|
|
97
|
+
"set": set,
|
|
98
|
+
"str": str,
|
|
99
|
+
"sum": sum,
|
|
100
|
+
"tuple": tuple,
|
|
101
|
+
"zip": zip,
|
|
102
|
+
# basic exceptions so user code can raise/catch
|
|
103
|
+
"BaseException": BaseException,
|
|
104
|
+
"Exception": Exception,
|
|
105
|
+
"ValueError": ValueError,
|
|
106
|
+
"TypeError": TypeError,
|
|
107
|
+
"NameError": NameError,
|
|
108
|
+
"IndexError": IndexError,
|
|
109
|
+
"KeyError": KeyError,
|
|
110
|
+
# our import guard
|
|
111
|
+
"__import__": _safe_import,
|
|
112
|
+
}
|
|
113
|
+
|
|
21
114
|
def __init__(self, df: pd.DataFrame):
|
|
22
115
|
"""
|
|
23
|
-
Initialize the execution environment with
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
df (pd.DataFrame): The dataframe to use for the execution environment.
|
|
116
|
+
Initialize the execution environment with a dataframe.
|
|
27
117
|
"""
|
|
28
118
|
self.df = df
|
|
29
|
-
|
|
119
|
+
# Base namespace for both globals & locals
|
|
120
|
+
self._base_ns = {
|
|
121
|
+
"__builtins__": self._SAFE_BUILTINS,
|
|
30
122
|
"df": df,
|
|
31
|
-
"px": px,
|
|
32
|
-
"go": go,
|
|
33
123
|
"pd": pd,
|
|
34
124
|
"np": np,
|
|
35
125
|
"plt": plt,
|
|
126
|
+
"px": px,
|
|
127
|
+
"go": go,
|
|
36
128
|
"make_subplots": make_subplots,
|
|
37
129
|
}
|
|
38
|
-
self.output = None
|
|
39
|
-
self.error = None
|
|
40
130
|
self.fig = None
|
|
41
131
|
|
|
42
|
-
def
|
|
132
|
+
def _validate_ast(self, node: ast.AST):
|
|
133
|
+
"""
|
|
134
|
+
Walk the AST and enforce:
|
|
135
|
+
• any Import/ImportFrom must be from _ALLOWED_MODULES
|
|
136
|
+
• no __dunder__ attribute access
|
|
43
137
|
"""
|
|
44
|
-
|
|
138
|
+
# Walk the AST and enforce:
|
|
139
|
+
for child in ast.walk(node):
|
|
140
|
+
# Check for imports
|
|
141
|
+
if isinstance(child, ast.Import):
|
|
142
|
+
# Check for imports
|
|
143
|
+
for alias in child.names:
|
|
144
|
+
root = alias.name.split(".", 1)[0]
|
|
145
|
+
# Check if the module is in the allowlist
|
|
146
|
+
if root not in _ALLOWED_MODULES:
|
|
147
|
+
raise ValueError(f"Import of '{alias.name}' is not allowed.")
|
|
148
|
+
# Check for import-froms
|
|
149
|
+
elif isinstance(child, ast.ImportFrom):
|
|
150
|
+
root = (child.module or "").split(".", 1)[0]
|
|
151
|
+
if root not in _ALLOWED_MODULES:
|
|
152
|
+
raise ValueError(f"Import-from of '{child.module}' is not allowed.")
|
|
153
|
+
# Check for dunder attribute access
|
|
154
|
+
elif isinstance(child, ast.Attribute) and child.attr.startswith("__"):
|
|
155
|
+
raise ValueError("Access to dunder attributes is forbidden.")
|
|
45
156
|
|
|
46
|
-
|
|
47
|
-
|
|
157
|
+
def execute_code(self, generated_code: str):
|
|
158
|
+
"""
|
|
159
|
+
Execute the user code in a locked‑down sandbox.
|
|
48
160
|
|
|
49
|
-
Returns:
|
|
50
|
-
|
|
161
|
+
Returns a dict with:
|
|
162
|
+
- fig: The figure if created, else None
|
|
163
|
+
- output: Captured stdout
|
|
164
|
+
- error: Captured stderr or exception text
|
|
165
|
+
- success: True if fig was produced and no errors
|
|
51
166
|
"""
|
|
52
|
-
self.output = None
|
|
53
|
-
self.error = None
|
|
54
167
|
|
|
55
|
-
#
|
|
56
|
-
|
|
57
|
-
|
|
168
|
+
# Copy the base namespace
|
|
169
|
+
ns = self._base_ns.copy()
|
|
170
|
+
# Purge any old `fig`
|
|
171
|
+
ns.pop("fig", None)
|
|
58
172
|
|
|
59
173
|
try:
|
|
60
|
-
#
|
|
61
|
-
|
|
174
|
+
# Parse the generated code
|
|
175
|
+
tree = ast.parse(generated_code)
|
|
176
|
+
# Validate the AST
|
|
177
|
+
self._validate_ast(tree)
|
|
178
|
+
except Exception as e:
|
|
179
|
+
# If the code is rejected on safety grounds, return an error
|
|
180
|
+
return {
|
|
181
|
+
"fig": None,
|
|
182
|
+
"output": "",
|
|
183
|
+
"error": f"Code rejected on safety grounds: {e}",
|
|
184
|
+
"success": False,
|
|
185
|
+
}
|
|
62
186
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
187
|
+
# Set a timeout only if running on the main thread; signals are not supported in worker threads
|
|
188
|
+
timeout_set = False
|
|
189
|
+
try:
|
|
190
|
+
if threading.current_thread() is threading.main_thread():
|
|
191
|
+
signal.signal(signal.SIGALRM, _timeout_handler)
|
|
192
|
+
signal.alarm(self.TIMEOUT_SECONDS)
|
|
193
|
+
timeout_set = True
|
|
194
|
+
except Exception:
|
|
195
|
+
# If setting the signal handler fails (e.g., not in main thread), proceed without timeout
|
|
196
|
+
timeout_set = False
|
|
70
197
|
|
|
198
|
+
# Execute the code
|
|
199
|
+
out_buf, err_buf = StringIO(), StringIO()
|
|
200
|
+
try:
|
|
201
|
+
# Redirect stdout and stderr
|
|
202
|
+
with contextlib.redirect_stdout(out_buf), contextlib.redirect_stderr(
|
|
203
|
+
err_buf
|
|
204
|
+
):
|
|
205
|
+
# Execute the code
|
|
206
|
+
exec(generated_code, ns, ns)
|
|
207
|
+
except TimeoutError as te:
|
|
208
|
+
# If the code execution timed out, return an error
|
|
209
|
+
tb = traceback.format_exc()
|
|
210
|
+
return {
|
|
211
|
+
"fig": None,
|
|
212
|
+
"output": out_buf.getvalue(),
|
|
213
|
+
"error": f"Code execution timed out: {te}\n{tb}",
|
|
214
|
+
"success": False,
|
|
215
|
+
}
|
|
71
216
|
except Exception as e:
|
|
72
|
-
|
|
73
|
-
|
|
217
|
+
# If there was an error, return an error
|
|
218
|
+
tb = traceback.format_exc()
|
|
219
|
+
return {
|
|
220
|
+
"fig": None,
|
|
221
|
+
"output": out_buf.getvalue(),
|
|
222
|
+
"error": f"Error executing code: {e}\n{tb}",
|
|
223
|
+
"success": False,
|
|
224
|
+
}
|
|
74
225
|
finally:
|
|
75
|
-
#
|
|
76
|
-
|
|
77
|
-
|
|
226
|
+
# Reset the timeout if it was set
|
|
227
|
+
if timeout_set:
|
|
228
|
+
try:
|
|
229
|
+
signal.alarm(0)
|
|
230
|
+
except Exception:
|
|
231
|
+
pass
|
|
78
232
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
233
|
+
# Get the `fig`
|
|
234
|
+
fig = ns.get("fig")
|
|
235
|
+
self.fig = fig
|
|
236
|
+
if fig is None:
|
|
237
|
+
return {
|
|
238
|
+
"fig": None,
|
|
239
|
+
"output": out_buf.getvalue(),
|
|
240
|
+
"error": "No `fig` created. Assign your figure to a variable named `fig`.",
|
|
241
|
+
"success": False,
|
|
242
|
+
}
|
|
84
243
|
|
|
244
|
+
# Return the result
|
|
85
245
|
return {
|
|
86
|
-
"fig":
|
|
87
|
-
"output":
|
|
88
|
-
"error":
|
|
89
|
-
"success":
|
|
246
|
+
"fig": fig,
|
|
247
|
+
"output": "Code executed successfully. 'fig' object was created.",
|
|
248
|
+
"error": "",
|
|
249
|
+
"success": True,
|
|
90
250
|
}
|
plot_agent/models.py
CHANGED
plot_agent/prompt.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the prompts for the PlotAgent.
|
|
3
|
+
"""
|
|
4
|
+
|
|
1
5
|
DEFAULT_SYSTEM_PROMPT = """
|
|
2
6
|
You are an expert data visualization assistant that helps users create Plotly visualizations in Python.
|
|
3
7
|
Your job is to generate Python and Plotly code based on the user's request that will create the desired visualization
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: plot-agent
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: An AI-powered data visualization assistant using Plotly
|
|
5
5
|
Author-email: andrewm4894 <andrewm4894@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
|
|
@@ -19,6 +19,8 @@ Dynamic: license-file
|
|
|
19
19
|
|
|
20
20
|
An AI-powered data visualization assistant that helps users create Plotly visualizations in Python.
|
|
21
21
|
|
|
22
|
+
Built on LangGraph with tool-calling to reliably execute generated Plotly code in a sandbox and keep the current `fig` in sync.
|
|
23
|
+
|
|
22
24
|
## Installation
|
|
23
25
|
|
|
24
26
|
You can install the package using pip:
|
|
@@ -37,7 +39,7 @@ Here's a simple minimal example of how to use Plot Agent:
|
|
|
37
39
|
import pandas as pd
|
|
38
40
|
from plot_agent.agent import PlotAgent
|
|
39
41
|
|
|
40
|
-
# ensure OPENAI_API_KEY is set
|
|
42
|
+
# ensure OPENAI_API_KEY is set (env or .env); optional debug via PLOT_AGENT_DEBUG=1
|
|
41
43
|
|
|
42
44
|
# Create a sample dataframe
|
|
43
45
|
df = pd.DataFrame({
|
|
@@ -92,19 +94,72 @@ fig.update_layout(
|
|
|
92
94
|
)
|
|
93
95
|
```
|
|
94
96
|
|
|
97
|
+
## How it works
|
|
98
|
+
|
|
99
|
+
```mermaid
|
|
100
|
+
flowchart TD
|
|
101
|
+
A[User message] --> B{LangGraph ReAct Agent}
|
|
102
|
+
subgraph Tools
|
|
103
|
+
T1[execute_plotly_code<br/>- runs code in sandbox<br/>- returns success/fig/error]
|
|
104
|
+
T2[does_fig_exist]
|
|
105
|
+
T3[view_generated_code]
|
|
106
|
+
end
|
|
107
|
+
B -- tool call --> T1
|
|
108
|
+
T1 -- result --> B
|
|
109
|
+
B -- optional --> T2
|
|
110
|
+
B -- optional --> T3
|
|
111
|
+
B --> C[AI response]
|
|
112
|
+
C --> D{Agent wrapper}
|
|
113
|
+
D -- persist messages --> B
|
|
114
|
+
D -- extract code blocks --> E[Sandbox execution]
|
|
115
|
+
E --> F[fig]
|
|
116
|
+
F --> G[get_figure]
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
- The LangGraph agent plans and decides when to call tools.
|
|
120
|
+
- The wrapper persists full graph messages between turns and executes any returned code blocks to keep `fig` updated.
|
|
121
|
+
- A safe execution environment runs code with an allowlist and a main-thread-only timeout.
|
|
122
|
+
|
|
95
123
|
## Features
|
|
96
124
|
|
|
97
125
|
- AI-powered visualization generation
|
|
98
126
|
- Support for various Plotly chart types
|
|
99
127
|
- Automatic data preprocessing
|
|
100
128
|
- Interactive visualization capabilities
|
|
101
|
-
-
|
|
129
|
+
- LangGraph-based tool calling and control flow
|
|
130
|
+
- Debug logging via `PlotAgent(debug=True)` or `PLOT_AGENT_DEBUG=1`
|
|
102
131
|
|
|
103
132
|
## Requirements
|
|
104
133
|
|
|
105
134
|
- Python 3.8 or higher
|
|
106
135
|
- Dependencies are automatically installed with the package
|
|
107
136
|
|
|
137
|
+
## Development
|
|
138
|
+
|
|
139
|
+
- Run unit tests:
|
|
140
|
+
|
|
141
|
+
```bash
|
|
142
|
+
make test
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
- Execute all example notebooks:
|
|
146
|
+
|
|
147
|
+
```bash
|
|
148
|
+
make run-examples
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
- Execute with debug logs enabled:
|
|
152
|
+
|
|
153
|
+
```bash
|
|
154
|
+
make run-examples-debug
|
|
155
|
+
```
|
|
156
|
+
|
|
157
|
+
- Quick CLI repro that prints evolving code each step:
|
|
158
|
+
|
|
159
|
+
```bash
|
|
160
|
+
make run-example-script
|
|
161
|
+
```
|
|
162
|
+
|
|
108
163
|
## License
|
|
109
164
|
|
|
110
165
|
This project is licensed under the MIT License - see the LICENSE file for details.
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
plot_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
plot_agent/agent.py,sha256=9N33lBkXe3PSYP6SsQkX2X2CAu7n2UjkRZap8hgxwzo,16733
|
|
3
|
+
plot_agent/execution.py,sha256=FaMWKyFxQewrVx5tpHKLI7WafE9Q2ogvXhOVYZ4G3hw,8086
|
|
4
|
+
plot_agent/models.py,sha256=THdGGGfGmRZ5rtgXvjPcQxFRRTZVFoADEHI_lsMVha8,860
|
|
5
|
+
plot_agent/prompt.py,sha256=5hBlF7jdMrj6MiGEL7YmSDWFUfiCXyIZfZtf3NstKoo,3125
|
|
6
|
+
plot_agent-0.4.0.dist-info/licenses/LICENSE,sha256=A4DPih7wHrh4VMEG3p1PhorqdhjmGIo8nQdYNQL7daA,1062
|
|
7
|
+
plot_agent-0.4.0.dist-info/METADATA,sha256=iNH_2qd_k_Jp3RyWiFRitY4Vwx4ZHFNwMSFFodzMcfE,4152
|
|
8
|
+
plot_agent-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
+
plot_agent-0.4.0.dist-info/top_level.txt,sha256=KyOjpihUssx26Ra-37vKUQ71pI2qgJsHaRwXHJUhjzQ,11
|
|
10
|
+
plot_agent-0.4.0.dist-info/RECORD,,
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
plot_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
plot_agent/agent.py,sha256=FPHma_u_77y3c5nBrCPS0VR8AWz-NFFRqURn0xHph-M,9584
|
|
3
|
-
plot_agent/execution.py,sha256=Ewnz6Pb2EWM0pDAmeVoJ5RmhY5gmO1F7iOvERkj2D28,2770
|
|
4
|
-
plot_agent/models.py,sha256=MkAaSELr54RfGRONKfiqCRA2ghlRbzTe5L5krVuqMsk,800
|
|
5
|
-
plot_agent/prompt.py,sha256=HjRgbsAe8HHs8arQogvzOGQdThEWKRqQhtQyaUplxhQ,3064
|
|
6
|
-
plot_agent-0.3.0.dist-info/licenses/LICENSE,sha256=A4DPih7wHrh4VMEG3p1PhorqdhjmGIo8nQdYNQL7daA,1062
|
|
7
|
-
plot_agent-0.3.0.dist-info/METADATA,sha256=Cqw5gPpT2OM0fPaV_XLFJChEjsAFrVKfiwvxWxdVphM,2837
|
|
8
|
-
plot_agent-0.3.0.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
|
9
|
-
plot_agent-0.3.0.dist-info/top_level.txt,sha256=KyOjpihUssx26Ra-37vKUQ71pI2qgJsHaRwXHJUhjzQ,11
|
|
10
|
-
plot_agent-0.3.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|