plot-agent 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plot_agent/agent.py CHANGED
@@ -1,11 +1,18 @@
1
+ """
2
+ This module contains the PlotAgent class, which is used to generate Plotly code based on a user's plot description.
3
+ """
4
+
1
5
  import pandas as pd
2
6
  from io import StringIO
7
+ import os
8
+ import re
9
+ import logging
3
10
  from typing import Optional
11
+ from dotenv import load_dotenv
4
12
 
5
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
6
13
  from langchain_core.messages import AIMessage, HumanMessage
7
14
  from langchain_core.tools import Tool, StructuredTool
8
- from langchain.agents import AgentExecutor, create_openai_tools_agent
15
+ from langgraph.prebuilt import create_react_agent
9
16
  from langchain_openai import ChatOpenAI
10
17
 
11
18
  from plot_agent.prompt import DEFAULT_SYSTEM_PROMPT
@@ -24,12 +31,16 @@ class PlotAgent:
24
31
 
25
32
  def __init__(
26
33
  self,
27
- model="gpt-4o-mini",
34
+ model: str = "gpt-4o-mini",
28
35
  system_prompt: Optional[str] = None,
29
36
  verbose: bool = True,
30
37
  max_iterations: int = 10,
31
38
  early_stopping_method: str = "force",
32
39
  handle_parsing_errors: bool = True,
40
+ llm_temperature: float = 0.0,
41
+ llm_timeout: int = 60,
42
+ llm_max_retries: int = 1,
43
+ debug: bool = False,
33
44
  ):
34
45
  """
35
46
  Initialize the PlotAgent.
@@ -42,13 +53,43 @@ class PlotAgent:
42
53
  early_stopping_method (str): Method to use for early stopping.
43
54
  handle_parsing_errors (bool): Whether to handle parsing errors gracefully.
44
55
  """
45
- self.llm = ChatOpenAI(model=model)
56
+ # Load .env if present, then require a valid API key
57
+ load_dotenv()
58
+ openai_api_key = os.getenv("OPENAI_API_KEY")
59
+ if not openai_api_key:
60
+ raise RuntimeError(
61
+ "OPENAI_API_KEY is not set. Provide it via environment or a .env file."
62
+ )
63
+ self.debug = debug or os.getenv("PLOT_AGENT_DEBUG") == "1"
64
+
65
+ # Configure logger
66
+ self._logger = logging.getLogger("plot_agent")
67
+ if self.debug:
68
+ self._logger.setLevel(logging.DEBUG)
69
+ if not self._logger.handlers:
70
+ handler = logging.StreamHandler()
71
+ handler.setFormatter(
72
+ logging.Formatter(
73
+ "%(asctime)s [%(levelname)s] %(name)s: %(message)s",
74
+ datefmt="%H:%M:%S",
75
+ )
76
+ )
77
+ self._logger.addHandler(handler)
78
+
79
+ self.llm = ChatOpenAI(
80
+ model=model,
81
+ temperature=llm_temperature,
82
+ timeout=llm_timeout,
83
+ max_retries=llm_max_retries,
84
+ )
46
85
  self.df = None
47
86
  self.df_info = None
48
87
  self.df_head = None
49
88
  self.sql_query = None
50
89
  self.execution_env = None
51
90
  self.chat_history = []
91
+ # Internal graph-native message history, including tool messages
92
+ self._graph_messages = []
52
93
  self.agent_executor = None
53
94
  self.generated_code = None
54
95
  self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
@@ -94,6 +135,10 @@ class PlotAgent:
94
135
 
95
136
  # Initialize the agent with tools
96
137
  self._initialize_agent()
138
+ # Reset graph messages for a fresh session with this dataframe
139
+ self._graph_messages = []
140
+ if self.debug:
141
+ self._logger.debug("set_df() initialized execution environment and graph")
97
142
 
98
143
  def execute_plotly_code(self, generated_code: str) -> str:
99
144
  """
@@ -123,7 +168,7 @@ class PlotAgent:
123
168
 
124
169
  # Check if the code executed successfully
125
170
  if code_execution_success:
126
- return f"Code executed successfully! A figure object was created.\n{code_execution_output}"
171
+ return f"Success: {code_execution_output}"
127
172
  else:
128
173
  return f"Error: {code_execution_error}\n{code_execution_output}"
129
174
 
@@ -150,64 +195,69 @@ class PlotAgent:
150
195
  """
151
196
  View the generated code.
152
197
  """
153
- return self.generated_code
198
+ return self.generated_code or ""
154
199
 
155
200
  def _initialize_agent(self):
156
- """Initialize the LangChain agent with the necessary tools and prompt."""
201
+ """Initialize a LangGraph ReAct agent with tools and keep API compatibility."""
157
202
 
158
203
  # Initialize the tools
159
204
  tools = [
160
205
  Tool.from_function(
161
206
  func=self.execute_plotly_code,
162
207
  name="execute_plotly_code",
163
- description="Execute the provided Plotly code and return a result indicating if the code executed successfully and if a figure object was created.",
208
+ description=(
209
+ "Execute the provided Plotly code and return a result indicating "
210
+ "if the code executed successfully and if a figure object was created."
211
+ ),
164
212
  args_schema=GeneratedCodeInput,
165
213
  ),
166
214
  StructuredTool.from_function(
167
215
  func=self.does_fig_exist,
168
216
  name="does_fig_exist",
169
- description="Check if a figure exists and is available for display. This tool takes no arguments and returns a string indicating if a figure is available for display or not.",
217
+ description=(
218
+ "Check if a figure exists and is available for display. "
219
+ "This tool takes no arguments and returns a string indicating "
220
+ "if a figure is available for display or not."
221
+ ),
170
222
  args_schema=DoesFigExistInput,
171
223
  ),
172
224
  StructuredTool.from_function(
173
225
  func=self.view_generated_code,
174
226
  name="view_generated_code",
175
- description="View the generated code. This tool takes no arguments and returns the generated code as a string.",
227
+ description=(
228
+ "View the generated code. "
229
+ "This tool takes no arguments and returns the generated code as a string."
230
+ ),
176
231
  args_schema=ViewGeneratedCodeInput,
177
232
  ),
178
233
  ]
179
234
 
180
- # Create system prompt with dataframe information
235
+ # Prepare system prompt with dataframe information
181
236
  sql_context = ""
182
237
  if self.sql_query:
183
- sql_context = f"In case it is useful to help with the data understanding, the df was generated using the following SQL query:\n```sql\n{self.sql_query}\n```"
184
-
185
- prompt = ChatPromptTemplate.from_messages(
186
- [
187
- (
188
- "system",
189
- self.system_prompt.format(
190
- df_info=self.df_info,
191
- df_head=self.df_head,
192
- sql_context=sql_context,
193
- ),
194
- ),
195
- MessagesPlaceholder(variable_name="chat_history"),
196
- ("human", "{input}"),
197
- MessagesPlaceholder(variable_name="agent_scratchpad"),
198
- ]
238
+ sql_context = (
239
+ "In case it is useful to help with the data understanding, the df was generated using the following SQL query:\n"
240
+ f"```sql\n{self.sql_query}\n```"
241
+ )
242
+
243
+ # Store formatted system instructions for the graph state modifier
244
+ self._system_message_content = self.system_prompt.format(
245
+ df_info=self.df_info,
246
+ df_head=self.df_head,
247
+ sql_context=sql_context,
199
248
  )
200
249
 
201
- agent = create_openai_tools_agent(self.llm, tools, prompt)
202
- self.agent_executor = AgentExecutor(
203
- agent=agent,
204
- tools=tools,
205
- verbose=self.verbose,
206
- max_iterations=self.max_iterations,
207
- early_stopping_method=self.early_stopping_method,
208
- handle_parsing_errors=self.handle_parsing_errors,
250
+ # Create a ReAct agent graph with the provided tools and system prompt
251
+ self._graph = create_react_agent(
252
+ self.llm,
253
+ tools,
254
+ prompt=self._system_message_content,
255
+ debug=self.debug,
209
256
  )
210
257
 
258
+ # Backwards-compatibility: expose under the old attribute name
259
+ self.agent_executor = self._graph
260
+
211
261
  def process_message(self, user_message: str) -> str:
212
262
  """Process a user message and return the agent's response."""
213
263
  assert isinstance(user_message, str), "The user message must be a string."
@@ -215,33 +265,135 @@ class PlotAgent:
215
265
  if not self.agent_executor:
216
266
  return "Please set a dataframe first using set_df() method."
217
267
 
218
- # Add user message to chat history
268
+ # Add user message to outward-facing chat history
219
269
  self.chat_history.append(HumanMessage(content=user_message))
220
270
 
221
271
  # Reset generated_code
222
272
  self.generated_code = None
223
273
 
224
- # Get response from agent
225
- response = self.agent_executor.invoke(
226
- {"input": user_message, "chat_history": self.chat_history}
274
+ # Short-circuit empty inputs to avoid graph recursion
275
+ if user_message.strip() == "":
276
+ ai_content = (
277
+ "Please provide a non-empty plotting request (e.g., 'scatter x vs y')."
278
+ )
279
+ self.chat_history.append(AIMessage(content=ai_content))
280
+ if self.debug:
281
+ self._logger.debug("empty message received; returning guidance without invoking graph")
282
+ return ai_content
283
+
284
+ # Short-circuit messages that are primarily raw code blocks without a visualization request
285
+ if "```" in user_message and not re.search(
286
+ r"\b(plot|chart|graph|visuali(s|z)e|figure|subplot|heatmap|bar|line|scatter)\b",
287
+ user_message,
288
+ flags=re.IGNORECASE,
289
+ ):
290
+ ai_content = (
291
+ "I see a code snippet. Please describe the visualization you want (e.g., 'line chart of y over x')."
292
+ )
293
+ self.chat_history.append(AIMessage(content=ai_content))
294
+ if self.debug:
295
+ self._logger.debug("code-only message received; returning guidance without invoking graph")
296
+ return ai_content
297
+
298
+ # Build graph messages (includes tool call/observation history)
299
+ graph_messages = [*self._graph_messages, HumanMessage(content=user_message)]
300
+ if self.debug:
301
+ self._logger.debug(f"process_message() user: {user_message}")
302
+ self._logger.debug(f"graph message count before invoke: {len(graph_messages)}")
303
+ # Invoke the LangGraph agent
304
+ result = self.agent_executor.invoke(
305
+ {"messages": graph_messages},
306
+ config={"recursion_limit": self.max_iterations},
227
307
  )
228
308
 
229
- # Add agent response to chat history
230
- self.chat_history.append(AIMessage(content=response["output"]))
231
-
232
- # If the agent didn't execute the code, but did generate code, execute it directly
233
- if self.execution_env.fig is None and self.generated_code is not None:
234
- self.execution_env.execute_code(self.generated_code)
235
-
236
- # If we can extract code from the response when no code was executed, try that too
237
- if self.execution_env.fig is None and "```python" in response["output"]:
238
- code_blocks = response["output"].split("```python")
239
- if len(code_blocks) > 1:
240
- generated_code = code_blocks[1].split("```")[0].strip()
241
- self.execution_env.execute_code(generated_code)
242
-
243
- # Return the agent's response
244
- return response["output"]
309
+ # Extract the latest AI message from the returned messages
310
+ ai_messages = [m for m in result.get("messages", []) if isinstance(m, AIMessage)]
311
+ ai_content = ai_messages[-1].content if ai_messages else ""
312
+
313
+ # Persist full graph messages for future context
314
+ self._graph_messages = result.get("messages", [])
315
+ if self.debug:
316
+ self._logger.debug(f"graph message count after invoke: {len(self._graph_messages)}")
317
+
318
+ # Add agent response to outward-facing chat history
319
+ self.chat_history.append(AIMessage(content=ai_content))
320
+
321
+ # If the agent didn't execute the code via tool, but we have prior generated_code, execute it
322
+ if self.execution_env and self.execution_env.fig is None and self.generated_code is not None:
323
+ if self.debug:
324
+ self._logger.debug("executing stored generated_code because no fig exists yet")
325
+ exec_result = self.execution_env.execute_code(self.generated_code)
326
+ if self.debug:
327
+ self._logger.debug(f"execution result success={exec_result.get('success')} error={exec_result.get('error')!r}")
328
+
329
+ # If the assistant returned code in the message, execute it to update the figure
330
+ code_executed = False
331
+ if self.execution_env and isinstance(ai_content, str):
332
+ extracted_code = None
333
+ if "```python" in ai_content:
334
+ parts = ai_content.split("```python", 1)
335
+ extracted_code = parts[1].split("```", 1)[0].strip() if len(parts) > 1 else None
336
+ elif "```" in ai_content:
337
+ # Fallback: extract first generic fenced code block
338
+ parts = ai_content.split("```", 1)
339
+ if len(parts) > 1:
340
+ extracted_code = parts[1].split("```", 1)[0].strip()
341
+ if extracted_code:
342
+ if (self.generated_code or "").strip() != extracted_code:
343
+ self.generated_code = extracted_code
344
+ if self.debug:
345
+ self._logger.debug("executing code extracted from AI message")
346
+ exec_result = self.execution_env.execute_code(extracted_code)
347
+ if self.debug:
348
+ self._logger.debug(f"execution result success={exec_result.get('success')} error={exec_result.get('error')!r}")
349
+ code_executed = True
350
+
351
+ # If still no figure and no code was executed, run one guided retry to force tool usage
352
+ if self.execution_env and self.execution_env.fig is None and not code_executed:
353
+ if self.debug:
354
+ self._logger.debug("guided retry: prompting model to use execute_plotly_code tool")
355
+ guided_messages = [
356
+ *self._graph_messages,
357
+ HumanMessage(
358
+ content=(
359
+ "Please use the execute_plotly_code(generated_code) tool with the FULL code to "
360
+ "create a variable named 'fig', then call does_fig_exist(). Return the final "
361
+ "code in a fenced ```python block."
362
+ )
363
+ ),
364
+ ]
365
+ retry_result = self.agent_executor.invoke(
366
+ {"messages": guided_messages},
367
+ config={"recursion_limit": max(3, self.max_iterations // 2)},
368
+ )
369
+ self._graph_messages = retry_result.get("messages", [])
370
+ retry_ai_messages = [
371
+ m for m in self._graph_messages if isinstance(m, AIMessage)
372
+ ]
373
+ retry_content = retry_ai_messages[-1].content if retry_ai_messages else ""
374
+ if isinstance(retry_content, str):
375
+ if "```python" in retry_content:
376
+ parts = retry_content.split("```python", 1)
377
+ retry_code = (
378
+ parts[1].split("```", 1)[0].strip() if len(parts) > 1 else None
379
+ )
380
+ elif "```" in retry_content:
381
+ parts = retry_content.split("```", 1)
382
+ retry_code = (
383
+ parts[1].split("```", 1)[0].strip() if len(parts) > 1 else None
384
+ )
385
+ else:
386
+ retry_code = None
387
+ if retry_code:
388
+ if (self.generated_code or "").strip() != retry_code:
389
+ self.generated_code = retry_code
390
+ if self.debug:
391
+ self._logger.debug("executing code extracted from guided retry response")
392
+ exec_result = self.execution_env.execute_code(retry_code)
393
+ if self.debug:
394
+ self._logger.debug(f"execution result success={exec_result.get('success')} error={exec_result.get('error')!r}")
395
+
396
+ return ai_content if isinstance(ai_content, str) else str(ai_content)
245
397
 
246
398
  def get_figure(self):
247
399
  """Return the current figure if one exists."""
plot_agent/execution.py CHANGED
@@ -1,90 +1,250 @@
1
- import sys
2
- from io import StringIO
1
+ """
2
+ This module contains the PlotAgentExecutionEnvironment class, which is used to safely execute LLM‑generated plotting code and capture `fig`.
3
+
4
+ Security features:
5
+ • Only allow imports from a fixed list of packages (pandas, numpy,
6
+ matplotlib, plotly, sklearn)
7
+ • AST scan rejects any import outside that list and any __dunder__ access
8
+ • Sandbox builtins to include only a minimal safe set + our _safe_import
9
+ • Enforce a 60 second timeout via signal.alarm
10
+ """
11
+ import ast
12
+ import builtins
13
+ import signal
14
+ import threading
3
15
  import traceback
16
+ from io import StringIO
17
+ import contextlib
18
+
4
19
  import pandas as pd
5
- import plotly.express as px
6
- import plotly.graph_objects as go
7
20
  import numpy as np
8
21
  import matplotlib.pyplot as plt
22
+ import plotly.express as px
23
+ import plotly.graph_objects as go
9
24
  from plotly.subplots import make_subplots
10
- from typing import Dict, Any
25
+
26
+
27
+ def _timeout_handler(signum, frame):
28
+ raise TimeoutError("Code execution timed out")
29
+
30
+
31
+ # List of allowed modules
32
+ _ALLOWED_MODULES = {
33
+ "pandas",
34
+ "numpy",
35
+ "matplotlib",
36
+ "plotly",
37
+ "sklearn",
38
+ "scipy",
39
+ }
40
+
41
+
42
+ # Wrap the real __import__ so only our allowlist can get in
43
+ _orig_import = builtins.__import__
44
+
45
+
46
+ def _safe_import(name, globals=None, locals=None, fromlist=(), level=0):
47
+ """
48
+ Wrap the real __import__ so only our allowlist can get in
49
+ """
50
+ root = name.split(".", 1)[0]
51
+ if root in _ALLOWED_MODULES:
52
+ return _orig_import(name, globals, locals, fromlist, level)
53
+ # If the module is not in the allowlist, raise an ImportError
54
+ raise ImportError(f"Import of module '{name}' is not allowed.")
11
55
 
12
56
 
13
57
  class PlotAgentExecutionEnvironment:
14
58
  """
15
- Environment to safely execute plotly code and capture the fig object.
59
+ Environment to safely execute LLM‑generated plotting code and capture `fig`.
16
60
 
17
- Args:
18
- df (pd.DataFrame): The dataframe to use for the execution environment.
61
+ Security features:
62
+ Only allow imports from a fixed list of packages (pandas, numpy,
63
+ matplotlib, plotly, sklearn)
64
+ • AST scan rejects any import outside that list and any __dunder__ access
65
+ • Sandbox builtins to include only a minimal safe set + our _safe_import
66
+ • Enforce a 60 second timeout via signal.alarm
67
+ • Capture both stdout & stderr
68
+ • Purge any old `fig` between runs
19
69
  """
20
70
 
71
+ TIMEOUT_SECONDS = 60
72
+
73
+ # A lean set of builtins, plus our safe-import hook
74
+ _SAFE_BUILTINS = {
75
+ "abs": abs,
76
+ "all": all,
77
+ "any": any,
78
+ "bin": bin,
79
+ "bool": bool,
80
+ "chr": chr,
81
+ "dict": dict,
82
+ "divmod": divmod,
83
+ "enumerate": enumerate,
84
+ "float": float,
85
+ "int": int,
86
+ "len": len,
87
+ "list": list,
88
+ "map": map,
89
+ "max": max,
90
+ "min": min,
91
+ "next": next,
92
+ "pow": pow,
93
+ "print": print,
94
+ "range": range,
95
+ "reversed": reversed,
96
+ "round": round,
97
+ "set": set,
98
+ "str": str,
99
+ "sum": sum,
100
+ "tuple": tuple,
101
+ "zip": zip,
102
+ # basic exceptions so user code can raise/catch
103
+ "BaseException": BaseException,
104
+ "Exception": Exception,
105
+ "ValueError": ValueError,
106
+ "TypeError": TypeError,
107
+ "NameError": NameError,
108
+ "IndexError": IndexError,
109
+ "KeyError": KeyError,
110
+ # our import guard
111
+ "__import__": _safe_import,
112
+ }
113
+
21
114
  def __init__(self, df: pd.DataFrame):
22
115
  """
23
- Initialize the execution environment with the given dataframe.
24
-
25
- Args:
26
- df (pd.DataFrame): The dataframe to use for the execution environment.
116
+ Initialize the execution environment with a dataframe.
27
117
  """
28
118
  self.df = df
29
- self.locals_dict = {
119
+ # Base namespace for both globals & locals
120
+ self._base_ns = {
121
+ "__builtins__": self._SAFE_BUILTINS,
30
122
  "df": df,
31
- "px": px,
32
- "go": go,
33
123
  "pd": pd,
34
124
  "np": np,
35
125
  "plt": plt,
126
+ "px": px,
127
+ "go": go,
36
128
  "make_subplots": make_subplots,
37
129
  }
38
- self.output = None
39
- self.error = None
40
130
  self.fig = None
41
131
 
42
- def execute_code(self, generated_code: str) -> Dict[str, Any]:
132
+ def _validate_ast(self, node: ast.AST):
133
+ """
134
+ Walk the AST and enforce:
135
+ • any Import/ImportFrom must be from _ALLOWED_MODULES
136
+ • no __dunder__ attribute access
43
137
  """
44
- Execute the provided code and capture the fig object if created.
138
+ # Walk the AST and enforce:
139
+ for child in ast.walk(node):
140
+ # Check for imports
141
+ if isinstance(child, ast.Import):
142
+ # Check for imports
143
+ for alias in child.names:
144
+ root = alias.name.split(".", 1)[0]
145
+ # Check if the module is in the allowlist
146
+ if root not in _ALLOWED_MODULES:
147
+ raise ValueError(f"Import of '{alias.name}' is not allowed.")
148
+ # Check for import-froms
149
+ elif isinstance(child, ast.ImportFrom):
150
+ root = (child.module or "").split(".", 1)[0]
151
+ if root not in _ALLOWED_MODULES:
152
+ raise ValueError(f"Import-from of '{child.module}' is not allowed.")
153
+ # Check for dunder attribute access
154
+ elif isinstance(child, ast.Attribute) and child.attr.startswith("__"):
155
+ raise ValueError("Access to dunder attributes is forbidden.")
45
156
 
46
- Args:
47
- generated_code (str): The code to execute.
157
+ def execute_code(self, generated_code: str):
158
+ """
159
+ Execute the user code in a locked‑down sandbox.
48
160
 
49
- Returns:
50
- Dict[str, Any]: A dictionary containing the fig object, output, error, and success status.
161
+ Returns a dict with:
162
+ - fig: The figure if created, else None
163
+ - output: Captured stdout
164
+ - error: Captured stderr or exception text
165
+ - success: True if fig was produced and no errors
51
166
  """
52
- self.output = None
53
- self.error = None
54
167
 
55
- # Capture stdout
56
- old_stdout = sys.stdout
57
- sys.stdout = mystdout = StringIO()
168
+ # Copy the base namespace
169
+ ns = self._base_ns.copy()
170
+ # Purge any old `fig`
171
+ ns.pop("fig", None)
58
172
 
59
173
  try:
60
- # Execute the code
61
- exec(generated_code, globals(), self.locals_dict)
174
+ # Parse the generated code
175
+ tree = ast.parse(generated_code)
176
+ # Validate the AST
177
+ self._validate_ast(tree)
178
+ except Exception as e:
179
+ # If the code is rejected on safety grounds, return an error
180
+ return {
181
+ "fig": None,
182
+ "output": "",
183
+ "error": f"Code rejected on safety grounds: {e}",
184
+ "success": False,
185
+ }
62
186
 
63
- # Check if a fig object was created
64
- if "fig" in self.locals_dict:
65
- self.fig = self.locals_dict["fig"]
66
- self.output = "Code executed successfully. 'fig' object was created."
67
- else:
68
- print(f"no fig object created: {generated_code}")
69
- self.error = "Code executed without errors, but no 'fig' object was created. Make sure your code creates a variable named 'fig'."
187
+ # Set a timeout only if running on the main thread; signals are not supported in worker threads
188
+ timeout_set = False
189
+ try:
190
+ if threading.current_thread() is threading.main_thread():
191
+ signal.signal(signal.SIGALRM, _timeout_handler)
192
+ signal.alarm(self.TIMEOUT_SECONDS)
193
+ timeout_set = True
194
+ except Exception:
195
+ # If setting the signal handler fails (e.g., not in main thread), proceed without timeout
196
+ timeout_set = False
70
197
 
198
+ # Execute the code
199
+ out_buf, err_buf = StringIO(), StringIO()
200
+ try:
201
+ # Redirect stdout and stderr
202
+ with contextlib.redirect_stdout(out_buf), contextlib.redirect_stderr(
203
+ err_buf
204
+ ):
205
+ # Execute the code
206
+ exec(generated_code, ns, ns)
207
+ except TimeoutError as te:
208
+ # If the code execution timed out, return an error
209
+ tb = traceback.format_exc()
210
+ return {
211
+ "fig": None,
212
+ "output": out_buf.getvalue(),
213
+ "error": f"Code execution timed out: {te}\n{tb}",
214
+ "success": False,
215
+ }
71
216
  except Exception as e:
72
- self.error = f"Error executing code: {str(e)}\n{traceback.format_exc()}"
73
-
217
+ # If there was an error, return an error
218
+ tb = traceback.format_exc()
219
+ return {
220
+ "fig": None,
221
+ "output": out_buf.getvalue(),
222
+ "error": f"Error executing code: {e}\n{tb}",
223
+ "success": False,
224
+ }
74
225
  finally:
75
- # Restore stdout
76
- sys.stdout = old_stdout
77
- captured_output = mystdout.getvalue()
226
+ # Reset the timeout if it was set
227
+ if timeout_set:
228
+ try:
229
+ signal.alarm(0)
230
+ except Exception:
231
+ pass
78
232
 
79
- if captured_output.strip():
80
- if self.output:
81
- self.output += f"\nOutput:\n{captured_output}"
82
- else:
83
- self.output = f"Output:\n{captured_output}"
233
+ # Get the `fig`
234
+ fig = ns.get("fig")
235
+ self.fig = fig
236
+ if fig is None:
237
+ return {
238
+ "fig": None,
239
+ "output": out_buf.getvalue(),
240
+ "error": "No `fig` created. Assign your figure to a variable named `fig`.",
241
+ "success": False,
242
+ }
84
243
 
244
+ # Return the result
85
245
  return {
86
- "fig": self.fig,
87
- "output": self.output,
88
- "error": self.error,
89
- "success": self.error is None and self.fig is not None,
246
+ "fig": fig,
247
+ "output": "Code executed successfully. 'fig' object was created.",
248
+ "error": "",
249
+ "success": True,
90
250
  }
plot_agent/models.py CHANGED
@@ -1,3 +1,7 @@
1
+ """
2
+ This module contains the models for the PlotAgent.
3
+ """
4
+
1
5
  from pydantic import BaseModel, Field
2
6
 
3
7
 
plot_agent/prompt.py CHANGED
@@ -1,3 +1,7 @@
1
+ """
2
+ This module contains the prompts for the PlotAgent.
3
+ """
4
+
1
5
  DEFAULT_SYSTEM_PROMPT = """
2
6
  You are an expert data visualization assistant that helps users create Plotly visualizations in Python.
3
7
  Your job is to generate Python and Plotly code based on the user's request that will create the desired visualization
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plot-agent
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary: An AI-powered data visualization assistant using Plotly
5
5
  Author-email: andrewm4894 <andrewm4894@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
@@ -19,6 +19,8 @@ Dynamic: license-file
19
19
 
20
20
  An AI-powered data visualization assistant that helps users create Plotly visualizations in Python.
21
21
 
22
+ Built on LangGraph with tool-calling to reliably execute generated Plotly code in a sandbox and keep the current `fig` in sync.
23
+
22
24
  ## Installation
23
25
 
24
26
  You can install the package using pip:
@@ -37,7 +39,7 @@ Here's a simple minimal example of how to use Plot Agent:
37
39
  import pandas as pd
38
40
  from plot_agent.agent import PlotAgent
39
41
 
40
- # ensure OPENAI_API_KEY is set and available for langchain
42
+ # ensure OPENAI_API_KEY is set (env or .env); optional debug via PLOT_AGENT_DEBUG=1
41
43
 
42
44
  # Create a sample dataframe
43
45
  df = pd.DataFrame({
@@ -92,19 +94,72 @@ fig.update_layout(
92
94
  )
93
95
  ```
94
96
 
97
+ ## How it works
98
+
99
+ ```mermaid
100
+ flowchart TD
101
+ A[User message] --> B{LangGraph ReAct Agent}
102
+ subgraph Tools
103
+ T1[execute_plotly_code<br/>- runs code in sandbox<br/>- returns success/fig/error]
104
+ T2[does_fig_exist]
105
+ T3[view_generated_code]
106
+ end
107
+ B -- tool call --> T1
108
+ T1 -- result --> B
109
+ B -- optional --> T2
110
+ B -- optional --> T3
111
+ B --> C[AI response]
112
+ C --> D{Agent wrapper}
113
+ D -- persist messages --> B
114
+ D -- extract code blocks --> E[Sandbox execution]
115
+ E --> F[fig]
116
+ F --> G[get_figure]
117
+ ```
118
+
119
+ - The LangGraph agent plans and decides when to call tools.
120
+ - The wrapper persists full graph messages between turns and executes any returned code blocks to keep `fig` updated.
121
+ - A safe execution environment runs code with an allowlist and a main-thread-only timeout.
122
+
95
123
  ## Features
96
124
 
97
125
  - AI-powered visualization generation
98
126
  - Support for various Plotly chart types
99
127
  - Automatic data preprocessing
100
128
  - Interactive visualization capabilities
101
- - Integration with LangChain for advanced AI capabilities
129
+ - LangGraph-based tool calling and control flow
130
+ - Debug logging via `PlotAgent(debug=True)` or `PLOT_AGENT_DEBUG=1`
102
131
 
103
132
  ## Requirements
104
133
 
105
134
  - Python 3.8 or higher
106
135
  - Dependencies are automatically installed with the package
107
136
 
137
+ ## Development
138
+
139
+ - Run unit tests:
140
+
141
+ ```bash
142
+ make test
143
+ ```
144
+
145
+ - Execute all example notebooks:
146
+
147
+ ```bash
148
+ make run-examples
149
+ ```
150
+
151
+ - Execute with debug logs enabled:
152
+
153
+ ```bash
154
+ make run-examples-debug
155
+ ```
156
+
157
+ - Quick CLI repro that prints evolving code each step:
158
+
159
+ ```bash
160
+ make run-example-script
161
+ ```
162
+
108
163
  ## License
109
164
 
110
165
  This project is licensed under the MIT License - see the LICENSE file for details.
@@ -0,0 +1,10 @@
1
+ plot_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ plot_agent/agent.py,sha256=9N33lBkXe3PSYP6SsQkX2X2CAu7n2UjkRZap8hgxwzo,16733
3
+ plot_agent/execution.py,sha256=FaMWKyFxQewrVx5tpHKLI7WafE9Q2ogvXhOVYZ4G3hw,8086
4
+ plot_agent/models.py,sha256=THdGGGfGmRZ5rtgXvjPcQxFRRTZVFoADEHI_lsMVha8,860
5
+ plot_agent/prompt.py,sha256=5hBlF7jdMrj6MiGEL7YmSDWFUfiCXyIZfZtf3NstKoo,3125
6
+ plot_agent-0.4.0.dist-info/licenses/LICENSE,sha256=A4DPih7wHrh4VMEG3p1PhorqdhjmGIo8nQdYNQL7daA,1062
7
+ plot_agent-0.4.0.dist-info/METADATA,sha256=iNH_2qd_k_Jp3RyWiFRitY4Vwx4ZHFNwMSFFodzMcfE,4152
8
+ plot_agent-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ plot_agent-0.4.0.dist-info/top_level.txt,sha256=KyOjpihUssx26Ra-37vKUQ71pI2qgJsHaRwXHJUhjzQ,11
10
+ plot_agent-0.4.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,10 +0,0 @@
1
- plot_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- plot_agent/agent.py,sha256=FPHma_u_77y3c5nBrCPS0VR8AWz-NFFRqURn0xHph-M,9584
3
- plot_agent/execution.py,sha256=Ewnz6Pb2EWM0pDAmeVoJ5RmhY5gmO1F7iOvERkj2D28,2770
4
- plot_agent/models.py,sha256=MkAaSELr54RfGRONKfiqCRA2ghlRbzTe5L5krVuqMsk,800
5
- plot_agent/prompt.py,sha256=HjRgbsAe8HHs8arQogvzOGQdThEWKRqQhtQyaUplxhQ,3064
6
- plot_agent-0.3.0.dist-info/licenses/LICENSE,sha256=A4DPih7wHrh4VMEG3p1PhorqdhjmGIo8nQdYNQL7daA,1062
7
- plot_agent-0.3.0.dist-info/METADATA,sha256=Cqw5gPpT2OM0fPaV_XLFJChEjsAFrVKfiwvxWxdVphM,2837
8
- plot_agent-0.3.0.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
9
- plot_agent-0.3.0.dist-info/top_level.txt,sha256=KyOjpihUssx26Ra-37vKUQ71pI2qgJsHaRwXHJUhjzQ,11
10
- plot_agent-0.3.0.dist-info/RECORD,,