plot-agent 0.1.1__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plot-agent
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: An AI-powered data visualization assistant using Plotly
5
5
  Author-email: andrewm4894 <andrewm4894@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
@@ -14,6 +14,9 @@ Dynamic: license-file
14
14
 
15
15
  # Plot Agent
16
16
 
17
+ [![Tests](https://github.com/andrewm4894/plot-agent/actions/workflows/test.yml/badge.svg)](https://github.com/andrewm4894/plot-agent/actions/workflows/test.yml)
18
+ [![PyPI version](https://badge.fury.io/py/plot-agent.svg)](https://badge.fury.io/py/plot-agent)
19
+
17
20
  An AI-powered data visualization assistant that helps users create Plotly visualizations in Python.
18
21
 
19
22
  ## Installation
@@ -26,11 +29,15 @@ pip install plot-agent
26
29
 
27
30
  ## Usage
28
31
 
29
- Here's a simple example of how to use Plot Agent:
32
+ See more examples in [/examples/](https://nbviewer.org/github/andrewm4894/plot-agent/tree/main/examples/) (via nbviewer so that can see the charts easily).
33
+
34
+ Here's a simple minimal example of how to use Plot Agent:
30
35
 
31
36
  ```python
32
37
  import pandas as pd
33
- from plot_agent import PlotlyAgent
38
+ from plot_agent.agent import PlotlyAgent
39
+
40
+ # ensure OPENAI_API_KEY is set and available for langchain
34
41
 
35
42
  # Create a sample dataframe
36
43
  df = pd.DataFrame({
@@ -46,6 +53,43 @@ agent.set_df(df)
46
53
 
47
54
  # Process a visualization request
48
55
  response = agent.process_message("Create a line plot of x vs y")
56
+
57
+ # Print generated code
58
+ print(agent.generated_code)
59
+
60
+ # Get fig
61
+ fig = agent.get_figure()
62
+ fig.show()
63
+ ```
64
+
65
+ `agent.generated_code`:
66
+
67
+ ```python
68
+ import pandas as pd
69
+ import plotly.graph_objects as go
70
+
71
+ # Creating a line plot of x vs y
72
+ # Create a figure object
73
+ fig = go.Figure()
74
+
75
+ # Add a line trace to the figure
76
+ fig.add_trace(
77
+ go.Scatter(
78
+ x=df['x'], # The x values
79
+ y=df['y'], # The y values
80
+ mode='lines+markers', # Display both lines and markers
81
+ name='Line Plot', # Name of the trace
82
+ line=dict(color='blue', width=2) # Specify line color and width
83
+ )
84
+ )
85
+
86
+ # Adding titles and labels
87
+ fig.update_layout(
88
+ title='Line Plot of x vs y', # Plot title
89
+ xaxis_title='x', # x-axis label
90
+ yaxis_title='y', # y-axis label
91
+ template='plotly_white' # A clean layout
92
+ )
49
93
  ```
50
94
 
51
95
  ## Features
@@ -0,0 +1,96 @@
1
+ # Plot Agent
2
+
3
+ [![Tests](https://github.com/andrewm4894/plot-agent/actions/workflows/test.yml/badge.svg)](https://github.com/andrewm4894/plot-agent/actions/workflows/test.yml)
4
+ [![PyPI version](https://badge.fury.io/py/plot-agent.svg)](https://badge.fury.io/py/plot-agent)
5
+
6
+ An AI-powered data visualization assistant that helps users create Plotly visualizations in Python.
7
+
8
+ ## Installation
9
+
10
+ You can install the package using pip:
11
+
12
+ ```bash
13
+ pip install plot-agent
14
+ ```
15
+
16
+ ## Usage
17
+
18
+ See more examples in [/examples/](https://nbviewer.org/github/andrewm4894/plot-agent/tree/main/examples/) (via nbviewer so that can see the charts easily).
19
+
20
+ Here's a simple minimal example of how to use Plot Agent:
21
+
22
+ ```python
23
+ import pandas as pd
24
+ from plot_agent.agent import PlotlyAgent
25
+
26
+ # ensure OPENAI_API_KEY is set and available for langchain
27
+
28
+ # Create a sample dataframe
29
+ df = pd.DataFrame({
30
+ 'x': [1, 2, 3, 4, 5],
31
+ 'y': [10, 20, 30, 40, 50]
32
+ })
33
+
34
+ # Initialize the agent
35
+ agent = PlotlyAgent()
36
+
37
+ # Set the dataframe
38
+ agent.set_df(df)
39
+
40
+ # Process a visualization request
41
+ response = agent.process_message("Create a line plot of x vs y")
42
+
43
+ # Print generated code
44
+ print(agent.generated_code)
45
+
46
+ # Get fig
47
+ fig = agent.get_figure()
48
+ fig.show()
49
+ ```
50
+
51
+ `agent.generated_code`:
52
+
53
+ ```python
54
+ import pandas as pd
55
+ import plotly.graph_objects as go
56
+
57
+ # Creating a line plot of x vs y
58
+ # Create a figure object
59
+ fig = go.Figure()
60
+
61
+ # Add a line trace to the figure
62
+ fig.add_trace(
63
+ go.Scatter(
64
+ x=df['x'], # The x values
65
+ y=df['y'], # The y values
66
+ mode='lines+markers', # Display both lines and markers
67
+ name='Line Plot', # Name of the trace
68
+ line=dict(color='blue', width=2) # Specify line color and width
69
+ )
70
+ )
71
+
72
+ # Adding titles and labels
73
+ fig.update_layout(
74
+ title='Line Plot of x vs y', # Plot title
75
+ xaxis_title='x', # x-axis label
76
+ yaxis_title='y', # y-axis label
77
+ template='plotly_white' # A clean layout
78
+ )
79
+ ```
80
+
81
+ ## Features
82
+
83
+ - AI-powered visualization generation
84
+ - Support for various Plotly chart types
85
+ - Automatic data preprocessing
86
+ - Interactive visualization capabilities
87
+ - Integration with LangChain for advanced AI capabilities
88
+
89
+ ## Requirements
90
+
91
+ - Python 3.8 or higher
92
+ - Dependencies are automatically installed with the package
93
+
94
+ ## License
95
+
96
+ This project is licensed under the MIT License - see the LICENSE file for details.
@@ -7,79 +7,19 @@ from plotly.subplots import make_subplots
7
7
  from io import StringIO
8
8
  import traceback
9
9
  import sys
10
- import re
11
10
  from typing import Dict, List, Optional, Any
12
11
 
13
12
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
14
13
  from langchain_core.messages import AIMessage, HumanMessage
15
- from langchain_core.tools import Tool
16
- from pydantic import BaseModel, Field
14
+ from langchain_core.tools import Tool, StructuredTool
17
15
  from langchain.agents import AgentExecutor, create_openai_tools_agent
18
16
  from langchain_openai import ChatOpenAI
19
-
20
-
21
- DEFAULT_SYSTEM_PROMPT = """
22
- You are an expert data visualization assistant that helps users create Plotly visualizations in Python.
23
- Your job is to generate Python and Plotly code based on the user's request that will create the desired visualization
24
- of their pandas DataFrame (df).
25
-
26
- You have access to a pandas df with the following information:
27
-
28
- df.info():
29
- ```plaintext
30
- {df_info}
31
- ```
32
-
33
- df.head():
34
- ```plaintext
35
- {df_head}
36
- ```
37
-
38
- {sql_context}
39
-
40
- NOTES:
41
- - You must use the execute_plotly_code(generated_code) tool to test and run your code
42
- - You must paste the full code, not just a reference to the code
43
- - You must not use fig.show() in your code as it will be executed in a headless environment
44
- - If you need to do any data cleaning or wrangling, do it in the code before generating the plotly code as preprocessing steps assume the data is in the pandas df object
45
- - Do not use fig.show() in your code as it will be executed in a headless environment.
46
-
47
- IMPORTANT CODE FORMATTING INSTRUCTIONS:
48
- 1. Include thorough, detailed comments in your code to explain what each section does
49
- 2. Use descriptive variable names
50
- 3. DO NOT include fig.show() in your code - the visualization will be rendered externally
51
- 4. Ensure your code creates a variable named 'fig' that contains the Plotly figure object
52
- 5. Structure your code with proper spacing for readability
53
-
54
- When a user asks for a visualization:
55
- 1. YOU MUST ALWAYS use the execute_plotly_code(generated_code) tool to test and run your code
56
- 2. If there are errors, fix the code and run it again with execute_plotly_code(generated_code)
57
- 3. Check that a figure object is available using does_fig_exist(). does_fig_exist() takes no arguments.
58
-
59
- IMPORTANT: The code you generate MUST be executed using the execute_plotly_code tool or no figure will be created!
60
- YOU MUST CALL execute_plotly_code WITH THE FULL CODE, NOT JUST A REFERENCE TO THE CODE.
61
-
62
- YOUR WORKFLOW MUST BE:
63
- 1. execute_plotly_code(generated_code) → 2. check that a figure object is available using does_fig_exist() → 3. (if needed) fix and execute again
64
-
65
- Always return the final working code (with all the comments) to the user along with an explanation of what the visualization shows.
66
- Make sure to follow best practices for data visualization, such as appropriate chart types, labels, and colors.
67
-
68
- Remember that users may want to iterate on their visualizations, so be responsive to requests for changes.
69
- """
70
-
71
-
72
- # Define input schemas for the tools
73
- class PlotDescriptionInput(BaseModel):
74
- plot_description: str = Field(
75
- ..., description="Description of the plot the user wants to create"
76
- )
77
-
78
-
79
- class GeneratedCodeInput(BaseModel):
80
- generated_code: str = Field(
81
- ..., description="Python code that creates a Plotly figure"
82
- )
17
+ from plot_agent.prompt import DEFAULT_SYSTEM_PROMPT
18
+ from plot_agent.models import (
19
+ GeneratedCodeInput,
20
+ DoesFigExistInput,
21
+ ViewGeneratedCodeInput,
22
+ )
83
23
 
84
24
 
85
25
  class PlotlyAgentExecutionEnvironment:
@@ -100,13 +40,6 @@ class PlotlyAgentExecutionEnvironment:
100
40
  self.error = None
101
41
  self.fig = None
102
42
 
103
- def preprocess_code(self, generated_code: str) -> str:
104
- """Preprocess code to remove fig.show() calls."""
105
- # Remove fig.show() calls
106
- generated_code = re.sub(r"fig\.show\(\s*\)", "", generated_code)
107
- generated_code = re.sub(r"fig\.show\(.*\)", "", generated_code)
108
- return generated_code
109
-
110
43
  def execute_code(self, generated_code: str) -> Dict[str, Any]:
111
44
  """
112
45
  Execute the provided code and capture the fig object if created.
@@ -120,23 +53,20 @@ class PlotlyAgentExecutionEnvironment:
120
53
  self.output = None
121
54
  self.error = None
122
55
 
123
- # Preprocess code to remove fig.show() calls
124
- processed_code = self.preprocess_code(generated_code)
125
-
126
56
  # Capture stdout
127
57
  old_stdout = sys.stdout
128
58
  sys.stdout = mystdout = StringIO()
129
59
 
130
60
  try:
131
61
  # Execute the code
132
- exec(processed_code, globals(), self.locals_dict)
62
+ exec(generated_code, globals(), self.locals_dict)
133
63
 
134
64
  # Check if a fig object was created
135
65
  if "fig" in self.locals_dict:
136
66
  self.fig = self.locals_dict["fig"]
137
- self.output = "Code executed successfully. Figure object was created."
67
+ self.output = "Code executed successfully. 'fig' object was created."
138
68
  else:
139
- print(f"no fig object created: {processed_code}")
69
+ print(f"no fig object created: {generated_code}")
140
70
  self.error = "Code executed without errors, but no 'fig' object was created. Make sure your code creates a variable named 'fig'."
141
71
 
142
72
  except Exception as e:
@@ -166,13 +96,25 @@ class PlotlyAgent:
166
96
  A class that uses an LLM to generate Plotly code based on a user's plot description.
167
97
  """
168
98
 
169
- def __init__(self, model="gpt-4o-mini", system_prompt: Optional[str] = None):
99
+ def __init__(
100
+ self,
101
+ model="gpt-4o-mini",
102
+ system_prompt: Optional[str] = None,
103
+ verbose: bool = True,
104
+ max_iterations: int = 10,
105
+ early_stopping_method: str = "force",
106
+ handle_parsing_errors: bool = True,
107
+ ):
170
108
  """
171
109
  Initialize the PlotlyAgent.
172
110
 
173
111
  Args:
174
112
  model (str): The model to use for the LLM.
175
113
  system_prompt (Optional[str]): The system prompt to use for the LLM.
114
+ verbose (bool): Whether to print verbose output from the agent.
115
+ max_iterations (int): Maximum number of iterations for the agent to take.
116
+ early_stopping_method (str): Method to use for early stopping.
117
+ handle_parsing_errors (bool): Whether to handle parsing errors gracefully.
176
118
  """
177
119
  self.llm = ChatOpenAI(model=model)
178
120
  self.df = None
@@ -184,6 +126,10 @@ class PlotlyAgent:
184
126
  self.agent_executor = None
185
127
  self.generated_code = None
186
128
  self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
129
+ self.verbose = verbose
130
+ self.max_iterations = max_iterations
131
+ self.early_stopping_method = early_stopping_method
132
+ self.handle_parsing_errors = handle_parsing_errors
187
133
 
188
134
  def set_df(self, df: pd.DataFrame, sql_query: Optional[str] = None):
189
135
  """
@@ -257,6 +203,12 @@ class PlotlyAgent:
257
203
  else:
258
204
  return "No figure has been created yet."
259
205
 
206
+ def view_generated_code(self, *args, **kwargs) -> str:
207
+ """
208
+ View the generated code.
209
+ """
210
+ return self.generated_code
211
+
260
212
  def _initialize_agent(self):
261
213
  """Initialize the LangChain agent with the necessary tools and prompt."""
262
214
  tools = [
@@ -266,11 +218,17 @@ class PlotlyAgent:
266
218
  description="Execute the provided Plotly code and return the result",
267
219
  args_schema=GeneratedCodeInput,
268
220
  ),
269
- Tool.from_function(
221
+ StructuredTool.from_function(
270
222
  func=self.does_fig_exist,
271
223
  name="does_fig_exist",
272
- description="Check if a figure exists and is available for display",
273
- args_schema=None,
224
+ description="Check if a figure exists and is available for display. This tool takes no arguments.",
225
+ args_schema=DoesFigExistInput,
226
+ ),
227
+ StructuredTool.from_function(
228
+ func=self.view_generated_code,
229
+ name="view_generated_code",
230
+ description="View the generated code.",
231
+ args_schema=ViewGeneratedCodeInput,
274
232
  ),
275
233
  ]
276
234
 
@@ -299,10 +257,10 @@ class PlotlyAgent:
299
257
  self.agent_executor = AgentExecutor(
300
258
  agent=agent,
301
259
  tools=tools,
302
- verbose=True,
303
- max_iterations=10,
304
- early_stopping_method="force",
305
- handle_parsing_errors=True,
260
+ verbose=self.verbose,
261
+ max_iterations=self.max_iterations,
262
+ early_stopping_method=self.early_stopping_method,
263
+ handle_parsing_errors=self.handle_parsing_errors,
306
264
  )
307
265
 
308
266
  def process_message(self, user_message: str) -> str:
@@ -0,0 +1,26 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+
4
+ # Define input schemas for the tools
5
+ class PlotDescriptionInput(BaseModel):
6
+ plot_description: str = Field(
7
+ ..., description="Description of the plot the user wants to create"
8
+ )
9
+
10
+
11
+ class GeneratedCodeInput(BaseModel):
12
+ generated_code: str = Field(
13
+ ..., description="Python code that creates a Plotly figure"
14
+ )
15
+
16
+
17
+ class DoesFigExistInput(BaseModel):
18
+ """Model indicating that the does_fig_exist function takes no arguments."""
19
+
20
+ pass
21
+
22
+
23
+ class ViewGeneratedCodeInput(BaseModel):
24
+ """Model indicating that the view_generated_code function takes no arguments."""
25
+
26
+ pass
@@ -0,0 +1,57 @@
1
+ DEFAULT_SYSTEM_PROMPT = """
2
+ You are an expert data visualization assistant that helps users create Plotly visualizations in Python.
3
+ Your job is to generate Python and Plotly code based on the user's request that will create the desired visualization
4
+ of their pandas DataFrame (df).
5
+
6
+ You have access to a pandas df with the following information:
7
+
8
+ df.info():
9
+ ```plaintext
10
+ {df_info}
11
+ ```
12
+
13
+ df.head():
14
+ ```plaintext
15
+ {df_head}
16
+ ```
17
+
18
+ {sql_context}
19
+
20
+ NOTES:
21
+ - You must use the execute_plotly_code(generated_code) tool run your code and use the does_fig_exist() tool to check that a fig object is available for display.
22
+ - You must paste the full code, not just a reference to the code.
23
+ - You must not use fig.show() in your code as it will ultimately be executed elsewhere in a headless environment.
24
+ - If you need to do any data cleaning or wrangling, do it in the code before generating the plotly code as preprocessing steps assume the data is in the pandas 'df' object.
25
+
26
+ TOOLS:
27
+ - execute_plotly_code(generated_code) to execute the generated code.
28
+ - does_fig_exist() to check that a fig object is available for display. This tool takes no arguments.
29
+ - view_generated_code() to view the generated code if need to fix it. This tool takes no arguments.
30
+
31
+ IMPORTANT CODE FORMATTING INSTRUCTIONS:
32
+ 1. Include thorough, detailed comments in your code to explain what each section does.
33
+ 2. Use descriptive variable names.
34
+ 3. DO NOT include fig.show() in your code - the visualization will be rendered externally.
35
+ 4. Ensure your code creates a variable named 'fig' that contains the Plotly figure object.
36
+
37
+ When a user asks for a visualization:
38
+ 1. YOU MUST ALWAYS use the execute_plotly_code(generated_code) tool to test and run your code.
39
+ 2. If there are errors, view the generated code using view_generated_code() and fix the code.
40
+ 3. Check that a figure object is available using does_fig_exist(). does_fig_exist() takes no arguments.
41
+ 4. If the figure object is not available, repeat the process until it is available.
42
+
43
+ IMPORTANT: The code you generate MUST be executed using the execute_plotly_code tool or no figure will be created!
44
+ YOU MUST CALL execute_plotly_code WITH THE FULL CODE, NOT JUST A REFERENCE TO THE CODE.
45
+
46
+ YOUR WORKFLOW MUST BE:
47
+ 1. execute_plotly_code(generated_code) to make sure the code is ran and a figure object is created.
48
+ 2. check that a figure object is available using does_fig_exist() to make sure the figure object was created.
49
+ 3. if there are errors, view the generated code using view_generated_code() to see what went wrong.
50
+ 4. fix the code and execute it again with execute_plotly_code(generated_code) to make sure the figure object is created.
51
+ 5. repeat until the figure object is available.
52
+
53
+ Always return the final working code (with all the comments) to the user along with an explanation of what the visualization shows.
54
+ Make sure to follow best practices for data visualization, such as appropriate chart types, labels, and colors.
55
+
56
+ Remember that users may want to iterate on their visualizations, so be responsive to requests for changes.
57
+ """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plot-agent
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: An AI-powered data visualization assistant using Plotly
5
5
  Author-email: andrewm4894 <andrewm4894@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
@@ -14,6 +14,9 @@ Dynamic: license-file
14
14
 
15
15
  # Plot Agent
16
16
 
17
+ [![Tests](https://github.com/andrewm4894/plot-agent/actions/workflows/test.yml/badge.svg)](https://github.com/andrewm4894/plot-agent/actions/workflows/test.yml)
18
+ [![PyPI version](https://badge.fury.io/py/plot-agent.svg)](https://badge.fury.io/py/plot-agent)
19
+
17
20
  An AI-powered data visualization assistant that helps users create Plotly visualizations in Python.
18
21
 
19
22
  ## Installation
@@ -26,11 +29,15 @@ pip install plot-agent
26
29
 
27
30
  ## Usage
28
31
 
29
- Here's a simple example of how to use Plot Agent:
32
+ See more examples in [/examples/](https://nbviewer.org/github/andrewm4894/plot-agent/tree/main/examples/) (via nbviewer so that can see the charts easily).
33
+
34
+ Here's a simple minimal example of how to use Plot Agent:
30
35
 
31
36
  ```python
32
37
  import pandas as pd
33
- from plot_agent import PlotlyAgent
38
+ from plot_agent.agent import PlotlyAgent
39
+
40
+ # ensure OPENAI_API_KEY is set and available for langchain
34
41
 
35
42
  # Create a sample dataframe
36
43
  df = pd.DataFrame({
@@ -46,6 +53,43 @@ agent.set_df(df)
46
53
 
47
54
  # Process a visualization request
48
55
  response = agent.process_message("Create a line plot of x vs y")
56
+
57
+ # Print generated code
58
+ print(agent.generated_code)
59
+
60
+ # Get fig
61
+ fig = agent.get_figure()
62
+ fig.show()
63
+ ```
64
+
65
+ `agent.generated_code`:
66
+
67
+ ```python
68
+ import pandas as pd
69
+ import plotly.graph_objects as go
70
+
71
+ # Creating a line plot of x vs y
72
+ # Create a figure object
73
+ fig = go.Figure()
74
+
75
+ # Add a line trace to the figure
76
+ fig.add_trace(
77
+ go.Scatter(
78
+ x=df['x'], # The x values
79
+ y=df['y'], # The y values
80
+ mode='lines+markers', # Display both lines and markers
81
+ name='Line Plot', # Name of the trace
82
+ line=dict(color='blue', width=2) # Specify line color and width
83
+ )
84
+ )
85
+
86
+ # Adding titles and labels
87
+ fig.update_layout(
88
+ title='Line Plot of x vs y', # Plot title
89
+ xaxis_title='x', # x-axis label
90
+ yaxis_title='y', # y-axis label
91
+ template='plotly_white' # A clean layout
92
+ )
49
93
  ```
50
94
 
51
95
  ## Features
@@ -2,8 +2,11 @@ LICENSE
2
2
  README.md
3
3
  pyproject.toml
4
4
  plot_agent/__init__.py
5
- plot_agent/plotly_agent.py
5
+ plot_agent/agent.py
6
+ plot_agent/models.py
7
+ plot_agent/prompt.py
6
8
  plot_agent.egg-info/PKG-INFO
7
9
  plot_agent.egg-info/SOURCES.txt
8
10
  plot_agent.egg-info/dependency_links.txt
9
- plot_agent.egg-info/top_level.txt
11
+ plot_agent.egg-info/top_level.txt
12
+ tests/test_plotly_agent.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "plot-agent"
7
- version = "0.1.1"
7
+ version = "0.2.0"
8
8
  authors = [
9
9
  { name="andrewm4894", email="andrewm4894@gmail.com" },
10
10
  ]
@@ -0,0 +1,152 @@
1
+ import pytest
2
+ import pandas as pd
3
+ import numpy as np
4
+ from plot_agent.agent import PlotlyAgent
5
+ from langchain_core.messages import HumanMessage, AIMessage
6
+
7
+ def test_plotly_agent_initialization():
8
+ """Test that PlotlyAgent initializes correctly."""
9
+ agent = PlotlyAgent()
10
+ assert agent.llm is not None
11
+ assert agent.df is None
12
+ assert agent.df_info is None
13
+ assert agent.df_head is None
14
+ assert agent.sql_query is None
15
+ assert agent.execution_env is None
16
+ assert agent.chat_history == []
17
+ assert agent.agent_executor is None
18
+ assert agent.generated_code is None
19
+
20
+ def test_set_df():
21
+ """Test that set_df properly sets up the dataframe and environment."""
22
+ # Create a sample dataframe
23
+ df = pd.DataFrame({
24
+ 'x': [1, 2, 3, 4, 5],
25
+ 'y': [10, 20, 30, 40, 50]
26
+ })
27
+
28
+ agent = PlotlyAgent()
29
+ agent.set_df(df)
30
+
31
+ assert agent.df is not None
32
+ assert agent.df_info is not None
33
+ assert agent.df_head is not None
34
+ assert agent.execution_env is not None
35
+ assert agent.agent_executor is not None
36
+
37
+ def test_execute_plotly_code():
38
+ """Test that execute_plotly_code works with valid code."""
39
+ df = pd.DataFrame({
40
+ 'x': [1, 2, 3, 4, 5],
41
+ 'y': [10, 20, 30, 40, 50]
42
+ })
43
+
44
+ agent = PlotlyAgent()
45
+ agent.set_df(df)
46
+
47
+ # Test with valid plotly code
48
+ valid_code = """import plotly.express as px
49
+ fig = px.scatter(df, x='x', y='y')"""
50
+
51
+ result = agent.execute_plotly_code(valid_code)
52
+ assert "Code executed successfully" in result
53
+ assert agent.execution_env.fig is not None
54
+
55
+ def test_execute_plotly_code_with_error():
56
+ """Test that execute_plotly_code handles errors properly."""
57
+ df = pd.DataFrame({
58
+ 'x': [1, 2, 3, 4, 5],
59
+ 'y': [10, 20, 30, 40, 50]
60
+ })
61
+
62
+ agent = PlotlyAgent()
63
+ agent.set_df(df)
64
+
65
+ # Test with invalid code
66
+ invalid_code = """import plotly.express as px
67
+ fig = px.scatter(df, x='non_existent_column', y='y')"""
68
+
69
+ result = agent.execute_plotly_code(invalid_code)
70
+ assert "Error" in result
71
+ assert agent.execution_env.fig is None
72
+
73
+ def test_does_fig_exist():
74
+ """Test that does_fig_exist correctly reports figure existence."""
75
+ df = pd.DataFrame({
76
+ 'x': [1, 2, 3, 4, 5],
77
+ 'y': [10, 20, 30, 40, 50]
78
+ })
79
+
80
+ agent = PlotlyAgent()
81
+ agent.set_df(df)
82
+
83
+ # Initially no figure should exist
84
+ assert "No figure has been created yet" in agent.does_fig_exist()
85
+
86
+ # Create a figure
87
+ valid_code = """import plotly.express as px
88
+ fig = px.scatter(df, x='x', y='y')"""
89
+ agent.execute_plotly_code(valid_code)
90
+
91
+ # Now a figure should exist
92
+ assert "A figure is available for display" in agent.does_fig_exist()
93
+
94
+ def test_reset_conversation():
95
+ """Test that reset_conversation clears the chat history."""
96
+ agent = PlotlyAgent()
97
+ agent.chat_history = ["message1", "message2"]
98
+ agent.reset_conversation()
99
+ assert agent.chat_history == []
100
+
101
+ def test_view_generated_code():
102
+ """Test that view_generated_code returns the last generated code."""
103
+ agent = PlotlyAgent()
104
+ test_code = "test code"
105
+ agent.generated_code = test_code
106
+ assert agent.view_generated_code() == test_code
107
+
108
+ def test_get_figure():
109
+ """Test that get_figure returns the current figure if it exists."""
110
+ df = pd.DataFrame({
111
+ 'x': [1, 2, 3, 4, 5],
112
+ 'y': [10, 20, 30, 40, 50]
113
+ })
114
+
115
+ agent = PlotlyAgent()
116
+ agent.set_df(df)
117
+
118
+ # Initially no figure should exist
119
+ assert agent.get_figure() is None
120
+
121
+ # Create a figure
122
+ valid_code = """import plotly.express as px
123
+ fig = px.scatter(df, x='x', y='y')"""
124
+ agent.execute_plotly_code(valid_code)
125
+
126
+ # Now a figure should exist
127
+ assert agent.get_figure() is not None
128
+
129
+ def test_process_message():
130
+ """Test that process_message updates chat history and handles responses."""
131
+ df = pd.DataFrame({
132
+ 'x': [1, 2, 3, 4, 5],
133
+ 'y': [10, 20, 30, 40, 50]
134
+ })
135
+
136
+ agent = PlotlyAgent()
137
+ agent.set_df(df)
138
+
139
+ # Test processing a message
140
+ response = agent.process_message("Create a scatter plot")
141
+
142
+ # Check that chat history was updated
143
+ assert len(agent.chat_history) == 2 # One human message and one AI message
144
+ assert isinstance(agent.chat_history[0], HumanMessage)
145
+ assert isinstance(agent.chat_history[1], AIMessage)
146
+ assert agent.chat_history[0].content == "Create a scatter plot"
147
+
148
+ def test_execute_plotly_code_without_df():
149
+ """Test that execute_plotly_code handles the case when no dataframe is set."""
150
+ agent = PlotlyAgent()
151
+ result = agent.execute_plotly_code("some code")
152
+ assert "Error" in result and "No dataframe has been set" in result
@@ -1,52 +0,0 @@
1
- # Plot Agent
2
-
3
- An AI-powered data visualization assistant that helps users create Plotly visualizations in Python.
4
-
5
- ## Installation
6
-
7
- You can install the package using pip:
8
-
9
- ```bash
10
- pip install plot-agent
11
- ```
12
-
13
- ## Usage
14
-
15
- Here's a simple example of how to use Plot Agent:
16
-
17
- ```python
18
- import pandas as pd
19
- from plot_agent import PlotlyAgent
20
-
21
- # Create a sample dataframe
22
- df = pd.DataFrame({
23
- 'x': [1, 2, 3, 4, 5],
24
- 'y': [10, 20, 30, 40, 50]
25
- })
26
-
27
- # Initialize the agent
28
- agent = PlotlyAgent()
29
-
30
- # Set the dataframe
31
- agent.set_df(df)
32
-
33
- # Process a visualization request
34
- response = agent.process_message("Create a line plot of x vs y")
35
- ```
36
-
37
- ## Features
38
-
39
- - AI-powered visualization generation
40
- - Support for various Plotly chart types
41
- - Automatic data preprocessing
42
- - Interactive visualization capabilities
43
- - Integration with LangChain for advanced AI capabilities
44
-
45
- ## Requirements
46
-
47
- - Python 3.8 or higher
48
- - Dependencies are automatically installed with the package
49
-
50
- ## License
51
-
52
- This project is licensed under the MIT License - see the LICENSE file for details.
File without changes
File without changes