plot-agent 0.2.2__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plot-agent
3
- Version: 0.2.2
3
+ Version: 0.3.0
4
4
  Summary: An AI-powered data visualization assistant using Plotly
5
5
  Author-email: andrewm4894 <andrewm4894@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
@@ -35,7 +35,7 @@ Here's a simple minimal example of how to use Plot Agent:
35
35
 
36
36
  ```python
37
37
  import pandas as pd
38
- from plot_agent.agent import PlotlyAgent
38
+ from plot_agent.agent import PlotAgent
39
39
 
40
40
  # ensure OPENAI_API_KEY is set and available for langchain
41
41
 
@@ -46,7 +46,7 @@ df = pd.DataFrame({
46
46
  })
47
47
 
48
48
  # Initialize the agent
49
- agent = PlotlyAgent()
49
+ agent = PlotAgent()
50
50
 
51
51
  # Set the dataframe
52
52
  agent.set_df(df)
@@ -21,7 +21,7 @@ Here's a simple minimal example of how to use Plot Agent:
21
21
 
22
22
  ```python
23
23
  import pandas as pd
24
- from plot_agent.agent import PlotlyAgent
24
+ from plot_agent.agent import PlotAgent
25
25
 
26
26
  # ensure OPENAI_API_KEY is set and available for langchain
27
27
 
@@ -32,7 +32,7 @@ df = pd.DataFrame({
32
32
  })
33
33
 
34
34
  # Initialize the agent
35
- agent = PlotlyAgent()
35
+ agent = PlotAgent()
36
36
 
37
37
  # Set the dataframe
38
38
  agent.set_df(df)
@@ -14,10 +14,10 @@ from plot_agent.models import (
14
14
  DoesFigExistInput,
15
15
  ViewGeneratedCodeInput,
16
16
  )
17
- from plot_agent.execution import PlotlyAgentExecutionEnvironment
17
+ from plot_agent.execution import PlotAgentExecutionEnvironment
18
18
 
19
19
 
20
- class PlotlyAgent:
20
+ class PlotAgent:
21
21
  """
22
22
  A class that uses an LLM to generate Plotly code based on a user's plot description.
23
23
  """
@@ -32,7 +32,7 @@ class PlotlyAgent:
32
32
  handle_parsing_errors: bool = True,
33
33
  ):
34
34
  """
35
- Initialize the PlotlyAgent.
35
+ Initialize the PlotAgent.
36
36
 
37
37
  Args:
38
38
  model (str): The model to use for the LLM.
@@ -90,7 +90,7 @@ class PlotlyAgent:
90
90
  self.sql_query = sql_query
91
91
 
92
92
  # Initialize execution environment
93
- self.execution_env = PlotlyAgentExecutionEnvironment(df)
93
+ self.execution_env = PlotAgentExecutionEnvironment(df)
94
94
 
95
95
  # Initialize the agent with tools
96
96
  self._initialize_agent()
@@ -10,7 +10,7 @@ from plotly.subplots import make_subplots
10
10
  from typing import Dict, Any
11
11
 
12
12
 
13
- class PlotlyAgentExecutionEnvironment:
13
+ class PlotAgentExecutionEnvironment:
14
14
  """
15
15
  Environment to safely execute plotly code and capture the fig object.
16
16
 
@@ -1,14 +1,17 @@
1
1
  from pydantic import BaseModel, Field
2
2
 
3
3
 
4
- # Define input schemas for the tools
5
4
  class PlotDescriptionInput(BaseModel):
5
+ """Model indicating that the plot_description function takes a plot_description argument."""
6
+
6
7
  plot_description: str = Field(
7
8
  ..., description="Description of the plot the user wants to create"
8
9
  )
9
10
 
10
11
 
11
12
  class GeneratedCodeInput(BaseModel):
13
+ """Model indicating that the generated_code function takes a generated_code argument."""
14
+
12
15
  generated_code: str = Field(
13
16
  ..., description="Python code that creates a Plotly figure"
14
17
  )
@@ -23,4 +26,4 @@ class DoesFigExistInput(BaseModel):
23
26
  class ViewGeneratedCodeInput(BaseModel):
24
27
  """Model indicating that the view_generated_code function takes no arguments."""
25
28
 
26
- pass
29
+ pass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plot-agent
3
- Version: 0.2.2
3
+ Version: 0.3.0
4
4
  Summary: An AI-powered data visualization assistant using Plotly
5
5
  Author-email: andrewm4894 <andrewm4894@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
@@ -35,7 +35,7 @@ Here's a simple minimal example of how to use Plot Agent:
35
35
 
36
36
  ```python
37
37
  import pandas as pd
38
- from plot_agent.agent import PlotlyAgent
38
+ from plot_agent.agent import PlotAgent
39
39
 
40
40
  # ensure OPENAI_API_KEY is set and available for langchain
41
41
 
@@ -46,7 +46,7 @@ df = pd.DataFrame({
46
46
  })
47
47
 
48
48
  # Initialize the agent
49
- agent = PlotlyAgent()
49
+ agent = PlotAgent()
50
50
 
51
51
  # Set the dataframe
52
52
  agent.set_df(df)
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "plot-agent"
7
- version = "0.2.2"
7
+ version = "0.3.0"
8
8
  authors = [
9
9
  { name="andrewm4894", email="andrewm4894@gmail.com" },
10
10
  ]
@@ -1,13 +1,13 @@
1
1
  import pytest
2
2
  import pandas as pd
3
3
  import numpy as np
4
- from plot_agent.agent import PlotlyAgent
4
+ from plot_agent.agent import PlotAgent
5
5
  from langchain_core.messages import HumanMessage, AIMessage
6
6
 
7
7
 
8
8
  def test_plotly_agent_initialization():
9
- """Test that PlotlyAgent initializes correctly."""
10
- agent = PlotlyAgent()
9
+ """Test that PlotAgent initializes correctly."""
10
+ agent = PlotAgent()
11
11
  assert agent.llm is not None
12
12
  assert agent.df is None
13
13
  assert agent.df_info is None
@@ -24,7 +24,7 @@ def test_set_df():
24
24
  # Create a sample dataframe
25
25
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
26
26
 
27
- agent = PlotlyAgent()
27
+ agent = PlotAgent()
28
28
  agent.set_df(df)
29
29
 
30
30
  assert agent.df is not None
@@ -38,7 +38,7 @@ def test_execute_plotly_code():
38
38
  """Test that execute_plotly_code works with valid code."""
39
39
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
40
40
 
41
- agent = PlotlyAgent()
41
+ agent = PlotAgent()
42
42
  agent.set_df(df)
43
43
 
44
44
  # Test with valid plotly code
@@ -54,7 +54,7 @@ def test_execute_plotly_code_with_error():
54
54
  """Test that execute_plotly_code handles errors properly."""
55
55
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
56
56
 
57
- agent = PlotlyAgent()
57
+ agent = PlotAgent()
58
58
  agent.set_df(df)
59
59
 
60
60
  # Test with invalid code
@@ -70,7 +70,7 @@ def test_does_fig_exist():
70
70
  """Test that does_fig_exist correctly reports figure existence."""
71
71
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
72
72
 
73
- agent = PlotlyAgent()
73
+ agent = PlotAgent()
74
74
  agent.set_df(df)
75
75
 
76
76
  # Initially no figure should exist
@@ -87,7 +87,7 @@ fig = px.scatter(df, x='x', y='y')"""
87
87
 
88
88
  def test_reset_conversation():
89
89
  """Test that reset_conversation clears the chat history."""
90
- agent = PlotlyAgent()
90
+ agent = PlotAgent()
91
91
  agent.chat_history = ["message1", "message2"]
92
92
  agent.reset_conversation()
93
93
  assert agent.chat_history == []
@@ -95,7 +95,7 @@ def test_reset_conversation():
95
95
 
96
96
  def test_view_generated_code():
97
97
  """Test that view_generated_code returns the last generated code."""
98
- agent = PlotlyAgent()
98
+ agent = PlotAgent()
99
99
  test_code = "test code"
100
100
  agent.generated_code = test_code
101
101
  assert agent.view_generated_code() == test_code
@@ -105,7 +105,7 @@ def test_get_figure():
105
105
  """Test that get_figure returns the current figure if it exists."""
106
106
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
107
107
 
108
- agent = PlotlyAgent()
108
+ agent = PlotAgent()
109
109
  agent.set_df(df)
110
110
 
111
111
  # Initially no figure should exist
@@ -124,7 +124,7 @@ def test_process_message():
124
124
  """Test that process_message updates chat history and handles responses."""
125
125
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
126
126
 
127
- agent = PlotlyAgent()
127
+ agent = PlotAgent()
128
128
  agent.set_df(df, sql_query="SELECT x, y FROM df")
129
129
 
130
130
  # Test processing a message
@@ -139,7 +139,7 @@ def test_process_message():
139
139
 
140
140
  def test_execute_plotly_code_without_df():
141
141
  """Test that execute_plotly_code handles the case when no dataframe is set."""
142
- agent = PlotlyAgent()
142
+ agent = PlotAgent()
143
143
  result = agent.execute_plotly_code("some code")
144
144
  assert "Error" in result and "No dataframe has been set" in result
145
145
 
@@ -149,7 +149,7 @@ def test_set_df_with_sql_query():
149
149
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
150
150
 
151
151
  sql_query = "SELECT x, y FROM table"
152
- agent = PlotlyAgent()
152
+ agent = PlotAgent()
153
153
  agent.set_df(df, sql_query=sql_query)
154
154
 
155
155
  assert agent.sql_query == sql_query
@@ -158,32 +158,32 @@ def test_set_df_with_sql_query():
158
158
  def test_agent_initialization_with_custom_prompt():
159
159
  """Test agent initialization with custom system prompt."""
160
160
  custom_prompt = "Custom system prompt for testing"
161
- agent = PlotlyAgent(system_prompt=custom_prompt)
161
+ agent = PlotAgent(system_prompt=custom_prompt)
162
162
  assert agent.system_prompt == custom_prompt
163
163
 
164
164
 
165
165
  def test_agent_initialization_with_different_model():
166
166
  """Test agent initialization with different model names."""
167
- agent = PlotlyAgent(model="gpt-3.5-turbo")
167
+ agent = PlotAgent(model="gpt-3.5-turbo")
168
168
  assert agent.llm.model_name == "gpt-3.5-turbo"
169
169
 
170
170
 
171
171
  def test_agent_initialization_with_verbose():
172
172
  """Test agent initialization with verbose settings."""
173
- agent = PlotlyAgent(verbose=False)
173
+ agent = PlotAgent(verbose=False)
174
174
  assert agent.verbose == False
175
175
  assert agent.agent_executor is None # Agent executor not initialized yet
176
176
 
177
177
 
178
178
  def test_agent_initialization_with_max_iterations():
179
179
  """Test agent initialization with different max iterations."""
180
- agent = PlotlyAgent(max_iterations=5)
180
+ agent = PlotAgent(max_iterations=5)
181
181
  assert agent.max_iterations == 5
182
182
 
183
183
 
184
184
  def test_agent_initialization_with_early_stopping():
185
185
  """Test agent initialization with different early stopping methods."""
186
- agent = PlotlyAgent(early_stopping_method="generate")
186
+ agent = PlotAgent(early_stopping_method="generate")
187
187
  assert agent.early_stopping_method == "generate"
188
188
 
189
189
 
@@ -191,7 +191,7 @@ def test_process_empty_message():
191
191
  """Test processing of empty messages."""
192
192
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
193
193
 
194
- agent = PlotlyAgent()
194
+ agent = PlotAgent()
195
195
  agent.set_df(df)
196
196
 
197
197
  response = agent.process_message("")
@@ -204,7 +204,7 @@ def test_process_message_with_code_blocks():
204
204
  """Test processing messages that contain code blocks."""
205
205
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
206
206
 
207
- agent = PlotlyAgent()
207
+ agent = PlotAgent()
208
208
  agent.set_df(df)
209
209
 
210
210
  message = "Here's some code:\n```python\nprint('test')\n```"
@@ -223,7 +223,7 @@ def test_execution_environment_with_different_plot_types():
223
223
  }
224
224
  )
225
225
 
226
- agent = PlotlyAgent()
226
+ agent = PlotAgent()
227
227
  agent.set_df(df)
228
228
 
229
229
  # Test scatter plot
@@ -254,7 +254,7 @@ def test_execution_environment_with_subplots():
254
254
  {"x": [1, 2, 3, 4, 5], "y1": [10, 20, 30, 40, 50], "y2": [50, 40, 30, 20, 10]}
255
255
  )
256
256
 
257
- agent = PlotlyAgent()
257
+ agent = PlotAgent()
258
258
  agent.set_df(df)
259
259
 
260
260
  subplot_code = """import plotly.subplots as sp
@@ -272,7 +272,7 @@ def test_execution_environment_with_data_preprocessing():
272
272
  """Test execution environment with data preprocessing steps."""
273
273
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
274
274
 
275
- agent = PlotlyAgent()
275
+ agent = PlotAgent()
276
276
  agent.set_df(df)
277
277
 
278
278
  preprocessing_code = """import plotly.express as px
@@ -289,7 +289,7 @@ def test_handle_syntax_error():
289
289
  """Test handling of syntax errors in generated code."""
290
290
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
291
291
 
292
- agent = PlotlyAgent()
292
+ agent = PlotAgent()
293
293
  agent.set_df(df)
294
294
 
295
295
  invalid_code = """import plotly.express as px
@@ -305,7 +305,7 @@ def test_handle_runtime_error():
305
305
  """Test handling of runtime errors in generated code."""
306
306
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
307
307
 
308
- agent = PlotlyAgent()
308
+ agent = PlotAgent()
309
309
  agent.set_df(df)
310
310
 
311
311
  error_code = """import plotly.express as px
@@ -320,7 +320,7 @@ def test_tool_interaction():
320
320
  """Test interaction between different tools."""
321
321
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
322
322
 
323
- agent = PlotlyAgent()
323
+ agent = PlotAgent()
324
324
  agent.set_df(df)
325
325
 
326
326
  # First check if figure exists (should not)
@@ -343,7 +343,7 @@ def test_tool_validation():
343
343
  """Test validation of tool inputs."""
344
344
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
345
345
 
346
- agent = PlotlyAgent()
346
+ agent = PlotAgent()
347
347
  agent.set_df(df)
348
348
 
349
349
  # Test with invalid code (empty string)
@@ -359,7 +359,7 @@ def test_tool_response_formatting():
359
359
  """Test formatting of tool responses."""
360
360
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
361
361
 
362
- agent = PlotlyAgent()
362
+ agent = PlotAgent()
363
363
  agent.set_df(df)
364
364
 
365
365
  # Test execute_plotly_code response format
@@ -384,7 +384,7 @@ def test_memory_cleanup():
384
384
  """Test memory cleanup after multiple plot generations."""
385
385
  df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
386
386
 
387
- agent = PlotlyAgent()
387
+ agent = PlotAgent()
388
388
  agent.set_df(df)
389
389
 
390
390
  # Generate multiple plots
@@ -406,7 +406,7 @@ def test_large_dataframe_handling():
406
406
  # Create a large dataframe
407
407
  df = pd.DataFrame({"x": range(10000), "y": range(10000)})
408
408
 
409
- agent = PlotlyAgent()
409
+ agent = PlotAgent()
410
410
  agent.set_df(df)
411
411
 
412
412
  # Test plot generation with large dataframe
@@ -421,12 +421,12 @@ def test_input_validation():
421
421
  """Test validation of input parameters."""
422
422
  # Test invalid dataframe input
423
423
  with pytest.raises(AssertionError):
424
- agent = PlotlyAgent()
424
+ agent = PlotAgent()
425
425
  agent.set_df("not a dataframe")
426
426
 
427
427
  # Test invalid SQL query input
428
428
  df = pd.DataFrame({"x": [1, 2, 3]})
429
- agent = PlotlyAgent()
429
+ agent = PlotAgent()
430
430
  with pytest.raises(AssertionError):
431
431
  agent.set_df(df, sql_query=123) # SQL query should be string
432
432
 
@@ -451,7 +451,7 @@ def test_complex_plot_handling():
451
451
  }
452
452
  )
453
453
 
454
- agent = PlotlyAgent()
454
+ agent = PlotAgent()
455
455
  agent.set_df(df)
456
456
 
457
457
  complex_code = """import plotly.graph_objects as go
File without changes
File without changes