plot-agent 0.2.2__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {plot_agent-0.2.2 → plot_agent-0.3.0}/PKG-INFO +3 -3
- {plot_agent-0.2.2 → plot_agent-0.3.0}/README.md +2 -2
- {plot_agent-0.2.2 → plot_agent-0.3.0}/plot_agent/agent.py +4 -4
- {plot_agent-0.2.2 → plot_agent-0.3.0}/plot_agent/execution.py +1 -1
- {plot_agent-0.2.2 → plot_agent-0.3.0}/plot_agent/models.py +5 -2
- {plot_agent-0.2.2 → plot_agent-0.3.0}/plot_agent.egg-info/PKG-INFO +3 -3
- {plot_agent-0.2.2 → plot_agent-0.3.0}/pyproject.toml +1 -1
- {plot_agent-0.2.2 → plot_agent-0.3.0}/tests/test_plot_agent.py +33 -33
- {plot_agent-0.2.2 → plot_agent-0.3.0}/LICENSE +0 -0
- {plot_agent-0.2.2 → plot_agent-0.3.0}/plot_agent/__init__.py +0 -0
- {plot_agent-0.2.2 → plot_agent-0.3.0}/plot_agent/prompt.py +0 -0
- {plot_agent-0.2.2 → plot_agent-0.3.0}/plot_agent.egg-info/SOURCES.txt +0 -0
- {plot_agent-0.2.2 → plot_agent-0.3.0}/plot_agent.egg-info/dependency_links.txt +0 -0
- {plot_agent-0.2.2 → plot_agent-0.3.0}/plot_agent.egg-info/top_level.txt +0 -0
- {plot_agent-0.2.2 → plot_agent-0.3.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: plot-agent
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: An AI-powered data visualization assistant using Plotly
|
|
5
5
|
Author-email: andrewm4894 <andrewm4894@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
|
|
@@ -35,7 +35,7 @@ Here's a simple minimal example of how to use Plot Agent:
|
|
|
35
35
|
|
|
36
36
|
```python
|
|
37
37
|
import pandas as pd
|
|
38
|
-
from plot_agent.agent import
|
|
38
|
+
from plot_agent.agent import PlotAgent
|
|
39
39
|
|
|
40
40
|
# ensure OPENAI_API_KEY is set and available for langchain
|
|
41
41
|
|
|
@@ -46,7 +46,7 @@ df = pd.DataFrame({
|
|
|
46
46
|
})
|
|
47
47
|
|
|
48
48
|
# Initialize the agent
|
|
49
|
-
agent =
|
|
49
|
+
agent = PlotAgent()
|
|
50
50
|
|
|
51
51
|
# Set the dataframe
|
|
52
52
|
agent.set_df(df)
|
|
@@ -21,7 +21,7 @@ Here's a simple minimal example of how to use Plot Agent:
|
|
|
21
21
|
|
|
22
22
|
```python
|
|
23
23
|
import pandas as pd
|
|
24
|
-
from plot_agent.agent import
|
|
24
|
+
from plot_agent.agent import PlotAgent
|
|
25
25
|
|
|
26
26
|
# ensure OPENAI_API_KEY is set and available for langchain
|
|
27
27
|
|
|
@@ -32,7 +32,7 @@ df = pd.DataFrame({
|
|
|
32
32
|
})
|
|
33
33
|
|
|
34
34
|
# Initialize the agent
|
|
35
|
-
agent =
|
|
35
|
+
agent = PlotAgent()
|
|
36
36
|
|
|
37
37
|
# Set the dataframe
|
|
38
38
|
agent.set_df(df)
|
|
@@ -14,10 +14,10 @@ from plot_agent.models import (
|
|
|
14
14
|
DoesFigExistInput,
|
|
15
15
|
ViewGeneratedCodeInput,
|
|
16
16
|
)
|
|
17
|
-
from plot_agent.execution import
|
|
17
|
+
from plot_agent.execution import PlotAgentExecutionEnvironment
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class
|
|
20
|
+
class PlotAgent:
|
|
21
21
|
"""
|
|
22
22
|
A class that uses an LLM to generate Plotly code based on a user's plot description.
|
|
23
23
|
"""
|
|
@@ -32,7 +32,7 @@ class PlotlyAgent:
|
|
|
32
32
|
handle_parsing_errors: bool = True,
|
|
33
33
|
):
|
|
34
34
|
"""
|
|
35
|
-
Initialize the
|
|
35
|
+
Initialize the PlotAgent.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
38
|
model (str): The model to use for the LLM.
|
|
@@ -90,7 +90,7 @@ class PlotlyAgent:
|
|
|
90
90
|
self.sql_query = sql_query
|
|
91
91
|
|
|
92
92
|
# Initialize execution environment
|
|
93
|
-
self.execution_env =
|
|
93
|
+
self.execution_env = PlotAgentExecutionEnvironment(df)
|
|
94
94
|
|
|
95
95
|
# Initialize the agent with tools
|
|
96
96
|
self._initialize_agent()
|
|
@@ -1,14 +1,17 @@
|
|
|
1
1
|
from pydantic import BaseModel, Field
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
# Define input schemas for the tools
|
|
5
4
|
class PlotDescriptionInput(BaseModel):
|
|
5
|
+
"""Model indicating that the plot_description function takes a plot_description argument."""
|
|
6
|
+
|
|
6
7
|
plot_description: str = Field(
|
|
7
8
|
..., description="Description of the plot the user wants to create"
|
|
8
9
|
)
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class GeneratedCodeInput(BaseModel):
|
|
13
|
+
"""Model indicating that the generated_code function takes a generated_code argument."""
|
|
14
|
+
|
|
12
15
|
generated_code: str = Field(
|
|
13
16
|
..., description="Python code that creates a Plotly figure"
|
|
14
17
|
)
|
|
@@ -23,4 +26,4 @@ class DoesFigExistInput(BaseModel):
|
|
|
23
26
|
class ViewGeneratedCodeInput(BaseModel):
|
|
24
27
|
"""Model indicating that the view_generated_code function takes no arguments."""
|
|
25
28
|
|
|
26
|
-
pass
|
|
29
|
+
pass
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: plot-agent
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: An AI-powered data visualization assistant using Plotly
|
|
5
5
|
Author-email: andrewm4894 <andrewm4894@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
|
|
@@ -35,7 +35,7 @@ Here's a simple minimal example of how to use Plot Agent:
|
|
|
35
35
|
|
|
36
36
|
```python
|
|
37
37
|
import pandas as pd
|
|
38
|
-
from plot_agent.agent import
|
|
38
|
+
from plot_agent.agent import PlotAgent
|
|
39
39
|
|
|
40
40
|
# ensure OPENAI_API_KEY is set and available for langchain
|
|
41
41
|
|
|
@@ -46,7 +46,7 @@ df = pd.DataFrame({
|
|
|
46
46
|
})
|
|
47
47
|
|
|
48
48
|
# Initialize the agent
|
|
49
|
-
agent =
|
|
49
|
+
agent = PlotAgent()
|
|
50
50
|
|
|
51
51
|
# Set the dataframe
|
|
52
52
|
agent.set_df(df)
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import pytest
|
|
2
2
|
import pandas as pd
|
|
3
3
|
import numpy as np
|
|
4
|
-
from plot_agent.agent import
|
|
4
|
+
from plot_agent.agent import PlotAgent
|
|
5
5
|
from langchain_core.messages import HumanMessage, AIMessage
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def test_plotly_agent_initialization():
|
|
9
|
-
"""Test that
|
|
10
|
-
agent =
|
|
9
|
+
"""Test that PlotAgent initializes correctly."""
|
|
10
|
+
agent = PlotAgent()
|
|
11
11
|
assert agent.llm is not None
|
|
12
12
|
assert agent.df is None
|
|
13
13
|
assert agent.df_info is None
|
|
@@ -24,7 +24,7 @@ def test_set_df():
|
|
|
24
24
|
# Create a sample dataframe
|
|
25
25
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
26
26
|
|
|
27
|
-
agent =
|
|
27
|
+
agent = PlotAgent()
|
|
28
28
|
agent.set_df(df)
|
|
29
29
|
|
|
30
30
|
assert agent.df is not None
|
|
@@ -38,7 +38,7 @@ def test_execute_plotly_code():
|
|
|
38
38
|
"""Test that execute_plotly_code works with valid code."""
|
|
39
39
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
40
40
|
|
|
41
|
-
agent =
|
|
41
|
+
agent = PlotAgent()
|
|
42
42
|
agent.set_df(df)
|
|
43
43
|
|
|
44
44
|
# Test with valid plotly code
|
|
@@ -54,7 +54,7 @@ def test_execute_plotly_code_with_error():
|
|
|
54
54
|
"""Test that execute_plotly_code handles errors properly."""
|
|
55
55
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
56
56
|
|
|
57
|
-
agent =
|
|
57
|
+
agent = PlotAgent()
|
|
58
58
|
agent.set_df(df)
|
|
59
59
|
|
|
60
60
|
# Test with invalid code
|
|
@@ -70,7 +70,7 @@ def test_does_fig_exist():
|
|
|
70
70
|
"""Test that does_fig_exist correctly reports figure existence."""
|
|
71
71
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
72
72
|
|
|
73
|
-
agent =
|
|
73
|
+
agent = PlotAgent()
|
|
74
74
|
agent.set_df(df)
|
|
75
75
|
|
|
76
76
|
# Initially no figure should exist
|
|
@@ -87,7 +87,7 @@ fig = px.scatter(df, x='x', y='y')"""
|
|
|
87
87
|
|
|
88
88
|
def test_reset_conversation():
|
|
89
89
|
"""Test that reset_conversation clears the chat history."""
|
|
90
|
-
agent =
|
|
90
|
+
agent = PlotAgent()
|
|
91
91
|
agent.chat_history = ["message1", "message2"]
|
|
92
92
|
agent.reset_conversation()
|
|
93
93
|
assert agent.chat_history == []
|
|
@@ -95,7 +95,7 @@ def test_reset_conversation():
|
|
|
95
95
|
|
|
96
96
|
def test_view_generated_code():
|
|
97
97
|
"""Test that view_generated_code returns the last generated code."""
|
|
98
|
-
agent =
|
|
98
|
+
agent = PlotAgent()
|
|
99
99
|
test_code = "test code"
|
|
100
100
|
agent.generated_code = test_code
|
|
101
101
|
assert agent.view_generated_code() == test_code
|
|
@@ -105,7 +105,7 @@ def test_get_figure():
|
|
|
105
105
|
"""Test that get_figure returns the current figure if it exists."""
|
|
106
106
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
107
107
|
|
|
108
|
-
agent =
|
|
108
|
+
agent = PlotAgent()
|
|
109
109
|
agent.set_df(df)
|
|
110
110
|
|
|
111
111
|
# Initially no figure should exist
|
|
@@ -124,7 +124,7 @@ def test_process_message():
|
|
|
124
124
|
"""Test that process_message updates chat history and handles responses."""
|
|
125
125
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
126
126
|
|
|
127
|
-
agent =
|
|
127
|
+
agent = PlotAgent()
|
|
128
128
|
agent.set_df(df, sql_query="SELECT x, y FROM df")
|
|
129
129
|
|
|
130
130
|
# Test processing a message
|
|
@@ -139,7 +139,7 @@ def test_process_message():
|
|
|
139
139
|
|
|
140
140
|
def test_execute_plotly_code_without_df():
|
|
141
141
|
"""Test that execute_plotly_code handles the case when no dataframe is set."""
|
|
142
|
-
agent =
|
|
142
|
+
agent = PlotAgent()
|
|
143
143
|
result = agent.execute_plotly_code("some code")
|
|
144
144
|
assert "Error" in result and "No dataframe has been set" in result
|
|
145
145
|
|
|
@@ -149,7 +149,7 @@ def test_set_df_with_sql_query():
|
|
|
149
149
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
150
150
|
|
|
151
151
|
sql_query = "SELECT x, y FROM table"
|
|
152
|
-
agent =
|
|
152
|
+
agent = PlotAgent()
|
|
153
153
|
agent.set_df(df, sql_query=sql_query)
|
|
154
154
|
|
|
155
155
|
assert agent.sql_query == sql_query
|
|
@@ -158,32 +158,32 @@ def test_set_df_with_sql_query():
|
|
|
158
158
|
def test_agent_initialization_with_custom_prompt():
|
|
159
159
|
"""Test agent initialization with custom system prompt."""
|
|
160
160
|
custom_prompt = "Custom system prompt for testing"
|
|
161
|
-
agent =
|
|
161
|
+
agent = PlotAgent(system_prompt=custom_prompt)
|
|
162
162
|
assert agent.system_prompt == custom_prompt
|
|
163
163
|
|
|
164
164
|
|
|
165
165
|
def test_agent_initialization_with_different_model():
|
|
166
166
|
"""Test agent initialization with different model names."""
|
|
167
|
-
agent =
|
|
167
|
+
agent = PlotAgent(model="gpt-3.5-turbo")
|
|
168
168
|
assert agent.llm.model_name == "gpt-3.5-turbo"
|
|
169
169
|
|
|
170
170
|
|
|
171
171
|
def test_agent_initialization_with_verbose():
|
|
172
172
|
"""Test agent initialization with verbose settings."""
|
|
173
|
-
agent =
|
|
173
|
+
agent = PlotAgent(verbose=False)
|
|
174
174
|
assert agent.verbose == False
|
|
175
175
|
assert agent.agent_executor is None # Agent executor not initialized yet
|
|
176
176
|
|
|
177
177
|
|
|
178
178
|
def test_agent_initialization_with_max_iterations():
|
|
179
179
|
"""Test agent initialization with different max iterations."""
|
|
180
|
-
agent =
|
|
180
|
+
agent = PlotAgent(max_iterations=5)
|
|
181
181
|
assert agent.max_iterations == 5
|
|
182
182
|
|
|
183
183
|
|
|
184
184
|
def test_agent_initialization_with_early_stopping():
|
|
185
185
|
"""Test agent initialization with different early stopping methods."""
|
|
186
|
-
agent =
|
|
186
|
+
agent = PlotAgent(early_stopping_method="generate")
|
|
187
187
|
assert agent.early_stopping_method == "generate"
|
|
188
188
|
|
|
189
189
|
|
|
@@ -191,7 +191,7 @@ def test_process_empty_message():
|
|
|
191
191
|
"""Test processing of empty messages."""
|
|
192
192
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
193
193
|
|
|
194
|
-
agent =
|
|
194
|
+
agent = PlotAgent()
|
|
195
195
|
agent.set_df(df)
|
|
196
196
|
|
|
197
197
|
response = agent.process_message("")
|
|
@@ -204,7 +204,7 @@ def test_process_message_with_code_blocks():
|
|
|
204
204
|
"""Test processing messages that contain code blocks."""
|
|
205
205
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
206
206
|
|
|
207
|
-
agent =
|
|
207
|
+
agent = PlotAgent()
|
|
208
208
|
agent.set_df(df)
|
|
209
209
|
|
|
210
210
|
message = "Here's some code:\n```python\nprint('test')\n```"
|
|
@@ -223,7 +223,7 @@ def test_execution_environment_with_different_plot_types():
|
|
|
223
223
|
}
|
|
224
224
|
)
|
|
225
225
|
|
|
226
|
-
agent =
|
|
226
|
+
agent = PlotAgent()
|
|
227
227
|
agent.set_df(df)
|
|
228
228
|
|
|
229
229
|
# Test scatter plot
|
|
@@ -254,7 +254,7 @@ def test_execution_environment_with_subplots():
|
|
|
254
254
|
{"x": [1, 2, 3, 4, 5], "y1": [10, 20, 30, 40, 50], "y2": [50, 40, 30, 20, 10]}
|
|
255
255
|
)
|
|
256
256
|
|
|
257
|
-
agent =
|
|
257
|
+
agent = PlotAgent()
|
|
258
258
|
agent.set_df(df)
|
|
259
259
|
|
|
260
260
|
subplot_code = """import plotly.subplots as sp
|
|
@@ -272,7 +272,7 @@ def test_execution_environment_with_data_preprocessing():
|
|
|
272
272
|
"""Test execution environment with data preprocessing steps."""
|
|
273
273
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
274
274
|
|
|
275
|
-
agent =
|
|
275
|
+
agent = PlotAgent()
|
|
276
276
|
agent.set_df(df)
|
|
277
277
|
|
|
278
278
|
preprocessing_code = """import plotly.express as px
|
|
@@ -289,7 +289,7 @@ def test_handle_syntax_error():
|
|
|
289
289
|
"""Test handling of syntax errors in generated code."""
|
|
290
290
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
291
291
|
|
|
292
|
-
agent =
|
|
292
|
+
agent = PlotAgent()
|
|
293
293
|
agent.set_df(df)
|
|
294
294
|
|
|
295
295
|
invalid_code = """import plotly.express as px
|
|
@@ -305,7 +305,7 @@ def test_handle_runtime_error():
|
|
|
305
305
|
"""Test handling of runtime errors in generated code."""
|
|
306
306
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
307
307
|
|
|
308
|
-
agent =
|
|
308
|
+
agent = PlotAgent()
|
|
309
309
|
agent.set_df(df)
|
|
310
310
|
|
|
311
311
|
error_code = """import plotly.express as px
|
|
@@ -320,7 +320,7 @@ def test_tool_interaction():
|
|
|
320
320
|
"""Test interaction between different tools."""
|
|
321
321
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
322
322
|
|
|
323
|
-
agent =
|
|
323
|
+
agent = PlotAgent()
|
|
324
324
|
agent.set_df(df)
|
|
325
325
|
|
|
326
326
|
# First check if figure exists (should not)
|
|
@@ -343,7 +343,7 @@ def test_tool_validation():
|
|
|
343
343
|
"""Test validation of tool inputs."""
|
|
344
344
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
345
345
|
|
|
346
|
-
agent =
|
|
346
|
+
agent = PlotAgent()
|
|
347
347
|
agent.set_df(df)
|
|
348
348
|
|
|
349
349
|
# Test with invalid code (empty string)
|
|
@@ -359,7 +359,7 @@ def test_tool_response_formatting():
|
|
|
359
359
|
"""Test formatting of tool responses."""
|
|
360
360
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
361
361
|
|
|
362
|
-
agent =
|
|
362
|
+
agent = PlotAgent()
|
|
363
363
|
agent.set_df(df)
|
|
364
364
|
|
|
365
365
|
# Test execute_plotly_code response format
|
|
@@ -384,7 +384,7 @@ def test_memory_cleanup():
|
|
|
384
384
|
"""Test memory cleanup after multiple plot generations."""
|
|
385
385
|
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
|
|
386
386
|
|
|
387
|
-
agent =
|
|
387
|
+
agent = PlotAgent()
|
|
388
388
|
agent.set_df(df)
|
|
389
389
|
|
|
390
390
|
# Generate multiple plots
|
|
@@ -406,7 +406,7 @@ def test_large_dataframe_handling():
|
|
|
406
406
|
# Create a large dataframe
|
|
407
407
|
df = pd.DataFrame({"x": range(10000), "y": range(10000)})
|
|
408
408
|
|
|
409
|
-
agent =
|
|
409
|
+
agent = PlotAgent()
|
|
410
410
|
agent.set_df(df)
|
|
411
411
|
|
|
412
412
|
# Test plot generation with large dataframe
|
|
@@ -421,12 +421,12 @@ def test_input_validation():
|
|
|
421
421
|
"""Test validation of input parameters."""
|
|
422
422
|
# Test invalid dataframe input
|
|
423
423
|
with pytest.raises(AssertionError):
|
|
424
|
-
agent =
|
|
424
|
+
agent = PlotAgent()
|
|
425
425
|
agent.set_df("not a dataframe")
|
|
426
426
|
|
|
427
427
|
# Test invalid SQL query input
|
|
428
428
|
df = pd.DataFrame({"x": [1, 2, 3]})
|
|
429
|
-
agent =
|
|
429
|
+
agent = PlotAgent()
|
|
430
430
|
with pytest.raises(AssertionError):
|
|
431
431
|
agent.set_df(df, sql_query=123) # SQL query should be string
|
|
432
432
|
|
|
@@ -451,7 +451,7 @@ def test_complex_plot_handling():
|
|
|
451
451
|
}
|
|
452
452
|
)
|
|
453
453
|
|
|
454
|
-
agent =
|
|
454
|
+
agent = PlotAgent()
|
|
455
455
|
agent.set_df(df)
|
|
456
456
|
|
|
457
457
|
complex_code = """import plotly.graph_objects as go
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|