plot-agent 0.2.2__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plot-agent
3
- Version: 0.2.2
3
+ Version: 0.3.1
4
4
  Summary: An AI-powered data visualization assistant using Plotly
5
5
  Author-email: andrewm4894 <andrewm4894@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
@@ -35,7 +35,7 @@ Here's a simple minimal example of how to use Plot Agent:
35
35
 
36
36
  ```python
37
37
  import pandas as pd
38
- from plot_agent.agent import PlotlyAgent
38
+ from plot_agent.agent import PlotAgent
39
39
 
40
40
  # ensure OPENAI_API_KEY is set and available for langchain
41
41
 
@@ -46,7 +46,7 @@ df = pd.DataFrame({
46
46
  })
47
47
 
48
48
  # Initialize the agent
49
- agent = PlotlyAgent()
49
+ agent = PlotAgent()
50
50
 
51
51
  # Set the dataframe
52
52
  agent.set_df(df)
@@ -21,7 +21,7 @@ Here's a simple minimal example of how to use Plot Agent:
21
21
 
22
22
  ```python
23
23
  import pandas as pd
24
- from plot_agent.agent import PlotlyAgent
24
+ from plot_agent.agent import PlotAgent
25
25
 
26
26
  # ensure OPENAI_API_KEY is set and available for langchain
27
27
 
@@ -32,7 +32,7 @@ df = pd.DataFrame({
32
32
  })
33
33
 
34
34
  # Initialize the agent
35
- agent = PlotlyAgent()
35
+ agent = PlotAgent()
36
36
 
37
37
  # Set the dataframe
38
38
  agent.set_df(df)
@@ -1,3 +1,7 @@
1
+ """
2
+ This module contains the PlotAgent class, which is used to generate Plotly code based on a user's plot description.
3
+ """
4
+
1
5
  import pandas as pd
2
6
  from io import StringIO
3
7
  from typing import Optional
@@ -14,10 +18,10 @@ from plot_agent.models import (
14
18
  DoesFigExistInput,
15
19
  ViewGeneratedCodeInput,
16
20
  )
17
- from plot_agent.execution import PlotlyAgentExecutionEnvironment
21
+ from plot_agent.execution import PlotAgentExecutionEnvironment
18
22
 
19
23
 
20
- class PlotlyAgent:
24
+ class PlotAgent:
21
25
  """
22
26
  A class that uses an LLM to generate Plotly code based on a user's plot description.
23
27
  """
@@ -32,7 +36,7 @@ class PlotlyAgent:
32
36
  handle_parsing_errors: bool = True,
33
37
  ):
34
38
  """
35
- Initialize the PlotlyAgent.
39
+ Initialize the PlotAgent.
36
40
 
37
41
  Args:
38
42
  model (str): The model to use for the LLM.
@@ -90,7 +94,7 @@ class PlotlyAgent:
90
94
  self.sql_query = sql_query
91
95
 
92
96
  # Initialize execution environment
93
- self.execution_env = PlotlyAgentExecutionEnvironment(df)
97
+ self.execution_env = PlotAgentExecutionEnvironment(df)
94
98
 
95
99
  # Initialize the agent with tools
96
100
  self._initialize_agent()
@@ -123,7 +127,7 @@ class PlotlyAgent:
123
127
 
124
128
  # Check if the code executed successfully
125
129
  if code_execution_success:
126
- return f"Code executed successfully! A figure object was created.\n{code_execution_output}"
130
+ return f"Success: {code_execution_output}"
127
131
  else:
128
132
  return f"Error: {code_execution_error}\n{code_execution_output}"
129
133
 
@@ -160,19 +164,29 @@ class PlotlyAgent:
160
164
  Tool.from_function(
161
165
  func=self.execute_plotly_code,
162
166
  name="execute_plotly_code",
163
- description="Execute the provided Plotly code and return a result indicating if the code executed successfully and if a figure object was created.",
167
+ description=(
168
+ "Execute the provided Plotly code and return a result indicating "
169
+ "if the code executed successfully and if a figure object was created."
170
+ ),
164
171
  args_schema=GeneratedCodeInput,
165
172
  ),
166
173
  StructuredTool.from_function(
167
174
  func=self.does_fig_exist,
168
175
  name="does_fig_exist",
169
- description="Check if a figure exists and is available for display. This tool takes no arguments and returns a string indicating if a figure is available for display or not.",
176
+ description=(
177
+ "Check if a figure exists and is available for display. "
178
+ "This tool takes no arguments and returns a string indicating "
179
+ "if a figure is available for display or not."
180
+ ),
170
181
  args_schema=DoesFigExistInput,
171
182
  ),
172
183
  StructuredTool.from_function(
173
184
  func=self.view_generated_code,
174
185
  name="view_generated_code",
175
- description="View the generated code. This tool takes no arguments and returns the generated code as a string.",
186
+ description=(
187
+ "View the generated code. "
188
+ "This tool takes no arguments and returns the generated code as a string."
189
+ ),
176
190
  args_schema=ViewGeneratedCodeInput,
177
191
  ),
178
192
  ]
@@ -0,0 +1,238 @@
1
+ """
2
+ This module contains the PlotAgentExecutionEnvironment class, which is used to safely execute LLM‑generated plotting code and capture `fig`.
3
+
4
+ Security features:
5
+ • Only allow imports from a fixed list of packages (pandas, numpy,
6
+ matplotlib, plotly, sklearn)
7
+ • AST scan rejects any import outside that list and any __dunder__ access
8
+ • Sandbox builtins to include only a minimal safe set + our _safe_import
9
+ • Enforce a 60 second timeout via signal.alarm
10
+ """
11
+ import ast
12
+ import builtins
13
+ import signal
14
+ import traceback
15
+ from io import StringIO
16
+ import contextlib
17
+
18
+ import pandas as pd
19
+ import numpy as np
20
+ import matplotlib.pyplot as plt
21
+ import plotly.express as px
22
+ import plotly.graph_objects as go
23
+ from plotly.subplots import make_subplots
24
+
25
+
26
+ def _timeout_handler(signum, frame):
27
+ raise TimeoutError("Code execution timed out")
28
+
29
+
30
+ # List of allowed modules
31
+ _ALLOWED_MODULES = {
32
+ "pandas",
33
+ "numpy",
34
+ "matplotlib",
35
+ "plotly",
36
+ "sklearn",
37
+ "scipy",
38
+ }
39
+
40
+
41
+ # Wrap the real __import__ so only our allowlist can get in
42
+ _orig_import = builtins.__import__
43
+
44
+
45
+ def _safe_import(name, globals=None, locals=None, fromlist=(), level=0):
46
+ """
47
+ Wrap the real __import__ so only our allowlist can get in
48
+ """
49
+ root = name.split(".", 1)[0]
50
+ if root in _ALLOWED_MODULES:
51
+ return _orig_import(name, globals, locals, fromlist, level)
52
+ # If the module is not in the allowlist, raise an ImportError
53
+ raise ImportError(f"Import of module '{name}' is not allowed.")
54
+
55
+
56
+ class PlotAgentExecutionEnvironment:
57
+ """
58
+ Environment to safely execute LLM‑generated plotting code and capture `fig`.
59
+
60
+ Security features:
61
+ • Only allow imports from a fixed list of packages (pandas, numpy,
62
+ matplotlib, plotly, sklearn)
63
+ • AST scan rejects any import outside that list and any __dunder__ access
64
+ • Sandbox builtins to include only a minimal safe set + our _safe_import
65
+ • Enforce a 60 second timeout via signal.alarm
66
+ • Capture both stdout & stderr
67
+ • Purge any old `fig` between runs
68
+ """
69
+
70
+ TIMEOUT_SECONDS = 60
71
+
72
+ # A lean set of builtins, plus our safe-import hook
73
+ _SAFE_BUILTINS = {
74
+ "abs": abs,
75
+ "all": all,
76
+ "any": any,
77
+ "bin": bin,
78
+ "bool": bool,
79
+ "chr": chr,
80
+ "dict": dict,
81
+ "divmod": divmod,
82
+ "enumerate": enumerate,
83
+ "float": float,
84
+ "int": int,
85
+ "len": len,
86
+ "list": list,
87
+ "map": map,
88
+ "max": max,
89
+ "min": min,
90
+ "next": next,
91
+ "pow": pow,
92
+ "print": print,
93
+ "range": range,
94
+ "reversed": reversed,
95
+ "round": round,
96
+ "set": set,
97
+ "str": str,
98
+ "sum": sum,
99
+ "tuple": tuple,
100
+ "zip": zip,
101
+ # basic exceptions so user code can raise/catch
102
+ "BaseException": BaseException,
103
+ "Exception": Exception,
104
+ "ValueError": ValueError,
105
+ "TypeError": TypeError,
106
+ "NameError": NameError,
107
+ "IndexError": IndexError,
108
+ "KeyError": KeyError,
109
+ # our import guard
110
+ "__import__": _safe_import,
111
+ }
112
+
113
+ def __init__(self, df: pd.DataFrame):
114
+ """
115
+ Initialize the execution environment with a dataframe.
116
+ """
117
+ self.df = df
118
+ # Base namespace for both globals & locals
119
+ self._base_ns = {
120
+ "__builtins__": self._SAFE_BUILTINS,
121
+ "df": df,
122
+ "pd": pd,
123
+ "np": np,
124
+ "plt": plt,
125
+ "px": px,
126
+ "go": go,
127
+ "make_subplots": make_subplots,
128
+ }
129
+ self.fig = None
130
+
131
+ def _validate_ast(self, node: ast.AST):
132
+ """
133
+ Walk the AST and enforce:
134
+ • any Import/ImportFrom must be from _ALLOWED_MODULES
135
+ • no __dunder__ attribute access
136
+ """
137
+ # Walk the AST and enforce:
138
+ for child in ast.walk(node):
139
+ # Check for imports
140
+ if isinstance(child, ast.Import):
141
+ # Check for imports
142
+ for alias in child.names:
143
+ root = alias.name.split(".", 1)[0]
144
+ # Check if the module is in the allowlist
145
+ if root not in _ALLOWED_MODULES:
146
+ raise ValueError(f"Import of '{alias.name}' is not allowed.")
147
+ # Check for import-froms
148
+ elif isinstance(child, ast.ImportFrom):
149
+ root = (child.module or "").split(".", 1)[0]
150
+ if root not in _ALLOWED_MODULES:
151
+ raise ValueError(f"Import-from of '{child.module}' is not allowed.")
152
+ # Check for dunder attribute access
153
+ elif isinstance(child, ast.Attribute) and child.attr.startswith("__"):
154
+ raise ValueError("Access to dunder attributes is forbidden.")
155
+
156
+ def execute_code(self, generated_code: str):
157
+ """
158
+ Execute the user code in a locked‑down sandbox.
159
+
160
+ Returns a dict with:
161
+ - fig: The figure if created, else None
162
+ - output: Captured stdout
163
+ - error: Captured stderr or exception text
164
+ - success: True if fig was produced and no errors
165
+ """
166
+
167
+ # Copy the base namespace
168
+ ns = self._base_ns.copy()
169
+ # Purge any old `fig`
170
+ ns.pop("fig", None)
171
+
172
+ try:
173
+ # Parse the generated code
174
+ tree = ast.parse(generated_code)
175
+ # Validate the AST
176
+ self._validate_ast(tree)
177
+ except Exception as e:
178
+ # If the code is rejected on safety grounds, return an error
179
+ return {
180
+ "fig": None,
181
+ "output": "",
182
+ "error": f"Code rejected on safety grounds: {e}",
183
+ "success": False,
184
+ }
185
+
186
+ # Set a timeout
187
+ signal.signal(signal.SIGALRM, _timeout_handler)
188
+ signal.alarm(self.TIMEOUT_SECONDS)
189
+
190
+ # Execute the code
191
+ out_buf, err_buf = StringIO(), StringIO()
192
+ try:
193
+ # Redirect stdout and stderr
194
+ with contextlib.redirect_stdout(out_buf), contextlib.redirect_stderr(
195
+ err_buf
196
+ ):
197
+ # Execute the code
198
+ exec(generated_code, ns, ns)
199
+ except TimeoutError as te:
200
+ # If the code execution timed out, return an error
201
+ tb = traceback.format_exc()
202
+ return {
203
+ "fig": None,
204
+ "output": out_buf.getvalue(),
205
+ "error": f"Code execution timed out: {te}\n{tb}",
206
+ "success": False,
207
+ }
208
+ except Exception as e:
209
+ # If there was an error, return an error
210
+ tb = traceback.format_exc()
211
+ return {
212
+ "fig": None,
213
+ "output": out_buf.getvalue(),
214
+ "error": f"Error executing code: {e}\n{tb}",
215
+ "success": False,
216
+ }
217
+ finally:
218
+ # Reset the timeout
219
+ signal.alarm(0)
220
+
221
+ # Get the `fig`
222
+ fig = ns.get("fig")
223
+ self.fig = fig
224
+ if fig is None:
225
+ return {
226
+ "fig": None,
227
+ "output": out_buf.getvalue(),
228
+ "error": "No `fig` created. Assign your figure to a variable named `fig`.",
229
+ "success": False,
230
+ }
231
+
232
+ # Return the result
233
+ return {
234
+ "fig": fig,
235
+ "output": "Code executed successfully. 'fig' object was created.",
236
+ "error": "",
237
+ "success": True,
238
+ }
@@ -1,14 +1,21 @@
1
+ """
2
+ This module contains the models for the PlotAgent.
3
+ """
4
+
1
5
  from pydantic import BaseModel, Field
2
6
 
3
7
 
4
- # Define input schemas for the tools
5
8
  class PlotDescriptionInput(BaseModel):
9
+ """Model indicating that the plot_description function takes a plot_description argument."""
10
+
6
11
  plot_description: str = Field(
7
12
  ..., description="Description of the plot the user wants to create"
8
13
  )
9
14
 
10
15
 
11
16
  class GeneratedCodeInput(BaseModel):
17
+ """Model indicating that the generated_code function takes a generated_code argument."""
18
+
12
19
  generated_code: str = Field(
13
20
  ..., description="Python code that creates a Plotly figure"
14
21
  )
@@ -23,4 +30,4 @@ class DoesFigExistInput(BaseModel):
23
30
  class ViewGeneratedCodeInput(BaseModel):
24
31
  """Model indicating that the view_generated_code function takes no arguments."""
25
32
 
26
- pass
33
+ pass
@@ -1,3 +1,7 @@
1
+ """
2
+ This module contains the prompts for the PlotAgent.
3
+ """
4
+
1
5
  DEFAULT_SYSTEM_PROMPT = """
2
6
  You are an expert data visualization assistant that helps users create Plotly visualizations in Python.
3
7
  Your job is to generate Python and Plotly code based on the user's request that will create the desired visualization
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plot-agent
3
- Version: 0.2.2
3
+ Version: 0.3.1
4
4
  Summary: An AI-powered data visualization assistant using Plotly
5
5
  Author-email: andrewm4894 <andrewm4894@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
@@ -35,7 +35,7 @@ Here's a simple minimal example of how to use Plot Agent:
35
35
 
36
36
  ```python
37
37
  import pandas as pd
38
- from plot_agent.agent import PlotlyAgent
38
+ from plot_agent.agent import PlotAgent
39
39
 
40
40
  # ensure OPENAI_API_KEY is set and available for langchain
41
41
 
@@ -46,7 +46,7 @@ df = pd.DataFrame({
46
46
  })
47
47
 
48
48
  # Initialize the agent
49
- agent = PlotlyAgent()
49
+ agent = PlotAgent()
50
50
 
51
51
  # Set the dataframe
52
52
  agent.set_df(df)
@@ -9,5 +9,4 @@ plot_agent/prompt.py
9
9
  plot_agent.egg-info/PKG-INFO
10
10
  plot_agent.egg-info/SOURCES.txt
11
11
  plot_agent.egg-info/dependency_links.txt
12
- plot_agent.egg-info/top_level.txt
13
- tests/test_plot_agent.py
12
+ plot_agent.egg-info/top_level.txt
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "plot-agent"
7
- version = "0.2.2"
7
+ version = "0.3.1"
8
8
  authors = [
9
9
  { name="andrewm4894", email="andrewm4894@gmail.com" },
10
10
  ]
@@ -1,90 +0,0 @@
1
- import sys
2
- from io import StringIO
3
- import traceback
4
- import pandas as pd
5
- import plotly.express as px
6
- import plotly.graph_objects as go
7
- import numpy as np
8
- import matplotlib.pyplot as plt
9
- from plotly.subplots import make_subplots
10
- from typing import Dict, Any
11
-
12
-
13
- class PlotlyAgentExecutionEnvironment:
14
- """
15
- Environment to safely execute plotly code and capture the fig object.
16
-
17
- Args:
18
- df (pd.DataFrame): The dataframe to use for the execution environment.
19
- """
20
-
21
- def __init__(self, df: pd.DataFrame):
22
- """
23
- Initialize the execution environment with the given dataframe.
24
-
25
- Args:
26
- df (pd.DataFrame): The dataframe to use for the execution environment.
27
- """
28
- self.df = df
29
- self.locals_dict = {
30
- "df": df,
31
- "px": px,
32
- "go": go,
33
- "pd": pd,
34
- "np": np,
35
- "plt": plt,
36
- "make_subplots": make_subplots,
37
- }
38
- self.output = None
39
- self.error = None
40
- self.fig = None
41
-
42
- def execute_code(self, generated_code: str) -> Dict[str, Any]:
43
- """
44
- Execute the provided code and capture the fig object if created.
45
-
46
- Args:
47
- generated_code (str): The code to execute.
48
-
49
- Returns:
50
- Dict[str, Any]: A dictionary containing the fig object, output, error, and success status.
51
- """
52
- self.output = None
53
- self.error = None
54
-
55
- # Capture stdout
56
- old_stdout = sys.stdout
57
- sys.stdout = mystdout = StringIO()
58
-
59
- try:
60
- # Execute the code
61
- exec(generated_code, globals(), self.locals_dict)
62
-
63
- # Check if a fig object was created
64
- if "fig" in self.locals_dict:
65
- self.fig = self.locals_dict["fig"]
66
- self.output = "Code executed successfully. 'fig' object was created."
67
- else:
68
- print(f"no fig object created: {generated_code}")
69
- self.error = "Code executed without errors, but no 'fig' object was created. Make sure your code creates a variable named 'fig'."
70
-
71
- except Exception as e:
72
- self.error = f"Error executing code: {str(e)}\n{traceback.format_exc()}"
73
-
74
- finally:
75
- # Restore stdout
76
- sys.stdout = old_stdout
77
- captured_output = mystdout.getvalue()
78
-
79
- if captured_output.strip():
80
- if self.output:
81
- self.output += f"\nOutput:\n{captured_output}"
82
- else:
83
- self.output = f"Output:\n{captured_output}"
84
-
85
- return {
86
- "fig": self.fig,
87
- "output": self.output,
88
- "error": self.error,
89
- "success": self.error is None and self.fig is not None,
90
- }
@@ -1,471 +0,0 @@
1
- import pytest
2
- import pandas as pd
3
- import numpy as np
4
- from plot_agent.agent import PlotlyAgent
5
- from langchain_core.messages import HumanMessage, AIMessage
6
-
7
-
8
- def test_plotly_agent_initialization():
9
- """Test that PlotlyAgent initializes correctly."""
10
- agent = PlotlyAgent()
11
- assert agent.llm is not None
12
- assert agent.df is None
13
- assert agent.df_info is None
14
- assert agent.df_head is None
15
- assert agent.sql_query is None
16
- assert agent.execution_env is None
17
- assert agent.chat_history == []
18
- assert agent.agent_executor is None
19
- assert agent.generated_code is None
20
-
21
-
22
- def test_set_df():
23
- """Test that set_df properly sets up the dataframe and environment."""
24
- # Create a sample dataframe
25
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
26
-
27
- agent = PlotlyAgent()
28
- agent.set_df(df)
29
-
30
- assert agent.df is not None
31
- assert agent.df_info is not None
32
- assert agent.df_head is not None
33
- assert agent.execution_env is not None
34
- assert agent.agent_executor is not None
35
-
36
-
37
- def test_execute_plotly_code():
38
- """Test that execute_plotly_code works with valid code."""
39
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
40
-
41
- agent = PlotlyAgent()
42
- agent.set_df(df)
43
-
44
- # Test with valid plotly code
45
- valid_code = """import plotly.express as px
46
- fig = px.scatter(df, x='x', y='y')"""
47
-
48
- result = agent.execute_plotly_code(valid_code)
49
- assert "Code executed successfully" in result
50
- assert agent.execution_env.fig is not None
51
-
52
-
53
- def test_execute_plotly_code_with_error():
54
- """Test that execute_plotly_code handles errors properly."""
55
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
56
-
57
- agent = PlotlyAgent()
58
- agent.set_df(df)
59
-
60
- # Test with invalid code
61
- invalid_code = """import plotly.express as px
62
- fig = px.scatter(df, x='non_existent_column', y='y')"""
63
-
64
- result = agent.execute_plotly_code(invalid_code)
65
- assert "Error" in result
66
- assert agent.execution_env.fig is None
67
-
68
-
69
- def test_does_fig_exist():
70
- """Test that does_fig_exist correctly reports figure existence."""
71
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
72
-
73
- agent = PlotlyAgent()
74
- agent.set_df(df)
75
-
76
- # Initially no figure should exist
77
- assert "No figure has been created yet" in agent.does_fig_exist()
78
-
79
- # Create a figure
80
- valid_code = """import plotly.express as px
81
- fig = px.scatter(df, x='x', y='y')"""
82
- agent.execute_plotly_code(valid_code)
83
-
84
- # Now a figure should exist
85
- assert "A figure is available for display" in agent.does_fig_exist()
86
-
87
-
88
- def test_reset_conversation():
89
- """Test that reset_conversation clears the chat history."""
90
- agent = PlotlyAgent()
91
- agent.chat_history = ["message1", "message2"]
92
- agent.reset_conversation()
93
- assert agent.chat_history == []
94
-
95
-
96
- def test_view_generated_code():
97
- """Test that view_generated_code returns the last generated code."""
98
- agent = PlotlyAgent()
99
- test_code = "test code"
100
- agent.generated_code = test_code
101
- assert agent.view_generated_code() == test_code
102
-
103
-
104
- def test_get_figure():
105
- """Test that get_figure returns the current figure if it exists."""
106
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
107
-
108
- agent = PlotlyAgent()
109
- agent.set_df(df)
110
-
111
- # Initially no figure should exist
112
- assert agent.get_figure() is None
113
-
114
- # Create a figure
115
- valid_code = """import plotly.express as px
116
- fig = px.scatter(df, x='x', y='y')"""
117
- agent.execute_plotly_code(valid_code)
118
-
119
- # Now a figure should exist
120
- assert agent.get_figure() is not None
121
-
122
-
123
- def test_process_message():
124
- """Test that process_message updates chat history and handles responses."""
125
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
126
-
127
- agent = PlotlyAgent()
128
- agent.set_df(df, sql_query="SELECT x, y FROM df")
129
-
130
- # Test processing a message
131
- response = agent.process_message("Create a scatter plot")
132
-
133
- # Check that chat history was updated
134
- assert len(agent.chat_history) == 2 # One human message and one AI message
135
- assert isinstance(agent.chat_history[0], HumanMessage)
136
- assert isinstance(agent.chat_history[1], AIMessage)
137
- assert agent.chat_history[0].content == "Create a scatter plot"
138
-
139
-
140
- def test_execute_plotly_code_without_df():
141
- """Test that execute_plotly_code handles the case when no dataframe is set."""
142
- agent = PlotlyAgent()
143
- result = agent.execute_plotly_code("some code")
144
- assert "Error" in result and "No dataframe has been set" in result
145
-
146
-
147
- def test_set_df_with_sql_query():
148
- """Test that set_df properly handles SQL query context."""
149
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
150
-
151
- sql_query = "SELECT x, y FROM table"
152
- agent = PlotlyAgent()
153
- agent.set_df(df, sql_query=sql_query)
154
-
155
- assert agent.sql_query == sql_query
156
-
157
-
158
- def test_agent_initialization_with_custom_prompt():
159
- """Test agent initialization with custom system prompt."""
160
- custom_prompt = "Custom system prompt for testing"
161
- agent = PlotlyAgent(system_prompt=custom_prompt)
162
- assert agent.system_prompt == custom_prompt
163
-
164
-
165
- def test_agent_initialization_with_different_model():
166
- """Test agent initialization with different model names."""
167
- agent = PlotlyAgent(model="gpt-3.5-turbo")
168
- assert agent.llm.model_name == "gpt-3.5-turbo"
169
-
170
-
171
- def test_agent_initialization_with_verbose():
172
- """Test agent initialization with verbose settings."""
173
- agent = PlotlyAgent(verbose=False)
174
- assert agent.verbose == False
175
- assert agent.agent_executor is None # Agent executor not initialized yet
176
-
177
-
178
- def test_agent_initialization_with_max_iterations():
179
- """Test agent initialization with different max iterations."""
180
- agent = PlotlyAgent(max_iterations=5)
181
- assert agent.max_iterations == 5
182
-
183
-
184
- def test_agent_initialization_with_early_stopping():
185
- """Test agent initialization with different early stopping methods."""
186
- agent = PlotlyAgent(early_stopping_method="generate")
187
- assert agent.early_stopping_method == "generate"
188
-
189
-
190
- def test_process_empty_message():
191
- """Test processing of empty messages."""
192
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
193
-
194
- agent = PlotlyAgent()
195
- agent.set_df(df)
196
-
197
- response = agent.process_message("")
198
- assert len(agent.chat_history) == 2 # Should still create chat history entries
199
- assert isinstance(agent.chat_history[0], HumanMessage)
200
- assert isinstance(agent.chat_history[1], AIMessage)
201
-
202
-
203
- def test_process_message_with_code_blocks():
204
- """Test processing messages that contain code blocks."""
205
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
206
-
207
- agent = PlotlyAgent()
208
- agent.set_df(df)
209
-
210
- message = "Here's some code:\n```python\nprint('test')\n```"
211
- response = agent.process_message(message)
212
- assert len(agent.chat_history) == 2
213
- assert "```python" in agent.chat_history[0].content
214
-
215
-
216
- def test_execution_environment_with_different_plot_types():
217
- """Test execution environment with different types of plots."""
218
- df = pd.DataFrame(
219
- {
220
- "x": [1, 2, 3, 4, 5],
221
- "y": [10, 20, 30, 40, 50],
222
- "category": ["A", "B", "A", "B", "A"],
223
- }
224
- )
225
-
226
- agent = PlotlyAgent()
227
- agent.set_df(df)
228
-
229
- # Test scatter plot
230
- scatter_code = """import plotly.express as px
231
- fig = px.scatter(df, x='x', y='y')"""
232
- result = agent.execute_plotly_code(scatter_code)
233
- assert "Code executed successfully" in result
234
- assert agent.execution_env.fig is not None
235
-
236
- # Test bar plot
237
- bar_code = """import plotly.express as px
238
- fig = px.bar(df, x='category', y='y')"""
239
- result = agent.execute_plotly_code(bar_code)
240
- assert "Code executed successfully" in result
241
- assert agent.execution_env.fig is not None
242
-
243
- # Test line plot
244
- line_code = """import plotly.express as px
245
- fig = px.line(df, x='x', y='y')"""
246
- result = agent.execute_plotly_code(line_code)
247
- assert "Code executed successfully" in result
248
- assert agent.execution_env.fig is not None
249
-
250
-
251
- def test_execution_environment_with_subplots():
252
- """Test execution environment with subplots."""
253
- df = pd.DataFrame(
254
- {"x": [1, 2, 3, 4, 5], "y1": [10, 20, 30, 40, 50], "y2": [50, 40, 30, 20, 10]}
255
- )
256
-
257
- agent = PlotlyAgent()
258
- agent.set_df(df)
259
-
260
- subplot_code = """import plotly.subplots as sp
261
- import plotly.graph_objects as go
262
- fig = sp.make_subplots(rows=1, cols=2)
263
- fig.add_trace(go.Scatter(x=df['x'], y=df['y1']), row=1, col=1)
264
- fig.add_trace(go.Scatter(x=df['x'], y=df['y2']), row=1, col=2)"""
265
-
266
- result = agent.execute_plotly_code(subplot_code)
267
- assert "Code executed successfully" in result
268
- assert agent.execution_env.fig is not None
269
-
270
-
271
- def test_execution_environment_with_data_preprocessing():
272
- """Test execution environment with data preprocessing steps."""
273
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
274
-
275
- agent = PlotlyAgent()
276
- agent.set_df(df)
277
-
278
- preprocessing_code = """import plotly.express as px
279
- # Preprocessing steps
280
- df['y_normalized'] = (df['y'] - df['y'].min()) / (df['y'].max() - df['y'].min())
281
- fig = px.scatter(df, x='x', y='y_normalized')"""
282
-
283
- result = agent.execute_plotly_code(preprocessing_code)
284
- assert "Code executed successfully" in result
285
- assert agent.execution_env.fig is not None
286
-
287
-
288
- def test_handle_syntax_error():
289
- """Test handling of syntax errors in generated code."""
290
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
291
-
292
- agent = PlotlyAgent()
293
- agent.set_df(df)
294
-
295
- invalid_code = """import plotly.express as px
296
- fig = px.scatter(df, x='x', y='y' # Missing closing parenthesis"""
297
-
298
- result = agent.execute_plotly_code(invalid_code)
299
- assert "Error" in result
300
- assert "SyntaxError" in result
301
- assert agent.execution_env.fig is None
302
-
303
-
304
- def test_handle_runtime_error():
305
- """Test handling of runtime errors in generated code."""
306
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
307
-
308
- agent = PlotlyAgent()
309
- agent.set_df(df)
310
-
311
- error_code = """import plotly.express as px
312
- fig = px.scatter(df, x='x', y='y', color='non_existent_column')"""
313
-
314
- result = agent.execute_plotly_code(error_code)
315
- assert "Error" in result
316
- assert agent.execution_env.fig is None
317
-
318
-
319
- def test_tool_interaction():
320
- """Test interaction between different tools."""
321
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
322
-
323
- agent = PlotlyAgent()
324
- agent.set_df(df)
325
-
326
- # First check if figure exists (should not)
327
- assert "No figure has been created yet" in agent.does_fig_exist()
328
-
329
- # Generate and execute code
330
- code = """import plotly.express as px
331
- fig = px.scatter(df, x='x', y='y')"""
332
- result = agent.execute_plotly_code(code)
333
- assert "Code executed successfully" in result
334
-
335
- # Check if figure exists (should now exist)
336
- assert "A figure is available for display" in agent.does_fig_exist()
337
-
338
- # View the generated code
339
- assert code in agent.view_generated_code()
340
-
341
-
342
- def test_tool_validation():
343
- """Test validation of tool inputs."""
344
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
345
-
346
- agent = PlotlyAgent()
347
- agent.set_df(df)
348
-
349
- # Test with invalid code (empty string)
350
- result = agent.execute_plotly_code("")
351
- assert "Error" in result
352
-
353
- # Test with invalid code (None)
354
- with pytest.raises(AssertionError):
355
- agent.execute_plotly_code(None)
356
-
357
-
358
- def test_tool_response_formatting():
359
- """Test formatting of tool responses."""
360
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
361
-
362
- agent = PlotlyAgent()
363
- agent.set_df(df)
364
-
365
- # Test execute_plotly_code response format
366
- code = """import plotly.express as px
367
- fig = px.scatter(df, x='x', y='y')"""
368
- result = agent.execute_plotly_code(code)
369
- assert isinstance(result, str)
370
- assert "Code executed successfully" in result
371
-
372
- # Test does_fig_exist response format
373
- result = agent.does_fig_exist()
374
- assert isinstance(result, str)
375
- assert "figure" in result.lower()
376
-
377
- # Test view_generated_code response format
378
- result = agent.view_generated_code()
379
- assert isinstance(result, str)
380
- assert code in result
381
-
382
-
383
- def test_memory_cleanup():
384
- """Test memory cleanup after multiple plot generations."""
385
- df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
386
-
387
- agent = PlotlyAgent()
388
- agent.set_df(df)
389
-
390
- # Generate multiple plots
391
- for i in range(5):
392
- code = f"""import plotly.express as px
393
- fig = px.scatter(df, x='x', y='y', title='Plot {i}')"""
394
- result = agent.execute_plotly_code(code)
395
- assert "Code executed successfully" in result
396
- assert agent.execution_env.fig is not None
397
-
398
- # Reset conversation and check memory
399
- agent.reset_conversation()
400
- assert len(agent.chat_history) == 0
401
- assert agent.generated_code is None
402
-
403
-
404
- def test_large_dataframe_handling():
405
- """Test handling of large dataframes."""
406
- # Create a large dataframe
407
- df = pd.DataFrame({"x": range(10000), "y": range(10000)})
408
-
409
- agent = PlotlyAgent()
410
- agent.set_df(df)
411
-
412
- # Test plot generation with large dataframe
413
- code = """import plotly.express as px
414
- fig = px.scatter(df, x='x', y='y')"""
415
- result = agent.execute_plotly_code(code)
416
- assert "Code executed successfully" in result
417
- assert agent.execution_env.fig is not None
418
-
419
-
420
- def test_input_validation():
421
- """Test validation of input parameters."""
422
- # Test invalid dataframe input
423
- with pytest.raises(AssertionError):
424
- agent = PlotlyAgent()
425
- agent.set_df("not a dataframe")
426
-
427
- # Test invalid SQL query input
428
- df = pd.DataFrame({"x": [1, 2, 3]})
429
- agent = PlotlyAgent()
430
- with pytest.raises(AssertionError):
431
- agent.set_df(df, sql_query=123) # SQL query should be string
432
-
433
- # Test invalid message input
434
- agent.set_df(df)
435
- with pytest.raises(AssertionError):
436
- agent.process_message(123) # Message should be string
437
-
438
- # Test invalid code input
439
- with pytest.raises(AssertionError):
440
- agent.execute_plotly_code(123) # Code should be string
441
-
442
-
443
- def test_complex_plot_handling():
444
- """Test handling of complex plots with multiple traces and layouts."""
445
- df = pd.DataFrame(
446
- {
447
- "x": [1, 2, 3, 4, 5],
448
- "y1": [10, 20, 30, 40, 50],
449
- "y2": [50, 40, 30, 20, 10],
450
- "category": ["A", "B", "A", "B", "A"],
451
- }
452
- )
453
-
454
- agent = PlotlyAgent()
455
- agent.set_df(df)
456
-
457
- complex_code = """import plotly.graph_objects as go
458
- fig = go.Figure()
459
- fig.add_trace(go.Scatter(x=df['x'], y=df['y1'], name='Trace 1'))
460
- fig.add_trace(go.Scatter(x=df['x'], y=df['y2'], name='Trace 2'))
461
- fig.update_layout(
462
- title='Complex Plot',
463
- xaxis_title='X Axis',
464
- yaxis_title='Y Axis',
465
- showlegend=True,
466
- template='plotly_white'
467
- )"""
468
-
469
- result = agent.execute_plotly_code(complex_code)
470
- assert "Code executed successfully" in result
471
- assert agent.execution_env.fig is not None
File without changes
File without changes