plot-agent 0.3.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plot_agent/__init__.py CHANGED
@@ -0,0 +1,5 @@
1
+ """plot-agent: LLM-powered Plotly visualization agent."""
2
+
3
+ from plot_agent.agent import PlotAgent
4
+
5
+ __all__ = ["PlotAgent"]
plot_agent/agent.py CHANGED
@@ -4,12 +4,15 @@ This module contains the PlotAgent class, which is used to generate Plotly code
4
4
 
5
5
  import pandas as pd
6
6
  from io import StringIO
7
- from typing import Optional
7
+ import os
8
+ import re
9
+ import logging
10
+ from typing import List, Optional, Union
11
+ from dotenv import load_dotenv
8
12
 
9
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
10
- from langchain_core.messages import AIMessage, HumanMessage
13
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
11
14
  from langchain_core.tools import Tool, StructuredTool
12
- from langchain.agents import AgentExecutor, create_openai_tools_agent
15
+ from langgraph.prebuilt import create_react_agent
13
16
  from langchain_openai import ChatOpenAI
14
17
 
15
18
  from plot_agent.prompt import DEFAULT_SYSTEM_PROMPT
@@ -17,9 +20,18 @@ from plot_agent.models import (
17
20
  GeneratedCodeInput,
18
21
  DoesFigExistInput,
19
22
  ViewGeneratedCodeInput,
23
+ CheckPlotOutputsInput,
20
24
  )
21
25
  from plot_agent.execution import PlotAgentExecutionEnvironment
22
26
 
27
+ # Optional PostHog integration
28
+ try:
29
+ from posthog import Posthog
30
+ from posthog.ai.langchain import CallbackHandler as PostHogCallbackHandler
31
+ POSTHOG_AVAILABLE = True
32
+ except ImportError:
33
+ POSTHOG_AVAILABLE = False
34
+
23
35
 
24
36
  class PlotAgent:
25
37
  """
@@ -28,12 +40,17 @@ class PlotAgent:
28
40
 
29
41
  def __init__(
30
42
  self,
31
- model="gpt-4o-mini",
43
+ model: str = "gpt-4o-mini",
32
44
  system_prompt: Optional[str] = None,
33
45
  verbose: bool = True,
34
46
  max_iterations: int = 10,
35
47
  early_stopping_method: str = "force",
36
48
  handle_parsing_errors: bool = True,
49
+ llm_temperature: float = 0.0,
50
+ llm_timeout: int = 60,
51
+ llm_max_retries: int = 1,
52
+ debug: bool = False,
53
+ include_plot_image: bool = False,
37
54
  ):
38
55
  """
39
56
  Initialize the PlotAgent.
@@ -45,14 +62,116 @@ class PlotAgent:
45
62
  max_iterations (int): Maximum number of iterations for the agent to take.
46
63
  early_stopping_method (str): Method to use for early stopping.
47
64
  handle_parsing_errors (bool): Whether to handle parsing errors gracefully.
65
+ llm_temperature (float): Temperature for LLM sampling.
66
+ llm_timeout (int): Timeout in seconds for LLM calls.
67
+ llm_max_retries (int): Maximum retries for LLM calls.
68
+ debug (bool): Enable debug logging.
69
+ include_plot_image (bool): Generate PNG image of plots for PostHog analytics.
48
70
  """
49
- self.llm = ChatOpenAI(model=model)
71
+ # Load .env if present, then require a valid API key
72
+ load_dotenv()
73
+ openai_api_key = os.getenv("OPENAI_API_KEY")
74
+ if not openai_api_key:
75
+ raise RuntimeError(
76
+ "OPENAI_API_KEY is not set. Provide it via environment or a .env file."
77
+ )
78
+ self.debug = debug or os.getenv("PLOT_AGENT_DEBUG") == "1"
79
+
80
+ # Configure logger
81
+ self._logger = logging.getLogger("plot_agent")
82
+ if self.debug:
83
+ self._logger.setLevel(logging.DEBUG)
84
+ if not self._logger.handlers:
85
+ handler = logging.StreamHandler()
86
+ handler.setFormatter(
87
+ logging.Formatter(
88
+ "%(asctime)s [%(levelname)s] %(name)s: %(message)s",
89
+ datefmt="%H:%M:%S",
90
+ )
91
+ )
92
+ self._logger.addHandler(handler)
93
+
94
+ # Initialize PostHog for LLM analytics (optional)
95
+ self.posthog_client = None
96
+ self.posthog_callback_handler = None
97
+ posthog_enabled = os.getenv("POSTHOG_ENABLED", "false").lower() == "true"
98
+
99
+ # Enable PostHog multimodal capture if include_plot_image is True
100
+ if include_plot_image and posthog_enabled:
101
+ os.environ["_INTERNAL_LLMA_MULTIMODAL"] = "true"
102
+
103
+ if posthog_enabled:
104
+ if not POSTHOG_AVAILABLE:
105
+ self._logger.warning(
106
+ "PostHog is enabled but the posthog package is not installed. "
107
+ "Install it with: pip install posthog"
108
+ )
109
+ else:
110
+ posthog_api_key = os.getenv("POSTHOG_API_KEY")
111
+ posthog_host = os.getenv("POSTHOG_HOST", "https://app.posthog.com")
112
+
113
+ if not posthog_api_key:
114
+ self._logger.warning(
115
+ "POSTHOG_ENABLED is true but POSTHOG_API_KEY is not set. "
116
+ "PostHog tracking will be disabled."
117
+ )
118
+ else:
119
+ try:
120
+ # Build super_properties for session tracking
121
+ super_properties = {}
122
+
123
+ # Add session ID from environment if provided
124
+ ai_session_id = os.getenv("POSTHOG_AI_SESSION_ID")
125
+ if ai_session_id:
126
+ super_properties["$ai_session_id"] = ai_session_id
127
+
128
+ # Initialize PostHog client with super_properties
129
+ self.posthog_client = Posthog(
130
+ posthog_api_key,
131
+ host=posthog_host,
132
+ super_properties=super_properties
133
+ )
134
+
135
+ # Build callback handler config
136
+ callback_config = {"client": self.posthog_client}
137
+
138
+ # Add optional distinct_id
139
+ distinct_id = os.getenv("POSTHOG_DISTINCT_ID")
140
+ if distinct_id:
141
+ callback_config["distinct_id"] = distinct_id
142
+
143
+ # Add privacy mode setting
144
+ privacy_mode = os.getenv("POSTHOG_PRIVACY_MODE", "false").lower() == "true"
145
+ callback_config["privacy_mode"] = privacy_mode
146
+
147
+ self.posthog_callback_handler = PostHogCallbackHandler(**callback_config)
148
+
149
+ if self.debug:
150
+ session_info = f"session_id={ai_session_id}" if ai_session_id else "no session"
151
+ self._logger.debug(
152
+ f"PostHog LLM analytics initialized (host={posthog_host}, "
153
+ f"distinct_id={distinct_id or 'anonymous'}, "
154
+ f"privacy_mode={privacy_mode}, {session_info})"
155
+ )
156
+ except Exception as e:
157
+ self._logger.error(f"Failed to initialize PostHog: {e}")
158
+ self.posthog_client = None
159
+ self.posthog_callback_handler = None
160
+
161
+ self.llm = ChatOpenAI(
162
+ model=model,
163
+ temperature=llm_temperature,
164
+ timeout=llm_timeout,
165
+ max_retries=llm_max_retries,
166
+ )
50
167
  self.df = None
51
168
  self.df_info = None
52
169
  self.df_head = None
53
170
  self.sql_query = None
54
171
  self.execution_env = None
55
172
  self.chat_history = []
173
+ # Internal graph-native message history, including tool messages
174
+ self._graph_messages = []
56
175
  self.agent_executor = None
57
176
  self.generated_code = None
58
177
  self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
@@ -60,6 +179,7 @@ class PlotAgent:
60
179
  self.max_iterations = max_iterations
61
180
  self.early_stopping_method = early_stopping_method
62
181
  self.handle_parsing_errors = handle_parsing_errors
182
+ self.include_plot_image = include_plot_image
63
183
 
64
184
  def set_df(self, df: pd.DataFrame, sql_query: Optional[str] = None):
65
185
  """
@@ -94,10 +214,16 @@ class PlotAgent:
94
214
  self.sql_query = sql_query
95
215
 
96
216
  # Initialize execution environment
97
- self.execution_env = PlotAgentExecutionEnvironment(df)
217
+ self.execution_env = PlotAgentExecutionEnvironment(
218
+ df, include_plot_image=self.include_plot_image
219
+ )
98
220
 
99
221
  # Initialize the agent with tools
100
222
  self._initialize_agent()
223
+ # Reset graph messages for a fresh session with this dataframe
224
+ self._graph_messages = []
225
+ if self.debug:
226
+ self._logger.debug("set_df() initialized execution environment and graph")
101
227
 
102
228
  def execute_plotly_code(self, generated_code: str) -> str:
103
229
  """
@@ -150,14 +276,52 @@ class PlotAgent:
150
276
  else:
151
277
  return "No figure has been created yet."
152
278
 
279
+ def check_plot_outputs(self, *args, **kwargs) -> str:
280
+ """
281
+ Check if all required plot outputs (fig, plot_title, plot_summary) are available.
282
+
283
+ Args:
284
+ *args: Any positional arguments (ignored)
285
+ **kwargs: Any keyword arguments (ignored)
286
+
287
+ Returns:
288
+ str: A message indicating which plot outputs are available.
289
+ """
290
+ if not self.execution_env:
291
+ return "No execution environment has been initialized. Please set a dataframe first."
292
+
293
+ available = []
294
+ missing = []
295
+
296
+ if self.execution_env.fig is not None:
297
+ available.append("fig")
298
+ else:
299
+ missing.append("fig")
300
+
301
+ if self.execution_env.plot_title is not None:
302
+ available.append("plot_title")
303
+ else:
304
+ missing.append("plot_title")
305
+
306
+ if self.execution_env.plot_summary is not None:
307
+ available.append("plot_summary")
308
+ else:
309
+ missing.append("plot_summary")
310
+
311
+ if not missing:
312
+ return "All required plot outputs are available: fig, plot_title, and plot_summary."
313
+ else:
314
+ status = f"Available: {', '.join(available) if available else 'none'}. Missing: {', '.join(missing)}."
315
+ return status
316
+
153
317
  def view_generated_code(self, *args, **kwargs) -> str:
154
318
  """
155
319
  View the generated code.
156
320
  """
157
- return self.generated_code
321
+ return self.generated_code or ""
158
322
 
159
323
  def _initialize_agent(self):
160
- """Initialize the LangChain agent with the necessary tools and prompt."""
324
+ """Initialize a LangGraph ReAct agent with tools and keep API compatibility."""
161
325
 
162
326
  # Initialize the tools
163
327
  tools = [
@@ -189,39 +353,43 @@ class PlotAgent:
189
353
  ),
190
354
  args_schema=ViewGeneratedCodeInput,
191
355
  ),
356
+ StructuredTool.from_function(
357
+ func=self.check_plot_outputs,
358
+ name="check_plot_outputs",
359
+ description=(
360
+ "Check if all required plot outputs (fig, plot_title, plot_summary) are available. "
361
+ "This tool takes no arguments and returns the status of all plot outputs."
362
+ ),
363
+ args_schema=CheckPlotOutputsInput,
364
+ ),
192
365
  ]
193
366
 
194
- # Create system prompt with dataframe information
367
+ # Prepare system prompt with dataframe information
195
368
  sql_context = ""
196
369
  if self.sql_query:
197
- sql_context = f"In case it is useful to help with the data understanding, the df was generated using the following SQL query:\n```sql\n{self.sql_query}\n```"
198
-
199
- prompt = ChatPromptTemplate.from_messages(
200
- [
201
- (
202
- "system",
203
- self.system_prompt.format(
204
- df_info=self.df_info,
205
- df_head=self.df_head,
206
- sql_context=sql_context,
207
- ),
208
- ),
209
- MessagesPlaceholder(variable_name="chat_history"),
210
- ("human", "{input}"),
211
- MessagesPlaceholder(variable_name="agent_scratchpad"),
212
- ]
370
+ sql_context = (
371
+ "In case it is useful to help with the data understanding, the df was generated using the following SQL query:\n"
372
+ f"```sql\n{self.sql_query}\n```"
373
+ )
374
+
375
+ # Store formatted system instructions for the graph state modifier
376
+ self._system_message_content = self.system_prompt.format(
377
+ df_info=self.df_info,
378
+ df_head=self.df_head,
379
+ sql_context=sql_context,
213
380
  )
214
381
 
215
- agent = create_openai_tools_agent(self.llm, tools, prompt)
216
- self.agent_executor = AgentExecutor(
217
- agent=agent,
218
- tools=tools,
219
- verbose=self.verbose,
220
- max_iterations=self.max_iterations,
221
- early_stopping_method=self.early_stopping_method,
222
- handle_parsing_errors=self.handle_parsing_errors,
382
+ # Create a ReAct agent graph with the provided tools and system prompt
383
+ self._graph = create_react_agent(
384
+ self.llm,
385
+ tools,
386
+ prompt=self._system_message_content,
387
+ debug=self.debug,
223
388
  )
224
389
 
390
+ # Backwards-compatibility: expose under the old attribute name
391
+ self.agent_executor = self._graph
392
+
225
393
  def process_message(self, user_message: str) -> str:
226
394
  """Process a user message and return the agent's response."""
227
395
  assert isinstance(user_message, str), "The user message must be a string."
@@ -229,33 +397,156 @@ class PlotAgent:
229
397
  if not self.agent_executor:
230
398
  return "Please set a dataframe first using set_df() method."
231
399
 
232
- # Add user message to chat history
400
+ # Add user message to outward-facing chat history
233
401
  self.chat_history.append(HumanMessage(content=user_message))
234
402
 
235
403
  # Reset generated_code
236
404
  self.generated_code = None
237
405
 
238
- # Get response from agent
239
- response = self.agent_executor.invoke(
240
- {"input": user_message, "chat_history": self.chat_history}
406
+ # Short-circuit empty inputs to avoid graph recursion
407
+ if user_message.strip() == "":
408
+ ai_content = (
409
+ "Please provide a non-empty plotting request (e.g., 'scatter x vs y')."
410
+ )
411
+ self.chat_history.append(AIMessage(content=ai_content))
412
+ if self.debug:
413
+ self._logger.debug("empty message received; returning guidance without invoking graph")
414
+ return ai_content
415
+
416
+ # Short-circuit messages that are primarily raw code blocks without a visualization request
417
+ if "```" in user_message and not re.search(
418
+ r"\b(plot|chart|graph|visuali(s|z)e|figure|subplot|heatmap|bar|line|scatter)\b",
419
+ user_message,
420
+ flags=re.IGNORECASE,
421
+ ):
422
+ ai_content = (
423
+ "I see a code snippet. Please describe the visualization you want (e.g., 'line chart of y over x')."
424
+ )
425
+ self.chat_history.append(AIMessage(content=ai_content))
426
+ if self.debug:
427
+ self._logger.debug("code-only message received; returning guidance without invoking graph")
428
+ return ai_content
429
+
430
+ # Build graph messages (includes tool call/observation history)
431
+ graph_messages = [*self._graph_messages, HumanMessage(content=user_message)]
432
+ if self.debug:
433
+ self._logger.debug(f"process_message() user: {user_message}")
434
+ self._logger.debug(f"graph message count before invoke: {len(graph_messages)}")
435
+
436
+ # Build config with optional PostHog callback
437
+ invoke_config = {"recursion_limit": self.max_iterations}
438
+ if self.posthog_callback_handler:
439
+ invoke_config["callbacks"] = [self.posthog_callback_handler]
440
+
441
+ # Invoke the LangGraph agent
442
+ result = self.agent_executor.invoke(
443
+ {"messages": graph_messages},
444
+ config=invoke_config,
241
445
  )
242
446
 
243
- # Add agent response to chat history
244
- self.chat_history.append(AIMessage(content=response["output"]))
245
-
246
- # If the agent didn't execute the code, but did generate code, execute it directly
247
- if self.execution_env.fig is None and self.generated_code is not None:
248
- self.execution_env.execute_code(self.generated_code)
249
-
250
- # If we can extract code from the response when no code was executed, try that too
251
- if self.execution_env.fig is None and "```python" in response["output"]:
252
- code_blocks = response["output"].split("```python")
253
- if len(code_blocks) > 1:
254
- generated_code = code_blocks[1].split("```")[0].strip()
255
- self.execution_env.execute_code(generated_code)
256
-
257
- # Return the agent's response
258
- return response["output"]
447
+ # Extract the latest AI message from the returned messages
448
+ ai_messages = [m for m in result.get("messages", []) if isinstance(m, AIMessage)]
449
+ ai_content = ai_messages[-1].content if ai_messages else ""
450
+
451
+ # Persist full graph messages for future context
452
+ self._graph_messages = result.get("messages", [])
453
+ if self.debug:
454
+ self._logger.debug(f"graph message count after invoke: {len(self._graph_messages)}")
455
+
456
+ # Add agent response to outward-facing chat history
457
+ self.chat_history.append(AIMessage(content=ai_content))
458
+
459
+ # If the agent didn't execute the code via tool, but we have prior generated_code, execute it
460
+ if self.execution_env and self.execution_env.fig is None and self.generated_code is not None:
461
+ if self.debug:
462
+ self._logger.debug("executing stored generated_code because no fig exists yet")
463
+ exec_result = self.execution_env.execute_code(self.generated_code)
464
+ if self.debug:
465
+ self._logger.debug(f"execution result success={exec_result.get('success')} error={exec_result.get('error')!r}")
466
+
467
+ # If the assistant returned code in the message, execute it to update the figure
468
+ code_executed = False
469
+ if self.execution_env and isinstance(ai_content, str):
470
+ extracted_code = None
471
+ if "```python" in ai_content:
472
+ parts = ai_content.split("```python", 1)
473
+ extracted_code = parts[1].split("```", 1)[0].strip() if len(parts) > 1 else None
474
+ elif "```" in ai_content:
475
+ # Fallback: extract first generic fenced code block
476
+ parts = ai_content.split("```", 1)
477
+ if len(parts) > 1:
478
+ extracted_code = parts[1].split("```", 1)[0].strip()
479
+ if extracted_code:
480
+ if (self.generated_code or "").strip() != extracted_code:
481
+ self.generated_code = extracted_code
482
+ if self.debug:
483
+ self._logger.debug("executing code extracted from AI message")
484
+ exec_result = self.execution_env.execute_code(extracted_code)
485
+ if self.debug:
486
+ self._logger.debug(f"execution result success={exec_result.get('success')} error={exec_result.get('error')!r}")
487
+ code_executed = True
488
+
489
+ # If still no figure and no code was executed, run one guided retry to force tool usage
490
+ if self.execution_env and self.execution_env.fig is None and not code_executed:
491
+ if self.debug:
492
+ self._logger.debug("guided retry: prompting model to use execute_plotly_code tool")
493
+ guided_messages = [
494
+ *self._graph_messages,
495
+ HumanMessage(
496
+ content=(
497
+ "Please use the execute_plotly_code(generated_code) tool with the FULL code to "
498
+ "create a variable named 'fig', then call does_fig_exist(). Return the final "
499
+ "code in a fenced ```python block."
500
+ )
501
+ ),
502
+ ]
503
+ # Build config with optional PostHog callback for retry
504
+ retry_config = {"recursion_limit": max(3, self.max_iterations // 2)}
505
+ if self.posthog_callback_handler:
506
+ retry_config["callbacks"] = [self.posthog_callback_handler]
507
+
508
+ retry_result = self.agent_executor.invoke(
509
+ {"messages": guided_messages},
510
+ config=retry_config,
511
+ )
512
+ self._graph_messages = retry_result.get("messages", [])
513
+ retry_ai_messages = [
514
+ m for m in self._graph_messages if isinstance(m, AIMessage)
515
+ ]
516
+ retry_content = retry_ai_messages[-1].content if retry_ai_messages else ""
517
+ if isinstance(retry_content, str):
518
+ if "```python" in retry_content:
519
+ parts = retry_content.split("```python", 1)
520
+ retry_code = (
521
+ parts[1].split("```", 1)[0].strip() if len(parts) > 1 else None
522
+ )
523
+ elif "```" in retry_content:
524
+ parts = retry_content.split("```", 1)
525
+ retry_code = (
526
+ parts[1].split("```", 1)[0].strip() if len(parts) > 1 else None
527
+ )
528
+ else:
529
+ retry_code = None
530
+ if retry_code:
531
+ if (self.generated_code or "").strip() != retry_code:
532
+ self.generated_code = retry_code
533
+ if self.debug:
534
+ self._logger.debug("executing code extracted from guided retry response")
535
+ exec_result = self.execution_env.execute_code(retry_code)
536
+ if self.debug:
537
+ self._logger.debug(f"execution result success={exec_result.get('success')} error={exec_result.get('error')!r}")
538
+
539
+ # Run verification step with image if enabled and figure exists
540
+ # This sends the plot image to the LLM for verification, which gets captured by PostHog
541
+ if (
542
+ self.include_plot_image
543
+ and self.posthog_callback_handler
544
+ and self.execution_env
545
+ and self.execution_env.fig is not None
546
+ ):
547
+ self._verify_plot_with_image(user_message)
548
+
549
+ return ai_content if isinstance(ai_content, str) else str(ai_content)
259
550
 
260
551
  def get_figure(self):
261
552
  """Return the current figure if one exists."""
@@ -263,7 +554,100 @@ class PlotAgent:
263
554
  return self.execution_env.fig
264
555
  return None
265
556
 
557
+ def get_plot_title(self):
558
+ """Return the current plot title if one exists."""
559
+ if self.execution_env and self.execution_env.plot_title:
560
+ return self.execution_env.plot_title
561
+ return None
562
+
563
+ def get_plot_summary(self):
564
+ """Return the current plot summary if one exists."""
565
+ if self.execution_env and self.execution_env.plot_summary:
566
+ return self.execution_env.plot_summary
567
+ return None
568
+
569
+ def get_plot_image_base64(self) -> Optional[str]:
570
+ """Return the current plot image as base64-encoded data URI."""
571
+ if self.execution_env and self.execution_env.plot_image_base64:
572
+ return self.execution_env.plot_image_base64
573
+ return None
574
+
575
+ def _verify_plot_with_image(self, user_request: str) -> Optional[str]:
576
+ """
577
+ Send the generated plot image to the LLM for verification.
578
+
579
+ This step serves two purposes:
580
+ 1. Verifies the plot matches the user's request
581
+ 2. Captures the plot image in PostHog LLM traces (via multimodal message)
582
+
583
+ Args:
584
+ user_request: The original user request for context.
585
+
586
+ Returns:
587
+ The LLM's verification response, or None if verification fails.
588
+ """
589
+ plot_image = self.get_plot_image_base64()
590
+ if not plot_image:
591
+ return None
592
+
593
+ # Build multimodal message with the plot image
594
+ human_content: List[Union[dict, str]] = [
595
+ {
596
+ "type": "text",
597
+ "text": (
598
+ f"Please verify this generated plot matches the user's request.\n\n"
599
+ f"User request: {user_request}\n\n"
600
+ f"Generated code:\n```python\n{self.generated_code or 'N/A'}\n```\n\n"
601
+ f"Plot title: {self.get_plot_title() or 'N/A'}\n"
602
+ f"Plot summary: {self.get_plot_summary() or 'N/A'}\n\n"
603
+ f"Respond with a brief confirmation that the plot is correct, "
604
+ f"or note any issues you see."
605
+ ),
606
+ },
607
+ {
608
+ "type": "image_url",
609
+ "image_url": {"url": plot_image},
610
+ },
611
+ ]
612
+
613
+ messages = [
614
+ SystemMessage(
615
+ content=(
616
+ "You are a plot verification assistant. "
617
+ "Review the generated plot image and verify it matches the user's request. "
618
+ "Be concise in your response."
619
+ )
620
+ ),
621
+ HumanMessage(content=human_content),
622
+ ]
623
+
624
+ try:
625
+ # Build config with PostHog callback to capture this verification step
626
+ invoke_config = {}
627
+ if self.posthog_callback_handler:
628
+ invoke_config["callbacks"] = [self.posthog_callback_handler]
629
+
630
+ if self.debug:
631
+ self._logger.debug("Running plot verification with image")
632
+
633
+ # Call LLM directly (not through agent) for verification
634
+ response = self.llm.invoke(messages, config=invoke_config)
635
+ verification_result = response.content if hasattr(response, "content") else str(response)
636
+
637
+ if self.debug:
638
+ self._logger.debug(f"Plot verification result: {verification_result[:200]}...")
639
+
640
+ return verification_result
641
+ except Exception as e:
642
+ self._logger.warning(f"Plot verification failed: {e}")
643
+ return None
644
+
266
645
  def reset_conversation(self):
267
646
  """Reset the conversation history."""
268
647
  self.chat_history = []
269
648
  self.generated_code = None
649
+ if self.execution_env:
650
+ self.execution_env.fig = None
651
+ self.execution_env.plot_title = None
652
+ self.execution_env.plot_summary = None
653
+ self.execution_env.plot_image_base64 = None
plot_agent/execution.py CHANGED
@@ -9,11 +9,15 @@ Security features:
9
9
  • Enforce a 60 second timeout via signal.alarm
10
10
  """
11
11
  import ast
12
+ import base64
12
13
  import builtins
14
+ import logging
13
15
  import signal
16
+ import threading
14
17
  import traceback
15
18
  from io import StringIO
16
19
  import contextlib
20
+ from typing import Optional
17
21
 
18
22
  import pandas as pd
19
23
  import numpy as np
@@ -110,11 +114,17 @@ class PlotAgentExecutionEnvironment:
110
114
  "__import__": _safe_import,
111
115
  }
112
116
 
113
- def __init__(self, df: pd.DataFrame):
117
+ def __init__(self, df: pd.DataFrame, include_plot_image: bool = False):
114
118
  """
115
119
  Initialize the execution environment with a dataframe.
120
+
121
+ Args:
122
+ df: The pandas dataframe to use for plotting.
123
+ include_plot_image: If True, generate a PNG image of the plot after execution.
116
124
  """
117
125
  self.df = df
126
+ self.include_plot_image = include_plot_image
127
+ self._logger = logging.getLogger("plot_agent.execution")
118
128
  # Base namespace for both globals & locals
119
129
  self._base_ns = {
120
130
  "__builtins__": self._SAFE_BUILTINS,
@@ -127,6 +137,30 @@ class PlotAgentExecutionEnvironment:
127
137
  "make_subplots": make_subplots,
128
138
  }
129
139
  self.fig = None
140
+ self.plot_title = None
141
+ self.plot_summary = None
142
+ self.plot_image_base64 = None
143
+
144
+ def _generate_plot_png(self, fig, width: int = 800, height: int = 600) -> Optional[str]:
145
+ """
146
+ Generate PNG as base64 data URI from a Plotly figure.
147
+
148
+ Args:
149
+ fig: The Plotly figure to convert.
150
+ width: Image width in pixels.
151
+ height: Image height in pixels.
152
+
153
+ Returns:
154
+ Base64-encoded data URI string, or None if generation fails.
155
+ """
156
+ try:
157
+ import plotly.io as pio
158
+ img_bytes = pio.to_image(fig, format='png', width=width, height=height)
159
+ b64 = base64.b64encode(img_bytes).decode('utf-8')
160
+ return f"data:image/png;base64,{b64}"
161
+ except Exception as e:
162
+ self._logger.warning(f"Failed to generate plot PNG: {e}")
163
+ return None
130
164
 
131
165
  def _validate_ast(self, node: ast.AST):
132
166
  """
@@ -166,8 +200,10 @@ class PlotAgentExecutionEnvironment:
166
200
 
167
201
  # Copy the base namespace
168
202
  ns = self._base_ns.copy()
169
- # Purge any old `fig`
203
+ # Purge any old variables
170
204
  ns.pop("fig", None)
205
+ ns.pop("plot_title", None)
206
+ ns.pop("plot_summary", None)
171
207
 
172
208
  try:
173
209
  # Parse the generated code
@@ -178,14 +214,23 @@ class PlotAgentExecutionEnvironment:
178
214
  # If the code is rejected on safety grounds, return an error
179
215
  return {
180
216
  "fig": None,
217
+ "plot_title": None,
218
+ "plot_summary": None,
181
219
  "output": "",
182
220
  "error": f"Code rejected on safety grounds: {e}",
183
221
  "success": False,
184
222
  }
185
223
 
186
- # Set a timeout
187
- signal.signal(signal.SIGALRM, _timeout_handler)
188
- signal.alarm(self.TIMEOUT_SECONDS)
224
+ # Set a timeout only if running on the main thread; signals are not supported in worker threads
225
+ timeout_set = False
226
+ try:
227
+ if threading.current_thread() is threading.main_thread():
228
+ signal.signal(signal.SIGALRM, _timeout_handler)
229
+ signal.alarm(self.TIMEOUT_SECONDS)
230
+ timeout_set = True
231
+ except Exception:
232
+ # If setting the signal handler fails (e.g., not in main thread), proceed without timeout
233
+ timeout_set = False
189
234
 
190
235
  # Execute the code
191
236
  out_buf, err_buf = StringIO(), StringIO()
@@ -201,6 +246,8 @@ class PlotAgentExecutionEnvironment:
201
246
  tb = traceback.format_exc()
202
247
  return {
203
248
  "fig": None,
249
+ "plot_title": None,
250
+ "plot_summary": None,
204
251
  "output": out_buf.getvalue(),
205
252
  "error": f"Code execution timed out: {te}\n{tb}",
206
253
  "success": False,
@@ -210,29 +257,78 @@ class PlotAgentExecutionEnvironment:
210
257
  tb = traceback.format_exc()
211
258
  return {
212
259
  "fig": None,
260
+ "plot_title": None,
261
+ "plot_summary": None,
213
262
  "output": out_buf.getvalue(),
214
263
  "error": f"Error executing code: {e}\n{tb}",
215
264
  "success": False,
216
265
  }
217
266
  finally:
218
- # Reset the timeout
219
- signal.alarm(0)
267
+ # Reset the timeout if it was set
268
+ if timeout_set:
269
+ try:
270
+ signal.alarm(0)
271
+ except Exception:
272
+ pass
220
273
 
221
- # Get the `fig`
274
+ # Get the variables
222
275
  fig = ns.get("fig")
276
+ plot_title = ns.get("plot_title")
277
+ plot_summary = ns.get("plot_summary")
278
+
279
+ # Store the variables
223
280
  self.fig = fig
281
+ self.plot_title = plot_title
282
+ self.plot_summary = plot_summary
283
+
284
+ # Generate PNG if enabled and figure exists
285
+ self.plot_image_base64 = None
286
+ if self.include_plot_image and fig is not None:
287
+ self.plot_image_base64 = self._generate_plot_png(fig)
288
+
289
+ # Validate required variables
290
+ missing_vars = []
224
291
  if fig is None:
292
+ missing_vars.append("fig")
293
+ if plot_title is None:
294
+ missing_vars.append("plot_title")
295
+ if plot_summary is None:
296
+ missing_vars.append("plot_summary")
297
+
298
+ if missing_vars:
225
299
  return {
226
- "fig": None,
300
+ "fig": fig,
301
+ "plot_title": plot_title,
302
+ "plot_summary": plot_summary,
303
+ "output": out_buf.getvalue(),
304
+ "error": f"Missing required variables: {', '.join(missing_vars)}. Please create variables named: {', '.join(missing_vars)}.",
305
+ "success": False,
306
+ }
307
+
308
+ # Validate that plot_title and plot_summary are strings
309
+ validation_errors = []
310
+ if not isinstance(plot_title, str):
311
+ validation_errors.append("plot_title must be a string")
312
+ if not isinstance(plot_summary, str):
313
+ validation_errors.append("plot_summary must be a string")
314
+
315
+ if validation_errors:
316
+ return {
317
+ "fig": fig,
318
+ "plot_title": plot_title,
319
+ "plot_summary": plot_summary,
227
320
  "output": out_buf.getvalue(),
228
- "error": "No `fig` created. Assign your figure to a variable named `fig`.",
321
+ "error": f"Validation errors: {'; '.join(validation_errors)}.",
229
322
  "success": False,
230
323
  }
231
324
 
232
325
  # Return the result
233
326
  return {
234
327
  "fig": fig,
235
- "output": "Code executed successfully. 'fig' object was created.",
328
+ "plot_title": plot_title,
329
+ "plot_summary": plot_summary,
330
+ "plot_image_base64": self.plot_image_base64,
331
+ "output": "Code executed successfully. 'fig', 'plot_title', and 'plot_summary' objects were created.",
236
332
  "error": "",
237
333
  "success": True,
238
334
  }
plot_agent/models.py CHANGED
@@ -31,3 +31,9 @@ class ViewGeneratedCodeInput(BaseModel):
31
31
  """Model indicating that the view_generated_code function takes no arguments."""
32
32
 
33
33
  pass
34
+
35
+
36
+ class CheckPlotOutputsInput(BaseModel):
37
+ """Model indicating that the check_plot_outputs function takes no arguments."""
38
+
39
+ pass
plot_agent/prompt.py CHANGED
@@ -26,10 +26,12 @@ NOTES:
26
26
  - You must paste the full code, not just a reference to the code.
27
27
  - You must not use fig.show() in your code as it will ultimately be executed elsewhere in a headless environment.
28
28
  - If you need to do any data cleaning or wrangling, do it in the code before generating the plotly code as preprocessing steps assume the data is in the pandas 'df' object.
29
+ - Your code MUST create three variables: 'fig' (Plotly figure), 'plot_title' (string), and 'plot_summary' (string).
29
30
 
30
31
  TOOLS:
31
32
  - execute_plotly_code(generated_code) to execute the generated code.
32
33
  - does_fig_exist() to check that a fig object is available for display. This tool takes no arguments.
34
+ - check_plot_outputs() to check if all required outputs (fig, plot_title, plot_summary) are available. This tool takes no arguments.
33
35
  - view_generated_code() to view the generated code if need to fix it. This tool takes no arguments.
34
36
 
35
37
  IMPORTANT CODE FORMATTING INSTRUCTIONS:
@@ -37,6 +39,9 @@ IMPORTANT CODE FORMATTING INSTRUCTIONS:
37
39
  2. Use descriptive variable names.
38
40
  3. DO NOT include fig.show() in your code - the visualization will be rendered externally.
39
41
  4. Ensure your code creates a variable named 'fig' that contains the Plotly figure object.
42
+ 5. You MUST also create two string variables:
43
+ - 'plot_title': A concise, descriptive title for the plot (string)
44
+ - 'plot_summary': A brief summary explaining what the plot shows and any key insights (string)
40
45
 
41
46
  When a user asks for a visualization:
42
47
  1. YOU MUST ALWAYS use the execute_plotly_code(generated_code) tool to test and run your code.
@@ -48,11 +53,11 @@ IMPORTANT: The code you generate MUST be executed using the execute_plotly_code
48
53
  YOU MUST CALL execute_plotly_code WITH THE FULL CODE, NOT JUST A REFERENCE TO THE CODE.
49
54
 
50
55
  YOUR WORKFLOW MUST BE:
51
- 1. execute_plotly_code(generated_code) to make sure the code is ran and a figure object is created.
52
- 2. check that a figure object is available using does_fig_exist() to make sure the figure object was created.
53
- 3. if there are errors, view the generated code using view_generated_code() to see what went wrong.
54
- 4. fix the code and execute it again with execute_plotly_code(generated_code) to make sure the figure object is created.
55
- 5. repeat until the figure object is available.
56
+ 1. execute_plotly_code(generated_code) to make sure the code is ran and a figure object, plot_title, and plot_summary are created.
57
+ 2. use check_plot_outputs() to verify that all required outputs (fig, plot_title, plot_summary) are available.
58
+ 3. if there are errors or missing outputs, view the generated code using view_generated_code() to see what went wrong.
59
+ 4. fix the code and execute it again with execute_plotly_code(generated_code) to make sure all required outputs are created.
60
+ 5. repeat until all outputs (figure object, plot_title, and plot_summary) are available.
56
61
 
57
62
  Always return the final working code (with all the comments) to the user along with an explanation of what the visualization shows.
58
63
  Make sure to follow best practices for data visualization, such as appropriate chart types, labels, and colors.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plot-agent
3
- Version: 0.3.1
3
+ Version: 0.5.0
4
4
  Summary: An AI-powered data visualization assistant using Plotly
5
5
  Author-email: andrewm4894 <andrewm4894@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/andrewm4894/plot-agent
@@ -19,6 +19,8 @@ Dynamic: license-file
19
19
 
20
20
  An AI-powered data visualization assistant that helps users create Plotly visualizations in Python.
21
21
 
22
+ Built on LangGraph with tool-calling to reliably execute generated Plotly code in a sandbox and keep the current `fig` in sync.
23
+
22
24
  ## Installation
23
25
 
24
26
  You can install the package using pip:
@@ -37,7 +39,7 @@ Here's a simple minimal example of how to use Plot Agent:
37
39
  import pandas as pd
38
40
  from plot_agent.agent import PlotAgent
39
41
 
40
- # ensure OPENAI_API_KEY is set and available for langchain
42
+ # ensure OPENAI_API_KEY is set (env or .env); optional debug via PLOT_AGENT_DEBUG=1
41
43
 
42
44
  # Create a sample dataframe
43
45
  df = pd.DataFrame({
@@ -92,19 +94,72 @@ fig.update_layout(
92
94
  )
93
95
  ```
94
96
 
97
+ ## How it works
98
+
99
+ ```mermaid
100
+ flowchart TD
101
+ A[User message] --> B{LangGraph ReAct Agent}
102
+ subgraph Tools
103
+ T1[execute_plotly_code<br/>- runs code in sandbox<br/>- returns success/fig/error]
104
+ T2[does_fig_exist]
105
+ T3[view_generated_code]
106
+ end
107
+ B -- tool call --> T1
108
+ T1 -- result --> B
109
+ B -- optional --> T2
110
+ B -- optional --> T3
111
+ B --> C[AI response]
112
+ C --> D{Agent wrapper}
113
+ D -- persist messages --> B
114
+ D -- extract code blocks --> E[Sandbox execution]
115
+ E --> F[fig]
116
+ F --> G[get_figure]
117
+ ```
118
+
119
+ - The LangGraph agent plans and decides when to call tools.
120
+ - The wrapper persists full graph messages between turns and executes any returned code blocks to keep `fig` updated.
121
+ - A safe execution environment runs code with an allowlist and a main-thread-only timeout.
122
+
95
123
  ## Features
96
124
 
97
125
  - AI-powered visualization generation
98
126
  - Support for various Plotly chart types
99
127
  - Automatic data preprocessing
100
128
  - Interactive visualization capabilities
101
- - Integration with LangChain for advanced AI capabilities
129
+ - LangGraph-based tool calling and control flow
130
+ - Debug logging via `PlotAgent(debug=True)` or `PLOT_AGENT_DEBUG=1`
102
131
 
103
132
  ## Requirements
104
133
 
105
134
  - Python 3.8 or higher
106
135
  - Dependencies are automatically installed with the package
107
136
 
137
+ ## Development
138
+
139
+ - Run unit tests:
140
+
141
+ ```bash
142
+ make test
143
+ ```
144
+
145
+ - Execute all example notebooks:
146
+
147
+ ```bash
148
+ make run-examples
149
+ ```
150
+
151
+ - Execute with debug logs enabled:
152
+
153
+ ```bash
154
+ make run-examples-debug
155
+ ```
156
+
157
+ - Quick CLI repro that prints evolving code each step:
158
+
159
+ ```bash
160
+ make run-example-script
161
+ ```
162
+
108
163
  ## License
109
164
 
110
165
  This project is licensed under the MIT License - see the LICENSE file for details.
@@ -0,0 +1,10 @@
1
+ plot_agent/__init__.py,sha256=hDmKvr4Hqe-NWpSGX8B36eWjBBkxq-JChIep4gN9pXc,123
2
+ plot_agent/agent.py,sha256=VcctFGizXPn5E3l-fD0QK_LHP8rXE6jjpZv65WE6lXY,27063
3
+ plot_agent/execution.py,sha256=oeiYy7dLF5krOK_vt0zoNB5C7mbliAQq9veVUYtBGfo,11350
4
+ plot_agent/models.py,sha256=0ca5absK0L3kOyZhkVm0V-iyVVBEQ7cw_DCililYTzg,996
5
+ plot_agent/prompt.py,sha256=WdXDhoTLdYZSmCay5K-eBLrJZrRHqwyDxfYrpsxUa4k,3684
6
+ plot_agent-0.5.0.dist-info/licenses/LICENSE,sha256=A4DPih7wHrh4VMEG3p1PhorqdhjmGIo8nQdYNQL7daA,1062
7
+ plot_agent-0.5.0.dist-info/METADATA,sha256=1sY052ANI3VRnsY794K7W2P2C1UiBs1RWFDxVDOpem8,4152
8
+ plot_agent-0.5.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
9
+ plot_agent-0.5.0.dist-info/top_level.txt,sha256=KyOjpihUssx26Ra-37vKUQ71pI2qgJsHaRwXHJUhjzQ,11
10
+ plot_agent-0.5.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.1)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,10 +0,0 @@
1
- plot_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- plot_agent/agent.py,sha256=sIG8GMS2A8TP_3kRxbgefn-yZM4K_7niZQR6cJrhl4s,9872
3
- plot_agent/execution.py,sha256=lQNyPzphPIdMQXxQkaf_g6oDZsU3dgF0or0ysKJm6FM,7537
4
- plot_agent/models.py,sha256=THdGGGfGmRZ5rtgXvjPcQxFRRTZVFoADEHI_lsMVha8,860
5
- plot_agent/prompt.py,sha256=5hBlF7jdMrj6MiGEL7YmSDWFUfiCXyIZfZtf3NstKoo,3125
6
- plot_agent-0.3.1.dist-info/licenses/LICENSE,sha256=A4DPih7wHrh4VMEG3p1PhorqdhjmGIo8nQdYNQL7daA,1062
7
- plot_agent-0.3.1.dist-info/METADATA,sha256=zkpeWRWczA_CzH7mahNtEuIvumBOhXaNGTiAcUIOQZQ,2837
8
- plot_agent-0.3.1.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
9
- plot_agent-0.3.1.dist-info/top_level.txt,sha256=KyOjpihUssx26Ra-37vKUQ71pI2qgJsHaRwXHJUhjzQ,11
10
- plot_agent-0.3.1.dist-info/RECORD,,