plot-agent 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plot_agent/agent.py CHANGED
@@ -1,3 +1,7 @@
1
+ """
2
+ This module contains the PlotAgent class, which is used to generate Plotly code based on a user's plot description.
3
+ """
4
+
1
5
  import pandas as pd
2
6
  from io import StringIO
3
7
  from typing import Optional
@@ -14,10 +18,10 @@ from plot_agent.models import (
14
18
  DoesFigExistInput,
15
19
  ViewGeneratedCodeInput,
16
20
  )
17
- from plot_agent.execution import PlotlyAgentExecutionEnvironment
21
+ from plot_agent.execution import PlotAgentExecutionEnvironment
18
22
 
19
23
 
20
- class PlotlyAgent:
24
+ class PlotAgent:
21
25
  """
22
26
  A class that uses an LLM to generate Plotly code based on a user's plot description.
23
27
  """
@@ -32,7 +36,7 @@ class PlotlyAgent:
32
36
  handle_parsing_errors: bool = True,
33
37
  ):
34
38
  """
35
- Initialize the PlotlyAgent.
39
+ Initialize the PlotAgent.
36
40
 
37
41
  Args:
38
42
  model (str): The model to use for the LLM.
@@ -90,7 +94,7 @@ class PlotlyAgent:
90
94
  self.sql_query = sql_query
91
95
 
92
96
  # Initialize execution environment
93
- self.execution_env = PlotlyAgentExecutionEnvironment(df)
97
+ self.execution_env = PlotAgentExecutionEnvironment(df)
94
98
 
95
99
  # Initialize the agent with tools
96
100
  self._initialize_agent()
@@ -123,7 +127,7 @@ class PlotlyAgent:
123
127
 
124
128
  # Check if the code executed successfully
125
129
  if code_execution_success:
126
- return f"Code executed successfully! A figure object was created.\n{code_execution_output}"
130
+ return f"Success: {code_execution_output}"
127
131
  else:
128
132
  return f"Error: {code_execution_error}\n{code_execution_output}"
129
133
 
@@ -160,19 +164,29 @@ class PlotlyAgent:
160
164
  Tool.from_function(
161
165
  func=self.execute_plotly_code,
162
166
  name="execute_plotly_code",
163
- description="Execute the provided Plotly code and return a result indicating if the code executed successfully and if a figure object was created.",
167
+ description=(
168
+ "Execute the provided Plotly code and return a result indicating "
169
+ "if the code executed successfully and if a figure object was created."
170
+ ),
164
171
  args_schema=GeneratedCodeInput,
165
172
  ),
166
173
  StructuredTool.from_function(
167
174
  func=self.does_fig_exist,
168
175
  name="does_fig_exist",
169
- description="Check if a figure exists and is available for display. This tool takes no arguments and returns a string indicating if a figure is available for display or not.",
176
+ description=(
177
+ "Check if a figure exists and is available for display. "
178
+ "This tool takes no arguments and returns a string indicating "
179
+ "if a figure is available for display or not."
180
+ ),
170
181
  args_schema=DoesFigExistInput,
171
182
  ),
172
183
  StructuredTool.from_function(
173
184
  func=self.view_generated_code,
174
185
  name="view_generated_code",
175
- description="View the generated code. This tool takes no arguments and returns the generated code as a string.",
186
+ description=(
187
+ "View the generated code. "
188
+ "This tool takes no arguments and returns the generated code as a string."
189
+ ),
176
190
  args_schema=ViewGeneratedCodeInput,
177
191
  ),
178
192
  ]
plot_agent/execution.py CHANGED
@@ -1,90 +1,238 @@
1
- import sys
2
- from io import StringIO
1
+ """
2
+ This module contains the PlotAgentExecutionEnvironment class, which is used to safely execute LLM‑generated plotting code and capture `fig`.
3
+
4
+ Security features:
5
+ • Only allow imports from a fixed list of packages (pandas, numpy,
6
+ matplotlib, plotly, sklearn)
7
+ • AST scan rejects any import outside that list and any __dunder__ access
8
+ • Sandbox builtins to include only a minimal safe set + our _safe_import
9
+ • Enforce a 60 second timeout via signal.alarm
10
+ """
11
+ import ast
12
+ import builtins
13
+ import signal
3
14
  import traceback
15
+ from io import StringIO
16
+ import contextlib
17
+
4
18
  import pandas as pd
5
- import plotly.express as px
6
- import plotly.graph_objects as go
7
19
  import numpy as np
8
20
  import matplotlib.pyplot as plt
21
+ import plotly.express as px
22
+ import plotly.graph_objects as go
9
23
  from plotly.subplots import make_subplots
10
- from typing import Dict, Any
11
24
 
12
25
 
13
- class PlotlyAgentExecutionEnvironment:
26
+ def _timeout_handler(signum, frame):
27
+ raise TimeoutError("Code execution timed out")
28
+
29
+
30
+ # List of allowed modules
31
+ _ALLOWED_MODULES = {
32
+ "pandas",
33
+ "numpy",
34
+ "matplotlib",
35
+ "plotly",
36
+ "sklearn",
37
+ "scipy",
38
+ }
39
+
40
+
41
+ # Wrap the real __import__ so only our allowlist can get in
42
+ _orig_import = builtins.__import__
43
+
44
+
45
+ def _safe_import(name, globals=None, locals=None, fromlist=(), level=0):
46
+ """
47
+ Wrap the real __import__ so only our allowlist can get in
48
+ """
49
+ root = name.split(".", 1)[0]
50
+ if root in _ALLOWED_MODULES:
51
+ return _orig_import(name, globals, locals, fromlist, level)
52
+ # If the module is not in the allowlist, raise an ImportError
53
+ raise ImportError(f"Import of module '{name}' is not allowed.")
54
+
55
+
56
+ class PlotAgentExecutionEnvironment:
14
57
  """
15
- Environment to safely execute plotly code and capture the fig object.
58
+ Environment to safely execute LLM‑generated plotting code and capture `fig`.
16
59
 
17
- Args:
18
- df (pd.DataFrame): The dataframe to use for the execution environment.
60
+ Security features:
61
+ Only allow imports from a fixed list of packages (pandas, numpy,
62
+ matplotlib, plotly, sklearn)
63
+ • AST scan rejects any import outside that list and any __dunder__ access
64
+ • Sandbox builtins to include only a minimal safe set + our _safe_import
65
+ • Enforce a 60 second timeout via signal.alarm
66
+ • Capture both stdout & stderr
67
+ • Purge any old `fig` between runs
19
68
  """
20
69
 
70
+ TIMEOUT_SECONDS = 60
71
+
72
+ # A lean set of builtins, plus our safe-import hook
73
+ _SAFE_BUILTINS = {
74
+ "abs": abs,
75
+ "all": all,
76
+ "any": any,
77
+ "bin": bin,
78
+ "bool": bool,
79
+ "chr": chr,
80
+ "dict": dict,
81
+ "divmod": divmod,
82
+ "enumerate": enumerate,
83
+ "float": float,
84
+ "int": int,
85
+ "len": len,
86
+ "list": list,
87
+ "map": map,
88
+ "max": max,
89
+ "min": min,
90
+ "next": next,
91
+ "pow": pow,
92
+ "print": print,
93
+ "range": range,
94
+ "reversed": reversed,
95
+ "round": round,
96
+ "set": set,
97
+ "str": str,
98
+ "sum": sum,
99
+ "tuple": tuple,
100
+ "zip": zip,
101
+ # basic exceptions so user code can raise/catch
102
+ "BaseException": BaseException,
103
+ "Exception": Exception,
104
+ "ValueError": ValueError,
105
+ "TypeError": TypeError,
106
+ "NameError": NameError,
107
+ "IndexError": IndexError,
108
+ "KeyError": KeyError,
109
+ # our import guard
110
+ "__import__": _safe_import,
111
+ }
112
+
21
113
  def __init__(self, df: pd.DataFrame):
22
114
  """
23
- Initialize the execution environment with the given dataframe.
24
-
25
- Args:
26
- df (pd.DataFrame): The dataframe to use for the execution environment.
115
+ Initialize the execution environment with a dataframe.
27
116
  """
28
117
  self.df = df
29
- self.locals_dict = {
118
+ # Base namespace for both globals & locals
119
+ self._base_ns = {
120
+ "__builtins__": self._SAFE_BUILTINS,
30
121
  "df": df,
31
- "px": px,
32
- "go": go,
33
122
  "pd": pd,
34
123
  "np": np,
35
124
  "plt": plt,
125
+ "px": px,
126
+ "go": go,
36
127
  "make_subplots": make_subplots,
37
128
  }
38
- self.output = None
39
- self.error = None
40
129
  self.fig = None
41
130
 
42
- def execute_code(self, generated_code: str) -> Dict[str, Any]:
131
+ def _validate_ast(self, node: ast.AST):
43
132
  """
44
- Execute the provided code and capture the fig object if created.
133
+ Walk the AST and enforce:
134
+ • any Import/ImportFrom must be from _ALLOWED_MODULES
135
+ • no __dunder__ attribute access
136
+ """
137
+ # Walk the AST and enforce:
138
+ for child in ast.walk(node):
139
+ # Check for imports
140
+ if isinstance(child, ast.Import):
141
+ # Check for imports
142
+ for alias in child.names:
143
+ root = alias.name.split(".", 1)[0]
144
+ # Check if the module is in the allowlist
145
+ if root not in _ALLOWED_MODULES:
146
+ raise ValueError(f"Import of '{alias.name}' is not allowed.")
147
+ # Check for import-froms
148
+ elif isinstance(child, ast.ImportFrom):
149
+ root = (child.module or "").split(".", 1)[0]
150
+ if root not in _ALLOWED_MODULES:
151
+ raise ValueError(f"Import-from of '{child.module}' is not allowed.")
152
+ # Check for dunder attribute access
153
+ elif isinstance(child, ast.Attribute) and child.attr.startswith("__"):
154
+ raise ValueError("Access to dunder attributes is forbidden.")
45
155
 
46
- Args:
47
- generated_code (str): The code to execute.
156
+ def execute_code(self, generated_code: str):
157
+ """
158
+ Execute the user code in a locked‑down sandbox.
48
159
 
49
- Returns:
50
- Dict[str, Any]: A dictionary containing the fig object, output, error, and success status.
160
+ Returns a dict with:
161
+ - fig: The figure if created, else None
162
+ - output: Captured stdout
163
+ - error: Captured stderr or exception text
164
+ - success: True if fig was produced and no errors
51
165
  """
52
- self.output = None
53
- self.error = None
54
166
 
55
- # Capture stdout
56
- old_stdout = sys.stdout
57
- sys.stdout = mystdout = StringIO()
167
+ # Copy the base namespace
168
+ ns = self._base_ns.copy()
169
+ # Purge any old `fig`
170
+ ns.pop("fig", None)
58
171
 
59
172
  try:
60
- # Execute the code
61
- exec(generated_code, globals(), self.locals_dict)
173
+ # Parse the generated code
174
+ tree = ast.parse(generated_code)
175
+ # Validate the AST
176
+ self._validate_ast(tree)
177
+ except Exception as e:
178
+ # If the code is rejected on safety grounds, return an error
179
+ return {
180
+ "fig": None,
181
+ "output": "",
182
+ "error": f"Code rejected on safety grounds: {e}",
183
+ "success": False,
184
+ }
62
185
 
63
- # Check if a fig object was created
64
- if "fig" in self.locals_dict:
65
- self.fig = self.locals_dict["fig"]
66
- self.output = "Code executed successfully. 'fig' object was created."
67
- else:
68
- print(f"no fig object created: {generated_code}")
69
- self.error = "Code executed without errors, but no 'fig' object was created. Make sure your code creates a variable named 'fig'."
186
+ # Set a timeout
187
+ signal.signal(signal.SIGALRM, _timeout_handler)
188
+ signal.alarm(self.TIMEOUT_SECONDS)
70
189
 
190
+ # Execute the code
191
+ out_buf, err_buf = StringIO(), StringIO()
192
+ try:
193
+ # Redirect stdout and stderr
194
+ with contextlib.redirect_stdout(out_buf), contextlib.redirect_stderr(
195
+ err_buf
196
+ ):
197
+ # Execute the code
198
+ exec(generated_code, ns, ns)
199
+ except TimeoutError as te:
200
+ # If the code execution timed out, return an error
201
+ tb = traceback.format_exc()
202
+ return {
203
+ "fig": None,
204
+ "output": out_buf.getvalue(),
205
+ "error": f"Code execution timed out: {te}\n{tb}",
206
+ "success": False,
207
+ }
71
208
  except Exception as e:
72
- self.error = f"Error executing code: {str(e)}\n{traceback.format_exc()}"
73
-
209
+ # If there was an error, return an error
210
+ tb = traceback.format_exc()
211
+ return {
212
+ "fig": None,
213
+ "output": out_buf.getvalue(),
214
+ "error": f"Error executing code: {e}\n{tb}",
215
+ "success": False,
216
+ }
74
217
  finally:
75
- # Restore stdout
76
- sys.stdout = old_stdout
77
- captured_output = mystdout.getvalue()
218
+ # Reset the timeout
219
+ signal.alarm(0)
78
220
 
79
- if captured_output.strip():
80
- if self.output:
81
- self.output += f"\nOutput:\n{captured_output}"
82
- else:
83
- self.output = f"Output:\n{captured_output}"
221
+ # Get the `fig`
222
+ fig = ns.get("fig")
223
+ self.fig = fig
224
+ if fig is None:
225
+ return {
226
+ "fig": None,
227
+ "output": out_buf.getvalue(),
228
+ "error": "No `fig` created. Assign your figure to a variable named `fig`.",
229
+ "success": False,
230
+ }
84
231
 
232
+ # Return the result
85
233
  return {
86
- "fig": self.fig,
87
- "output": self.output,
88
- "error": self.error,
89
- "success": self.error is None and self.fig is not None,
234
+ "fig": fig,
235
+ "output": "Code executed successfully. 'fig' object was created.",
236
+ "error": "",
237
+ "success": True,
90
238
  }
plot_agent/models.py CHANGED
@@ -1,14 +1,21 @@
1
+ """
2
+ This module contains the models for the PlotAgent.
3
+ """
4
+
1
5
  from pydantic import BaseModel, Field
2
6
 
3
7
 
4
- # Define input schemas for the tools
5
8
  class PlotDescriptionInput(BaseModel):
9
+ """Model indicating that the plot_description function takes a plot_description argument."""
10
+
6
11
  plot_description: str = Field(
7
12
  ..., description="Description of the plot the user wants to create"
8
13
  )
9
14
 
10
15
 
11
16
  class GeneratedCodeInput(BaseModel):
17
+ """Model indicating that the generated_code function takes a generated_code argument."""
18
+
12
19
  generated_code: str = Field(
13
20
  ..., description="Python code that creates a Plotly figure"
14
21
  )
@@ -23,4 +30,4 @@ class DoesFigExistInput(BaseModel):
23
30
  class ViewGeneratedCodeInput(BaseModel):
24
31
  """Model indicating that the view_generated_code function takes no arguments."""
25
32
 
26
- pass
33
+ pass
plot_agent/prompt.py CHANGED
@@ -1,3 +1,7 @@
1
+ """
2
+ This module contains the prompts for the PlotAgent.
3
+ """
4
+
1
5
  DEFAULT_SYSTEM_PROMPT = """
2
6
  You are an expert data visualization assistant that helps users create Plotly visualizations in Python.
3
7
  Your job is to generate Python and Plotly code based on the user's request that will create the desired visualization
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plot-agent
3
- Version: 0.2.2
3
+ Version: 0.3.1
4
4
  Summary: An AI-powered data visualization assistant using Plotly
5
5
  Author-email: andrewm4894 <andrewm4894@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
@@ -35,7 +35,7 @@ Here's a simple minimal example of how to use Plot Agent:
35
35
 
36
36
  ```python
37
37
  import pandas as pd
38
- from plot_agent.agent import PlotlyAgent
38
+ from plot_agent.agent import PlotAgent
39
39
 
40
40
  # ensure OPENAI_API_KEY is set and available for langchain
41
41
 
@@ -46,7 +46,7 @@ df = pd.DataFrame({
46
46
  })
47
47
 
48
48
  # Initialize the agent
49
- agent = PlotlyAgent()
49
+ agent = PlotAgent()
50
50
 
51
51
  # Set the dataframe
52
52
  agent.set_df(df)
@@ -0,0 +1,10 @@
1
+ plot_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ plot_agent/agent.py,sha256=sIG8GMS2A8TP_3kRxbgefn-yZM4K_7niZQR6cJrhl4s,9872
3
+ plot_agent/execution.py,sha256=lQNyPzphPIdMQXxQkaf_g6oDZsU3dgF0or0ysKJm6FM,7537
4
+ plot_agent/models.py,sha256=THdGGGfGmRZ5rtgXvjPcQxFRRTZVFoADEHI_lsMVha8,860
5
+ plot_agent/prompt.py,sha256=5hBlF7jdMrj6MiGEL7YmSDWFUfiCXyIZfZtf3NstKoo,3125
6
+ plot_agent-0.3.1.dist-info/licenses/LICENSE,sha256=A4DPih7wHrh4VMEG3p1PhorqdhjmGIo8nQdYNQL7daA,1062
7
+ plot_agent-0.3.1.dist-info/METADATA,sha256=zkpeWRWczA_CzH7mahNtEuIvumBOhXaNGTiAcUIOQZQ,2837
8
+ plot_agent-0.3.1.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
9
+ plot_agent-0.3.1.dist-info/top_level.txt,sha256=KyOjpihUssx26Ra-37vKUQ71pI2qgJsHaRwXHJUhjzQ,11
10
+ plot_agent-0.3.1.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- plot_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- plot_agent/agent.py,sha256=mflUj-vE_x9_W7XTJ3-GgYFfvg3uXV3wkrUw7wRPlVM,9592
3
- plot_agent/execution.py,sha256=BBKBVGQDrg7BPWvPbLWOjkAAFRlfyu28QxClXuspd8o,2772
4
- plot_agent/models.py,sha256=ZOWWeYaqmnKJarXYyXnBQQ4nwUe71ae_gMul3bZXaWU,644
5
- plot_agent/prompt.py,sha256=HjRgbsAe8HHs8arQogvzOGQdThEWKRqQhtQyaUplxhQ,3064
6
- plot_agent-0.2.2.dist-info/licenses/LICENSE,sha256=A4DPih7wHrh4VMEG3p1PhorqdhjmGIo8nQdYNQL7daA,1062
7
- plot_agent-0.2.2.dist-info/METADATA,sha256=dLTSBMjvxg0U63r-EbN2tfI7-eZGYexq854L0kN8M6M,2841
8
- plot_agent-0.2.2.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
9
- plot_agent-0.2.2.dist-info/top_level.txt,sha256=KyOjpihUssx26Ra-37vKUQ71pI2qgJsHaRwXHJUhjzQ,11
10
- plot_agent-0.2.2.dist-info/RECORD,,