plot-agent 0.2.2__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {plot_agent-0.2.2 → plot_agent-0.3.1}/PKG-INFO +3 -3
- {plot_agent-0.2.2 → plot_agent-0.3.1}/README.md +2 -2
- {plot_agent-0.2.2 → plot_agent-0.3.1}/plot_agent/agent.py +22 -8
- plot_agent-0.3.1/plot_agent/execution.py +238 -0
- {plot_agent-0.2.2 → plot_agent-0.3.1}/plot_agent/models.py +9 -2
- {plot_agent-0.2.2 → plot_agent-0.3.1}/plot_agent/prompt.py +4 -0
- {plot_agent-0.2.2 → plot_agent-0.3.1}/plot_agent.egg-info/PKG-INFO +3 -3
- {plot_agent-0.2.2 → plot_agent-0.3.1}/plot_agent.egg-info/SOURCES.txt +1 -2
- {plot_agent-0.2.2 → plot_agent-0.3.1}/pyproject.toml +1 -1
- plot_agent-0.2.2/plot_agent/execution.py +0 -90
- plot_agent-0.2.2/tests/test_plot_agent.py +0 -471
- {plot_agent-0.2.2 → plot_agent-0.3.1}/LICENSE +0 -0
- {plot_agent-0.2.2 → plot_agent-0.3.1}/plot_agent/__init__.py +0 -0
- {plot_agent-0.2.2 → plot_agent-0.3.1}/plot_agent.egg-info/dependency_links.txt +0 -0
- {plot_agent-0.2.2 → plot_agent-0.3.1}/plot_agent.egg-info/top_level.txt +0 -0
- {plot_agent-0.2.2 → plot_agent-0.3.1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: plot-agent
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: An AI-powered data visualization assistant using Plotly
|
|
5
5
|
Author-email: andrewm4894 <andrewm4894@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
|
|
@@ -35,7 +35,7 @@ Here's a simple minimal example of how to use Plot Agent:
|
|
|
35
35
|
|
|
36
36
|
```python
|
|
37
37
|
import pandas as pd
|
|
38
|
-
from plot_agent.agent import
|
|
38
|
+
from plot_agent.agent import PlotAgent
|
|
39
39
|
|
|
40
40
|
# ensure OPENAI_API_KEY is set and available for langchain
|
|
41
41
|
|
|
@@ -46,7 +46,7 @@ df = pd.DataFrame({
|
|
|
46
46
|
})
|
|
47
47
|
|
|
48
48
|
# Initialize the agent
|
|
49
|
-
agent =
|
|
49
|
+
agent = PlotAgent()
|
|
50
50
|
|
|
51
51
|
# Set the dataframe
|
|
52
52
|
agent.set_df(df)
|
|
@@ -21,7 +21,7 @@ Here's a simple minimal example of how to use Plot Agent:
|
|
|
21
21
|
|
|
22
22
|
```python
|
|
23
23
|
import pandas as pd
|
|
24
|
-
from plot_agent.agent import
|
|
24
|
+
from plot_agent.agent import PlotAgent
|
|
25
25
|
|
|
26
26
|
# ensure OPENAI_API_KEY is set and available for langchain
|
|
27
27
|
|
|
@@ -32,7 +32,7 @@ df = pd.DataFrame({
|
|
|
32
32
|
})
|
|
33
33
|
|
|
34
34
|
# Initialize the agent
|
|
35
|
-
agent =
|
|
35
|
+
agent = PlotAgent()
|
|
36
36
|
|
|
37
37
|
# Set the dataframe
|
|
38
38
|
agent.set_df(df)
|
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the PlotAgent class, which is used to generate Plotly code based on a user's plot description.
|
|
3
|
+
"""
|
|
4
|
+
|
|
1
5
|
import pandas as pd
|
|
2
6
|
from io import StringIO
|
|
3
7
|
from typing import Optional
|
|
@@ -14,10 +18,10 @@ from plot_agent.models import (
|
|
|
14
18
|
DoesFigExistInput,
|
|
15
19
|
ViewGeneratedCodeInput,
|
|
16
20
|
)
|
|
17
|
-
from plot_agent.execution import
|
|
21
|
+
from plot_agent.execution import PlotAgentExecutionEnvironment
|
|
18
22
|
|
|
19
23
|
|
|
20
|
-
class
|
|
24
|
+
class PlotAgent:
|
|
21
25
|
"""
|
|
22
26
|
A class that uses an LLM to generate Plotly code based on a user's plot description.
|
|
23
27
|
"""
|
|
@@ -32,7 +36,7 @@ class PlotlyAgent:
|
|
|
32
36
|
handle_parsing_errors: bool = True,
|
|
33
37
|
):
|
|
34
38
|
"""
|
|
35
|
-
Initialize the
|
|
39
|
+
Initialize the PlotAgent.
|
|
36
40
|
|
|
37
41
|
Args:
|
|
38
42
|
model (str): The model to use for the LLM.
|
|
@@ -90,7 +94,7 @@ class PlotlyAgent:
|
|
|
90
94
|
self.sql_query = sql_query
|
|
91
95
|
|
|
92
96
|
# Initialize execution environment
|
|
93
|
-
self.execution_env =
|
|
97
|
+
self.execution_env = PlotAgentExecutionEnvironment(df)
|
|
94
98
|
|
|
95
99
|
# Initialize the agent with tools
|
|
96
100
|
self._initialize_agent()
|
|
@@ -123,7 +127,7 @@ class PlotlyAgent:
|
|
|
123
127
|
|
|
124
128
|
# Check if the code executed successfully
|
|
125
129
|
if code_execution_success:
|
|
126
|
-
return f"
|
|
130
|
+
return f"Success: {code_execution_output}"
|
|
127
131
|
else:
|
|
128
132
|
return f"Error: {code_execution_error}\n{code_execution_output}"
|
|
129
133
|
|
|
@@ -160,19 +164,29 @@ class PlotlyAgent:
|
|
|
160
164
|
Tool.from_function(
|
|
161
165
|
func=self.execute_plotly_code,
|
|
162
166
|
name="execute_plotly_code",
|
|
163
|
-
description=
|
|
167
|
+
description=(
|
|
168
|
+
"Execute the provided Plotly code and return a result indicating "
|
|
169
|
+
"if the code executed successfully and if a figure object was created."
|
|
170
|
+
),
|
|
164
171
|
args_schema=GeneratedCodeInput,
|
|
165
172
|
),
|
|
166
173
|
StructuredTool.from_function(
|
|
167
174
|
func=self.does_fig_exist,
|
|
168
175
|
name="does_fig_exist",
|
|
169
|
-
description=
|
|
176
|
+
description=(
|
|
177
|
+
"Check if a figure exists and is available for display. "
|
|
178
|
+
"This tool takes no arguments and returns a string indicating "
|
|
179
|
+
"if a figure is available for display or not."
|
|
180
|
+
),
|
|
170
181
|
args_schema=DoesFigExistInput,
|
|
171
182
|
),
|
|
172
183
|
StructuredTool.from_function(
|
|
173
184
|
func=self.view_generated_code,
|
|
174
185
|
name="view_generated_code",
|
|
175
|
-
description=
|
|
186
|
+
description=(
|
|
187
|
+
"View the generated code. "
|
|
188
|
+
"This tool takes no arguments and returns the generated code as a string."
|
|
189
|
+
),
|
|
176
190
|
args_schema=ViewGeneratedCodeInput,
|
|
177
191
|
),
|
|
178
192
|
]
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the PlotAgentExecutionEnvironment class, which is used to safely execute LLM‑generated plotting code and capture `fig`.
|
|
3
|
+
|
|
4
|
+
Security features:
|
|
5
|
+
• Only allow imports from a fixed list of packages (pandas, numpy,
|
|
6
|
+
matplotlib, plotly, sklearn)
|
|
7
|
+
• AST scan rejects any import outside that list and any __dunder__ access
|
|
8
|
+
• Sandbox builtins to include only a minimal safe set + our _safe_import
|
|
9
|
+
• Enforce a 60 second timeout via signal.alarm
|
|
10
|
+
"""
|
|
11
|
+
import ast
|
|
12
|
+
import builtins
|
|
13
|
+
import signal
|
|
14
|
+
import traceback
|
|
15
|
+
from io import StringIO
|
|
16
|
+
import contextlib
|
|
17
|
+
|
|
18
|
+
import pandas as pd
|
|
19
|
+
import numpy as np
|
|
20
|
+
import matplotlib.pyplot as plt
|
|
21
|
+
import plotly.express as px
|
|
22
|
+
import plotly.graph_objects as go
|
|
23
|
+
from plotly.subplots import make_subplots
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _timeout_handler(signum, frame):
|
|
27
|
+
raise TimeoutError("Code execution timed out")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# List of allowed modules
|
|
31
|
+
_ALLOWED_MODULES = {
|
|
32
|
+
"pandas",
|
|
33
|
+
"numpy",
|
|
34
|
+
"matplotlib",
|
|
35
|
+
"plotly",
|
|
36
|
+
"sklearn",
|
|
37
|
+
"scipy",
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Wrap the real __import__ so only our allowlist can get in
|
|
42
|
+
_orig_import = builtins.__import__
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _safe_import(name, globals=None, locals=None, fromlist=(), level=0):
|
|
46
|
+
"""
|
|
47
|
+
Wrap the real __import__ so only our allowlist can get in
|
|
48
|
+
"""
|
|
49
|
+
root = name.split(".", 1)[0]
|
|
50
|
+
if root in _ALLOWED_MODULES:
|
|
51
|
+
return _orig_import(name, globals, locals, fromlist, level)
|
|
52
|
+
# If the module is not in the allowlist, raise an ImportError
|
|
53
|
+
raise ImportError(f"Import of module '{name}' is not allowed.")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PlotAgentExecutionEnvironment:
|
|
57
|
+
"""
|
|
58
|
+
Environment to safely execute LLM‑generated plotting code and capture `fig`.
|
|
59
|
+
|
|
60
|
+
Security features:
|
|
61
|
+
• Only allow imports from a fixed list of packages (pandas, numpy,
|
|
62
|
+
matplotlib, plotly, sklearn)
|
|
63
|
+
• AST scan rejects any import outside that list and any __dunder__ access
|
|
64
|
+
• Sandbox builtins to include only a minimal safe set + our _safe_import
|
|
65
|
+
• Enforce a 60 second timeout via signal.alarm
|
|
66
|
+
• Capture both stdout & stderr
|
|
67
|
+
• Purge any old `fig` between runs
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
TIMEOUT_SECONDS = 60
|
|
71
|
+
|
|
72
|
+
# A lean set of builtins, plus our safe-import hook
|
|
73
|
+
_SAFE_BUILTINS = {
|
|
74
|
+
"abs": abs,
|
|
75
|
+
"all": all,
|
|
76
|
+
"any": any,
|
|
77
|
+
"bin": bin,
|
|
78
|
+
"bool": bool,
|
|
79
|
+
"chr": chr,
|
|
80
|
+
"dict": dict,
|
|
81
|
+
"divmod": divmod,
|
|
82
|
+
"enumerate": enumerate,
|
|
83
|
+
"float": float,
|
|
84
|
+
"int": int,
|
|
85
|
+
"len": len,
|
|
86
|
+
"list": list,
|
|
87
|
+
"map": map,
|
|
88
|
+
"max": max,
|
|
89
|
+
"min": min,
|
|
90
|
+
"next": next,
|
|
91
|
+
"pow": pow,
|
|
92
|
+
"print": print,
|
|
93
|
+
"range": range,
|
|
94
|
+
"reversed": reversed,
|
|
95
|
+
"round": round,
|
|
96
|
+
"set": set,
|
|
97
|
+
"str": str,
|
|
98
|
+
"sum": sum,
|
|
99
|
+
"tuple": tuple,
|
|
100
|
+
"zip": zip,
|
|
101
|
+
# basic exceptions so user code can raise/catch
|
|
102
|
+
"BaseException": BaseException,
|
|
103
|
+
"Exception": Exception,
|
|
104
|
+
"ValueError": ValueError,
|
|
105
|
+
"TypeError": TypeError,
|
|
106
|
+
"NameError": NameError,
|
|
107
|
+
"IndexError": IndexError,
|
|
108
|
+
"KeyError": KeyError,
|
|
109
|
+
# our import guard
|
|
110
|
+
"__import__": _safe_import,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
def __init__(self, df: pd.DataFrame):
|
|
114
|
+
"""
|
|
115
|
+
Initialize the execution environment with a dataframe.
|
|
116
|
+
"""
|
|
117
|
+
self.df = df
|
|
118
|
+
# Base namespace for both globals & locals
|
|
119
|
+
self._base_ns = {
|
|
120
|
+
"__builtins__": self._SAFE_BUILTINS,
|
|
121
|
+
"df": df,
|
|
122
|
+
"pd": pd,
|
|
123
|
+
"np": np,
|
|
124
|
+
"plt": plt,
|
|
125
|
+
"px": px,
|
|
126
|
+
"go": go,
|
|
127
|
+
"make_subplots": make_subplots,
|
|
128
|
+
}
|
|
129
|
+
self.fig = None
|
|
130
|
+
|
|
131
|
+
def _validate_ast(self, node: ast.AST):
|
|
132
|
+
"""
|
|
133
|
+
Walk the AST and enforce:
|
|
134
|
+
• any Import/ImportFrom must be from _ALLOWED_MODULES
|
|
135
|
+
• no __dunder__ attribute access
|
|
136
|
+
"""
|
|
137
|
+
# Walk the AST and enforce:
|
|
138
|
+
for child in ast.walk(node):
|
|
139
|
+
# Check for imports
|
|
140
|
+
if isinstance(child, ast.Import):
|
|
141
|
+
# Check for imports
|
|
142
|
+
for alias in child.names:
|
|
143
|
+
root = alias.name.split(".", 1)[0]
|
|
144
|
+
# Check if the module is in the allowlist
|
|
145
|
+
if root not in _ALLOWED_MODULES:
|
|
146
|
+
raise ValueError(f"Import of '{alias.name}' is not allowed.")
|
|
147
|
+
# Check for import-froms
|
|
148
|
+
elif isinstance(child, ast.ImportFrom):
|
|
149
|
+
root = (child.module or "").split(".", 1)[0]
|
|
150
|
+
if root not in _ALLOWED_MODULES:
|
|
151
|
+
raise ValueError(f"Import-from of '{child.module}' is not allowed.")
|
|
152
|
+
# Check for dunder attribute access
|
|
153
|
+
elif isinstance(child, ast.Attribute) and child.attr.startswith("__"):
|
|
154
|
+
raise ValueError("Access to dunder attributes is forbidden.")
|
|
155
|
+
|
|
156
|
+
def execute_code(self, generated_code: str):
|
|
157
|
+
"""
|
|
158
|
+
Execute the user code in a locked‑down sandbox.
|
|
159
|
+
|
|
160
|
+
Returns a dict with:
|
|
161
|
+
- fig: The figure if created, else None
|
|
162
|
+
- output: Captured stdout
|
|
163
|
+
- error: Captured stderr or exception text
|
|
164
|
+
- success: True if fig was produced and no errors
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
# Copy the base namespace
|
|
168
|
+
ns = self._base_ns.copy()
|
|
169
|
+
# Purge any old `fig`
|
|
170
|
+
ns.pop("fig", None)
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
# Parse the generated code
|
|
174
|
+
tree = ast.parse(generated_code)
|
|
175
|
+
# Validate the AST
|
|
176
|
+
self._validate_ast(tree)
|
|
177
|
+
except Exception as e:
|
|
178
|
+
# If the code is rejected on safety grounds, return an error
|
|
179
|
+
return {
|
|
180
|
+
"fig": None,
|
|
181
|
+
"output": "",
|
|
182
|
+
"error": f"Code rejected on safety grounds: {e}",
|
|
183
|
+
"success": False,
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
# Set a timeout
|
|
187
|
+
signal.signal(signal.SIGALRM, _timeout_handler)
|
|
188
|
+
signal.alarm(self.TIMEOUT_SECONDS)
|
|
189
|
+
|
|
190
|
+
# Execute the code
|
|
191
|
+
out_buf, err_buf = StringIO(), StringIO()
|
|
192
|
+
try:
|
|
193
|
+
# Redirect stdout and stderr
|
|
194
|
+
with contextlib.redirect_stdout(out_buf), contextlib.redirect_stderr(
|
|
195
|
+
err_buf
|
|
196
|
+
):
|
|
197
|
+
# Execute the code
|
|
198
|
+
exec(generated_code, ns, ns)
|
|
199
|
+
except TimeoutError as te:
|
|
200
|
+
# If the code execution timed out, return an error
|
|
201
|
+
tb = traceback.format_exc()
|
|
202
|
+
return {
|
|
203
|
+
"fig": None,
|
|
204
|
+
"output": out_buf.getvalue(),
|
|
205
|
+
"error": f"Code execution timed out: {te}\n{tb}",
|
|
206
|
+
"success": False,
|
|
207
|
+
}
|
|
208
|
+
except Exception as e:
|
|
209
|
+
# If there was an error, return an error
|
|
210
|
+
tb = traceback.format_exc()
|
|
211
|
+
return {
|
|
212
|
+
"fig": None,
|
|
213
|
+
"output": out_buf.getvalue(),
|
|
214
|
+
"error": f"Error executing code: {e}\n{tb}",
|
|
215
|
+
"success": False,
|
|
216
|
+
}
|
|
217
|
+
finally:
|
|
218
|
+
# Reset the timeout
|
|
219
|
+
signal.alarm(0)
|
|
220
|
+
|
|
221
|
+
# Get the `fig`
|
|
222
|
+
fig = ns.get("fig")
|
|
223
|
+
self.fig = fig
|
|
224
|
+
if fig is None:
|
|
225
|
+
return {
|
|
226
|
+
"fig": None,
|
|
227
|
+
"output": out_buf.getvalue(),
|
|
228
|
+
"error": "No `fig` created. Assign your figure to a variable named `fig`.",
|
|
229
|
+
"success": False,
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
# Return the result
|
|
233
|
+
return {
|
|
234
|
+
"fig": fig,
|
|
235
|
+
"output": "Code executed successfully. 'fig' object was created.",
|
|
236
|
+
"error": "",
|
|
237
|
+
"success": True,
|
|
238
|
+
}
|
|
@@ -1,14 +1,21 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the models for the PlotAgent.
|
|
3
|
+
"""
|
|
4
|
+
|
|
1
5
|
from pydantic import BaseModel, Field
|
|
2
6
|
|
|
3
7
|
|
|
4
|
-
# Define input schemas for the tools
|
|
5
8
|
class PlotDescriptionInput(BaseModel):
|
|
9
|
+
"""Model indicating that the plot_description function takes a plot_description argument."""
|
|
10
|
+
|
|
6
11
|
plot_description: str = Field(
|
|
7
12
|
..., description="Description of the plot the user wants to create"
|
|
8
13
|
)
|
|
9
14
|
|
|
10
15
|
|
|
11
16
|
class GeneratedCodeInput(BaseModel):
|
|
17
|
+
"""Model indicating that the generated_code function takes a generated_code argument."""
|
|
18
|
+
|
|
12
19
|
generated_code: str = Field(
|
|
13
20
|
..., description="Python code that creates a Plotly figure"
|
|
14
21
|
)
|
|
@@ -23,4 +30,4 @@ class DoesFigExistInput(BaseModel):
|
|
|
23
30
|
class ViewGeneratedCodeInput(BaseModel):
|
|
24
31
|
"""Model indicating that the view_generated_code function takes no arguments."""
|
|
25
32
|
|
|
26
|
-
pass
|
|
33
|
+
pass
|
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the prompts for the PlotAgent.
|
|
3
|
+
"""
|
|
4
|
+
|
|
1
5
|
DEFAULT_SYSTEM_PROMPT = """
|
|
2
6
|
You are an expert data visualization assistant that helps users create Plotly visualizations in Python.
|
|
3
7
|
Your job is to generate Python and Plotly code based on the user's request that will create the desired visualization
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: plot-agent
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: An AI-powered data visualization assistant using Plotly
|
|
5
5
|
Author-email: andrewm4894 <andrewm4894@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
|
|
@@ -35,7 +35,7 @@ Here's a simple minimal example of how to use Plot Agent:
|
|
|
35
35
|
|
|
36
36
|
```python
|
|
37
37
|
import pandas as pd
|
|
38
|
-
from plot_agent.agent import
|
|
38
|
+
from plot_agent.agent import PlotAgent
|
|
39
39
|
|
|
40
40
|
# ensure OPENAI_API_KEY is set and available for langchain
|
|
41
41
|
|
|
@@ -46,7 +46,7 @@ df = pd.DataFrame({
|
|
|
46
46
|
})
|
|
47
47
|
|
|
48
48
|
# Initialize the agent
|
|
49
|
-
agent =
|
|
49
|
+
agent = PlotAgent()
|
|
50
50
|
|
|
51
51
|
# Set the dataframe
|
|
52
52
|
agent.set_df(df)
|
|
@@ -1,90 +0,0 @@
|
|
|
1
|
-
import sys
|
|
2
|
-
from io import StringIO
|
|
3
|
-
import traceback
|
|
4
|
-
import pandas as pd
|
|
5
|
-
import plotly.express as px
|
|
6
|
-
import plotly.graph_objects as go
|
|
7
|
-
import numpy as np
|
|
8
|
-
import matplotlib.pyplot as plt
|
|
9
|
-
from plotly.subplots import make_subplots
|
|
10
|
-
from typing import Dict, Any
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class PlotlyAgentExecutionEnvironment:
|
|
14
|
-
"""
|
|
15
|
-
Environment to safely execute plotly code and capture the fig object.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
df (pd.DataFrame): The dataframe to use for the execution environment.
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
def __init__(self, df: pd.DataFrame):
|
|
22
|
-
"""
|
|
23
|
-
Initialize the execution environment with the given dataframe.
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
df (pd.DataFrame): The dataframe to use for the execution environment.
|
|
27
|
-
"""
|
|
28
|
-
self.df = df
|
|
29
|
-
self.locals_dict = {
|
|
30
|
-
"df": df,
|
|
31
|
-
"px": px,
|
|
32
|
-
"go": go,
|
|
33
|
-
"pd": pd,
|
|
34
|
-
"np": np,
|
|
35
|
-
"plt": plt,
|
|
36
|
-
"make_subplots": make_subplots,
|
|
37
|
-
}
|
|
38
|
-
self.output = None
|
|
39
|
-
self.error = None
|
|
40
|
-
self.fig = None
|
|
41
|
-
|
|
42
|
-
def execute_code(self, generated_code: str) -> Dict[str, Any]:
|
|
43
|
-
"""
|
|
44
|
-
Execute the provided code and capture the fig object if created.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
generated_code (str): The code to execute.
|
|
48
|
-
|
|
49
|
-
Returns:
|
|
50
|
-
Dict[str, Any]: A dictionary containing the fig object, output, error, and success status.
|
|
51
|
-
"""
|
|
52
|
-
self.output = None
|
|
53
|
-
self.error = None
|
|
54
|
-
|
|
55
|
-
# Capture stdout
|
|
56
|
-
old_stdout = sys.stdout
|
|
57
|
-
sys.stdout = mystdout = StringIO()
|
|
58
|
-
|
|
59
|
-
try:
|
|
60
|
-
# Execute the code
|
|
61
|
-
exec(generated_code, globals(), self.locals_dict)
|
|
62
|
-
|
|
63
|
-
# Check if a fig object was created
|
|
64
|
-
if "fig" in self.locals_dict:
|
|
65
|
-
self.fig = self.locals_dict["fig"]
|
|
66
|
-
self.output = "Code executed successfully. 'fig' object was created."
|
|
67
|
-
else:
|
|
68
|
-
print(f"no fig object created: {generated_code}")
|
|
69
|
-
self.error = "Code executed without errors, but no 'fig' object was created. Make sure your code creates a variable named 'fig'."
|
|
70
|
-
|
|
71
|
-
except Exception as e:
|
|
72
|
-
self.error = f"Error executing code: {str(e)}\n{traceback.format_exc()}"
|
|
73
|
-
|
|
74
|
-
finally:
|
|
75
|
-
# Restore stdout
|
|
76
|
-
sys.stdout = old_stdout
|
|
77
|
-
captured_output = mystdout.getvalue()
|
|
78
|
-
|
|
79
|
-
if captured_output.strip():
|
|
80
|
-
if self.output:
|
|
81
|
-
self.output += f"\nOutput:\n{captured_output}"
|
|
82
|
-
else:
|
|
83
|
-
self.output = f"Output:\n{captured_output}"
|
|
84
|
-
|
|
85
|
-
return {
|
|
86
|
-
"fig": self.fig,
|
|
87
|
-
"output": self.output,
|
|
88
|
-
"error": self.error,
|
|
89
|
-
"success": self.error is None and self.fig is not None,
|
|
90
|
-
}
|
|
@@ -1,471 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
import pandas as pd
|
|
3
|
-
import numpy as np
|
|
4
|
-
from plot_agent.agent import PlotlyAgent
|
|
5
|
-
from langchain_core.messages import HumanMessage, AIMessage
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def test_plotly_agent_initialization():
|
|
9
|
-
"""Test that PlotlyAgent initializes correctly."""
|
|
10
|
-
agent = PlotlyAgent()
|
|
11
|
-
assert agent.llm is not None
|
|
12
|
-
assert agent.df is None
|
|
13
|
-
assert agent.df_info is None
|
|
14
|
-
assert agent.df_head is None
|
|
15
|
-
assert agent.sql_query is None
|
|
16
|
-
assert agent.execution_env is None
|
|
17
|
-
assert agent.chat_history == []
|
|
18
|
-
assert agent.agent_executor is None
|
|
19
|
-
assert agent.generated_code is None
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def test_set_df():
|
|
23
|
-
"""Test that set_df properly sets up the dataframe and environment."""
|
|
24
|
-
# Create a sample dataframe
|
|
25
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
26
|
-
|
|
27
|
-
agent = PlotlyAgent()
|
|
28
|
-
agent.set_df(df)
|
|
29
|
-
|
|
30
|
-
assert agent.df is not None
|
|
31
|
-
assert agent.df_info is not None
|
|
32
|
-
assert agent.df_head is not None
|
|
33
|
-
assert agent.execution_env is not None
|
|
34
|
-
assert agent.agent_executor is not None
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def test_execute_plotly_code():
|
|
38
|
-
"""Test that execute_plotly_code works with valid code."""
|
|
39
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
40
|
-
|
|
41
|
-
agent = PlotlyAgent()
|
|
42
|
-
agent.set_df(df)
|
|
43
|
-
|
|
44
|
-
# Test with valid plotly code
|
|
45
|
-
valid_code = """import plotly.express as px
|
|
46
|
-
fig = px.scatter(df, x='x', y='y')"""
|
|
47
|
-
|
|
48
|
-
result = agent.execute_plotly_code(valid_code)
|
|
49
|
-
assert "Code executed successfully" in result
|
|
50
|
-
assert agent.execution_env.fig is not None
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def test_execute_plotly_code_with_error():
|
|
54
|
-
"""Test that execute_plotly_code handles errors properly."""
|
|
55
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
56
|
-
|
|
57
|
-
agent = PlotlyAgent()
|
|
58
|
-
agent.set_df(df)
|
|
59
|
-
|
|
60
|
-
# Test with invalid code
|
|
61
|
-
invalid_code = """import plotly.express as px
|
|
62
|
-
fig = px.scatter(df, x='non_existent_column', y='y')"""
|
|
63
|
-
|
|
64
|
-
result = agent.execute_plotly_code(invalid_code)
|
|
65
|
-
assert "Error" in result
|
|
66
|
-
assert agent.execution_env.fig is None
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def test_does_fig_exist():
|
|
70
|
-
"""Test that does_fig_exist correctly reports figure existence."""
|
|
71
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
72
|
-
|
|
73
|
-
agent = PlotlyAgent()
|
|
74
|
-
agent.set_df(df)
|
|
75
|
-
|
|
76
|
-
# Initially no figure should exist
|
|
77
|
-
assert "No figure has been created yet" in agent.does_fig_exist()
|
|
78
|
-
|
|
79
|
-
# Create a figure
|
|
80
|
-
valid_code = """import plotly.express as px
|
|
81
|
-
fig = px.scatter(df, x='x', y='y')"""
|
|
82
|
-
agent.execute_plotly_code(valid_code)
|
|
83
|
-
|
|
84
|
-
# Now a figure should exist
|
|
85
|
-
assert "A figure is available for display" in agent.does_fig_exist()
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def test_reset_conversation():
|
|
89
|
-
"""Test that reset_conversation clears the chat history."""
|
|
90
|
-
agent = PlotlyAgent()
|
|
91
|
-
agent.chat_history = ["message1", "message2"]
|
|
92
|
-
agent.reset_conversation()
|
|
93
|
-
assert agent.chat_history == []
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def test_view_generated_code():
|
|
97
|
-
"""Test that view_generated_code returns the last generated code."""
|
|
98
|
-
agent = PlotlyAgent()
|
|
99
|
-
test_code = "test code"
|
|
100
|
-
agent.generated_code = test_code
|
|
101
|
-
assert agent.view_generated_code() == test_code
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def test_get_figure():
|
|
105
|
-
"""Test that get_figure returns the current figure if it exists."""
|
|
106
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
107
|
-
|
|
108
|
-
agent = PlotlyAgent()
|
|
109
|
-
agent.set_df(df)
|
|
110
|
-
|
|
111
|
-
# Initially no figure should exist
|
|
112
|
-
assert agent.get_figure() is None
|
|
113
|
-
|
|
114
|
-
# Create a figure
|
|
115
|
-
valid_code = """import plotly.express as px
|
|
116
|
-
fig = px.scatter(df, x='x', y='y')"""
|
|
117
|
-
agent.execute_plotly_code(valid_code)
|
|
118
|
-
|
|
119
|
-
# Now a figure should exist
|
|
120
|
-
assert agent.get_figure() is not None
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def test_process_message():
|
|
124
|
-
"""Test that process_message updates chat history and handles responses."""
|
|
125
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
126
|
-
|
|
127
|
-
agent = PlotlyAgent()
|
|
128
|
-
agent.set_df(df, sql_query="SELECT x, y FROM df")
|
|
129
|
-
|
|
130
|
-
# Test processing a message
|
|
131
|
-
response = agent.process_message("Create a scatter plot")
|
|
132
|
-
|
|
133
|
-
# Check that chat history was updated
|
|
134
|
-
assert len(agent.chat_history) == 2 # One human message and one AI message
|
|
135
|
-
assert isinstance(agent.chat_history[0], HumanMessage)
|
|
136
|
-
assert isinstance(agent.chat_history[1], AIMessage)
|
|
137
|
-
assert agent.chat_history[0].content == "Create a scatter plot"
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
def test_execute_plotly_code_without_df():
|
|
141
|
-
"""Test that execute_plotly_code handles the case when no dataframe is set."""
|
|
142
|
-
agent = PlotlyAgent()
|
|
143
|
-
result = agent.execute_plotly_code("some code")
|
|
144
|
-
assert "Error" in result and "No dataframe has been set" in result
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
def test_set_df_with_sql_query():
|
|
148
|
-
"""Test that set_df properly handles SQL query context."""
|
|
149
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
150
|
-
|
|
151
|
-
sql_query = "SELECT x, y FROM table"
|
|
152
|
-
agent = PlotlyAgent()
|
|
153
|
-
agent.set_df(df, sql_query=sql_query)
|
|
154
|
-
|
|
155
|
-
assert agent.sql_query == sql_query
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
def test_agent_initialization_with_custom_prompt():
|
|
159
|
-
"""Test agent initialization with custom system prompt."""
|
|
160
|
-
custom_prompt = "Custom system prompt for testing"
|
|
161
|
-
agent = PlotlyAgent(system_prompt=custom_prompt)
|
|
162
|
-
assert agent.system_prompt == custom_prompt
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def test_agent_initialization_with_different_model():
|
|
166
|
-
"""Test agent initialization with different model names."""
|
|
167
|
-
agent = PlotlyAgent(model="gpt-3.5-turbo")
|
|
168
|
-
assert agent.llm.model_name == "gpt-3.5-turbo"
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
def test_agent_initialization_with_verbose():
|
|
172
|
-
"""Test agent initialization with verbose settings."""
|
|
173
|
-
agent = PlotlyAgent(verbose=False)
|
|
174
|
-
assert agent.verbose == False
|
|
175
|
-
assert agent.agent_executor is None # Agent executor not initialized yet
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
def test_agent_initialization_with_max_iterations():
|
|
179
|
-
"""Test agent initialization with different max iterations."""
|
|
180
|
-
agent = PlotlyAgent(max_iterations=5)
|
|
181
|
-
assert agent.max_iterations == 5
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
def test_agent_initialization_with_early_stopping():
|
|
185
|
-
"""Test agent initialization with different early stopping methods."""
|
|
186
|
-
agent = PlotlyAgent(early_stopping_method="generate")
|
|
187
|
-
assert agent.early_stopping_method == "generate"
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
def test_process_empty_message():
|
|
191
|
-
"""Test processing of empty messages."""
|
|
192
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
193
|
-
|
|
194
|
-
agent = PlotlyAgent()
|
|
195
|
-
agent.set_df(df)
|
|
196
|
-
|
|
197
|
-
response = agent.process_message("")
|
|
198
|
-
assert len(agent.chat_history) == 2 # Should still create chat history entries
|
|
199
|
-
assert isinstance(agent.chat_history[0], HumanMessage)
|
|
200
|
-
assert isinstance(agent.chat_history[1], AIMessage)
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
def test_process_message_with_code_blocks():
|
|
204
|
-
"""Test processing messages that contain code blocks."""
|
|
205
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
206
|
-
|
|
207
|
-
agent = PlotlyAgent()
|
|
208
|
-
agent.set_df(df)
|
|
209
|
-
|
|
210
|
-
message = "Here's some code:\n```python\nprint('test')\n```"
|
|
211
|
-
response = agent.process_message(message)
|
|
212
|
-
assert len(agent.chat_history) == 2
|
|
213
|
-
assert "```python" in agent.chat_history[0].content
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
def test_execution_environment_with_different_plot_types():
|
|
217
|
-
"""Test execution environment with different types of plots."""
|
|
218
|
-
df = pd.DataFrame(
|
|
219
|
-
{
|
|
220
|
-
"x": [1, 2, 3, 4, 5],
|
|
221
|
-
"y": [10, 20, 30, 40, 50],
|
|
222
|
-
"category": ["A", "B", "A", "B", "A"],
|
|
223
|
-
}
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
agent = PlotlyAgent()
|
|
227
|
-
agent.set_df(df)
|
|
228
|
-
|
|
229
|
-
# Test scatter plot
|
|
230
|
-
scatter_code = """import plotly.express as px
|
|
231
|
-
fig = px.scatter(df, x='x', y='y')"""
|
|
232
|
-
result = agent.execute_plotly_code(scatter_code)
|
|
233
|
-
assert "Code executed successfully" in result
|
|
234
|
-
assert agent.execution_env.fig is not None
|
|
235
|
-
|
|
236
|
-
# Test bar plot
|
|
237
|
-
bar_code = """import plotly.express as px
|
|
238
|
-
fig = px.bar(df, x='category', y='y')"""
|
|
239
|
-
result = agent.execute_plotly_code(bar_code)
|
|
240
|
-
assert "Code executed successfully" in result
|
|
241
|
-
assert agent.execution_env.fig is not None
|
|
242
|
-
|
|
243
|
-
# Test line plot
|
|
244
|
-
line_code = """import plotly.express as px
|
|
245
|
-
fig = px.line(df, x='x', y='y')"""
|
|
246
|
-
result = agent.execute_plotly_code(line_code)
|
|
247
|
-
assert "Code executed successfully" in result
|
|
248
|
-
assert agent.execution_env.fig is not None
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
def test_execution_environment_with_subplots():
|
|
252
|
-
"""Test execution environment with subplots."""
|
|
253
|
-
df = pd.DataFrame(
|
|
254
|
-
{"x": [1, 2, 3, 4, 5], "y1": [10, 20, 30, 40, 50], "y2": [50, 40, 30, 20, 10]}
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
agent = PlotlyAgent()
|
|
258
|
-
agent.set_df(df)
|
|
259
|
-
|
|
260
|
-
subplot_code = """import plotly.subplots as sp
|
|
261
|
-
import plotly.graph_objects as go
|
|
262
|
-
fig = sp.make_subplots(rows=1, cols=2)
|
|
263
|
-
fig.add_trace(go.Scatter(x=df['x'], y=df['y1']), row=1, col=1)
|
|
264
|
-
fig.add_trace(go.Scatter(x=df['x'], y=df['y2']), row=1, col=2)"""
|
|
265
|
-
|
|
266
|
-
result = agent.execute_plotly_code(subplot_code)
|
|
267
|
-
assert "Code executed successfully" in result
|
|
268
|
-
assert agent.execution_env.fig is not None
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
def test_execution_environment_with_data_preprocessing():
|
|
272
|
-
"""Test execution environment with data preprocessing steps."""
|
|
273
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
274
|
-
|
|
275
|
-
agent = PlotlyAgent()
|
|
276
|
-
agent.set_df(df)
|
|
277
|
-
|
|
278
|
-
preprocessing_code = """import plotly.express as px
|
|
279
|
-
# Preprocessing steps
|
|
280
|
-
df['y_normalized'] = (df['y'] - df['y'].min()) / (df['y'].max() - df['y'].min())
|
|
281
|
-
fig = px.scatter(df, x='x', y='y_normalized')"""
|
|
282
|
-
|
|
283
|
-
result = agent.execute_plotly_code(preprocessing_code)
|
|
284
|
-
assert "Code executed successfully" in result
|
|
285
|
-
assert agent.execution_env.fig is not None
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
def test_handle_syntax_error():
|
|
289
|
-
"""Test handling of syntax errors in generated code."""
|
|
290
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
291
|
-
|
|
292
|
-
agent = PlotlyAgent()
|
|
293
|
-
agent.set_df(df)
|
|
294
|
-
|
|
295
|
-
invalid_code = """import plotly.express as px
|
|
296
|
-
fig = px.scatter(df, x='x', y='y' # Missing closing parenthesis"""
|
|
297
|
-
|
|
298
|
-
result = agent.execute_plotly_code(invalid_code)
|
|
299
|
-
assert "Error" in result
|
|
300
|
-
assert "SyntaxError" in result
|
|
301
|
-
assert agent.execution_env.fig is None
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
def test_handle_runtime_error():
|
|
305
|
-
"""Test handling of runtime errors in generated code."""
|
|
306
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
307
|
-
|
|
308
|
-
agent = PlotlyAgent()
|
|
309
|
-
agent.set_df(df)
|
|
310
|
-
|
|
311
|
-
error_code = """import plotly.express as px
|
|
312
|
-
fig = px.scatter(df, x='x', y='y', color='non_existent_column')"""
|
|
313
|
-
|
|
314
|
-
result = agent.execute_plotly_code(error_code)
|
|
315
|
-
assert "Error" in result
|
|
316
|
-
assert agent.execution_env.fig is None
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
def test_tool_interaction():
|
|
320
|
-
"""Test interaction between different tools."""
|
|
321
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
322
|
-
|
|
323
|
-
agent = PlotlyAgent()
|
|
324
|
-
agent.set_df(df)
|
|
325
|
-
|
|
326
|
-
# First check if figure exists (should not)
|
|
327
|
-
assert "No figure has been created yet" in agent.does_fig_exist()
|
|
328
|
-
|
|
329
|
-
# Generate and execute code
|
|
330
|
-
code = """import plotly.express as px
|
|
331
|
-
fig = px.scatter(df, x='x', y='y')"""
|
|
332
|
-
result = agent.execute_plotly_code(code)
|
|
333
|
-
assert "Code executed successfully" in result
|
|
334
|
-
|
|
335
|
-
# Check if figure exists (should now exist)
|
|
336
|
-
assert "A figure is available for display" in agent.does_fig_exist()
|
|
337
|
-
|
|
338
|
-
# View the generated code
|
|
339
|
-
assert code in agent.view_generated_code()
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
def test_tool_validation():
|
|
343
|
-
"""Test validation of tool inputs."""
|
|
344
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
345
|
-
|
|
346
|
-
agent = PlotlyAgent()
|
|
347
|
-
agent.set_df(df)
|
|
348
|
-
|
|
349
|
-
# Test with invalid code (empty string)
|
|
350
|
-
result = agent.execute_plotly_code("")
|
|
351
|
-
assert "Error" in result
|
|
352
|
-
|
|
353
|
-
# Test with invalid code (None)
|
|
354
|
-
with pytest.raises(AssertionError):
|
|
355
|
-
agent.execute_plotly_code(None)
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
def test_tool_response_formatting():
|
|
359
|
-
"""Test formatting of tool responses."""
|
|
360
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
361
|
-
|
|
362
|
-
agent = PlotlyAgent()
|
|
363
|
-
agent.set_df(df)
|
|
364
|
-
|
|
365
|
-
# Test execute_plotly_code response format
|
|
366
|
-
code = """import plotly.express as px
|
|
367
|
-
fig = px.scatter(df, x='x', y='y')"""
|
|
368
|
-
result = agent.execute_plotly_code(code)
|
|
369
|
-
assert isinstance(result, str)
|
|
370
|
-
assert "Code executed successfully" in result
|
|
371
|
-
|
|
372
|
-
# Test does_fig_exist response format
|
|
373
|
-
result = agent.does_fig_exist()
|
|
374
|
-
assert isinstance(result, str)
|
|
375
|
-
assert "figure" in result.lower()
|
|
376
|
-
|
|
377
|
-
# Test view_generated_code response format
|
|
378
|
-
result = agent.view_generated_code()
|
|
379
|
-
assert isinstance(result, str)
|
|
380
|
-
assert code in result
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
def test_memory_cleanup():
|
|
384
|
-
"""Test memory cleanup after multiple plot generations."""
|
|
385
|
-
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
386
|
-
|
|
387
|
-
agent = PlotlyAgent()
|
|
388
|
-
agent.set_df(df)
|
|
389
|
-
|
|
390
|
-
# Generate multiple plots
|
|
391
|
-
for i in range(5):
|
|
392
|
-
code = f"""import plotly.express as px
|
|
393
|
-
fig = px.scatter(df, x='x', y='y', title='Plot {i}')"""
|
|
394
|
-
result = agent.execute_plotly_code(code)
|
|
395
|
-
assert "Code executed successfully" in result
|
|
396
|
-
assert agent.execution_env.fig is not None
|
|
397
|
-
|
|
398
|
-
# Reset conversation and check memory
|
|
399
|
-
agent.reset_conversation()
|
|
400
|
-
assert len(agent.chat_history) == 0
|
|
401
|
-
assert agent.generated_code is None
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
def test_large_dataframe_handling():
|
|
405
|
-
"""Test handling of large dataframes."""
|
|
406
|
-
# Create a large dataframe
|
|
407
|
-
df = pd.DataFrame({"x": range(10000), "y": range(10000)})
|
|
408
|
-
|
|
409
|
-
agent = PlotlyAgent()
|
|
410
|
-
agent.set_df(df)
|
|
411
|
-
|
|
412
|
-
# Test plot generation with large dataframe
|
|
413
|
-
code = """import plotly.express as px
|
|
414
|
-
fig = px.scatter(df, x='x', y='y')"""
|
|
415
|
-
result = agent.execute_plotly_code(code)
|
|
416
|
-
assert "Code executed successfully" in result
|
|
417
|
-
assert agent.execution_env.fig is not None
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
def test_input_validation():
|
|
421
|
-
"""Test validation of input parameters."""
|
|
422
|
-
# Test invalid dataframe input
|
|
423
|
-
with pytest.raises(AssertionError):
|
|
424
|
-
agent = PlotlyAgent()
|
|
425
|
-
agent.set_df("not a dataframe")
|
|
426
|
-
|
|
427
|
-
# Test invalid SQL query input
|
|
428
|
-
df = pd.DataFrame({"x": [1, 2, 3]})
|
|
429
|
-
agent = PlotlyAgent()
|
|
430
|
-
with pytest.raises(AssertionError):
|
|
431
|
-
agent.set_df(df, sql_query=123) # SQL query should be string
|
|
432
|
-
|
|
433
|
-
# Test invalid message input
|
|
434
|
-
agent.set_df(df)
|
|
435
|
-
with pytest.raises(AssertionError):
|
|
436
|
-
agent.process_message(123) # Message should be string
|
|
437
|
-
|
|
438
|
-
# Test invalid code input
|
|
439
|
-
with pytest.raises(AssertionError):
|
|
440
|
-
agent.execute_plotly_code(123) # Code should be string
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
def test_complex_plot_handling():
|
|
444
|
-
"""Test handling of complex plots with multiple traces and layouts."""
|
|
445
|
-
df = pd.DataFrame(
|
|
446
|
-
{
|
|
447
|
-
"x": [1, 2, 3, 4, 5],
|
|
448
|
-
"y1": [10, 20, 30, 40, 50],
|
|
449
|
-
"y2": [50, 40, 30, 20, 10],
|
|
450
|
-
"category": ["A", "B", "A", "B", "A"],
|
|
451
|
-
}
|
|
452
|
-
)
|
|
453
|
-
|
|
454
|
-
agent = PlotlyAgent()
|
|
455
|
-
agent.set_df(df)
|
|
456
|
-
|
|
457
|
-
complex_code = """import plotly.graph_objects as go
|
|
458
|
-
fig = go.Figure()
|
|
459
|
-
fig.add_trace(go.Scatter(x=df['x'], y=df['y1'], name='Trace 1'))
|
|
460
|
-
fig.add_trace(go.Scatter(x=df['x'], y=df['y2'], name='Trace 2'))
|
|
461
|
-
fig.update_layout(
|
|
462
|
-
title='Complex Plot',
|
|
463
|
-
xaxis_title='X Axis',
|
|
464
|
-
yaxis_title='Y Axis',
|
|
465
|
-
showlegend=True,
|
|
466
|
-
template='plotly_white'
|
|
467
|
-
)"""
|
|
468
|
-
|
|
469
|
-
result = agent.execute_plotly_code(complex_code)
|
|
470
|
-
assert "Code executed successfully" in result
|
|
471
|
-
assert agent.execution_env.fig is not None
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|