ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -17,25 +17,367 @@ from langgraph.types import Command
17
17
  from langgraph.checkpoint.memory import MemorySaver
18
18
 
19
19
  import os
20
- import io
21
20
  import pandas as pd
22
21
 
22
+ from IPython.display import Markdown
23
+
23
24
  from ai_data_science_team.templates import(
24
25
  node_func_execute_agent_code_on_data,
25
26
  node_func_human_review,
26
27
  node_func_fix_agent_code,
27
28
  node_func_explain_agent_code,
28
- create_coding_agent_graph
29
+ create_coding_agent_graph,
30
+ BaseAgent,
29
31
  )
30
32
  from ai_data_science_team.tools.parsers import PythonOutputParser
31
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
33
+ from ai_data_science_team.tools.regex import (
34
+ relocate_imports_inside_function,
35
+ add_comments_to_top,
36
+ format_agent_name,
37
+ format_recommended_steps
38
+ )
32
39
  from ai_data_science_team.tools.metadata import get_dataframe_summary
33
40
  from ai_data_science_team.tools.logging import log_ai_function
41
+ from ai_data_science_team.utils.plotly import plotly_from_dict
34
42
 
35
43
  # Setup
36
44
  AGENT_NAME = "data_visualization_agent"
37
45
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
38
46
 
47
+ # Class
48
+
49
+ class DataVisualizationAgent(BaseAgent):
50
+ """
51
+ Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
52
+ default visualization steps (if any). The agent generates a Python function to produce the visualization,
53
+ executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
54
+ and customizable data visualization workflows.
55
+
56
+ The agent may use default instructions for creating charts unless instructed otherwise, such as:
57
+ - Generating a recommended chart type (bar, scatter, line, etc.)
58
+ - Creating user-friendly titles and axis labels
59
+ - Applying consistent styling (template, font sizes, color themes)
60
+ - Handling theme details (white background, base font size, line size, etc.)
61
+
62
+ User instructions can modify, add, or remove any of these steps to tailor the visualization process.
63
+
64
+ Parameters
65
+ ----------
66
+ model : langchain.llms.base.LLM
67
+ The language model used to generate the data visualization function.
68
+ n_samples : int, optional
69
+ Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
70
+ Reducing this number can help avoid exceeding the model's token limits.
71
+ log : bool, optional
72
+ Whether to log the generated code and errors. Defaults to False.
73
+ log_path : str, optional
74
+ Directory path for storing log files. Defaults to None.
75
+ file_name : str, optional
76
+ Name of the file for saving the generated response. Defaults to "data_visualization.py".
77
+ function_name : str, optional
78
+ Name of the function for data visualization. Defaults to "data_visualization".
79
+ overwrite : bool, optional
80
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
81
+ human_in_the_loop : bool, optional
82
+ Enables user review of data visualization instructions. Defaults to False.
83
+ bypass_recommended_steps : bool, optional
84
+ If True, skips the default recommended visualization steps. Defaults to False.
85
+ bypass_explain_code : bool, optional
86
+ If True, skips the step that provides code explanations. Defaults to False.
87
+
88
+ Methods
89
+ -------
90
+ update_params(**kwargs)
91
+ Updates the agent's parameters and rebuilds the compiled state graph.
92
+ ainvoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
93
+ Asynchronously generates a visualization based on user instructions.
94
+ invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
95
+ Synchronously generates a visualization based on user instructions.
96
+ explain_visualization_steps()
97
+ Returns an explanation of the visualization steps performed by the agent.
98
+ get_log_summary()
99
+ Retrieves a summary of logged operations if logging is enabled.
100
+ get_plotly_graph()
101
+ Retrieves the Plotly graph (as a dictionary) produced by the agent.
102
+ get_data_raw()
103
+ Retrieves the raw dataset as a pandas DataFrame (based on the last response).
104
+ get_data_visualization_function()
105
+ Retrieves the generated Python function used for data visualization.
106
+ get_recommended_visualization_steps()
107
+ Retrieves the agent's recommended visualization steps.
108
+ get_response()
109
+ Returns the response from the agent as a dictionary.
110
+ show()
111
+ Displays the agent's mermaid diagram.
112
+
113
+ Examples
114
+ --------
115
+ ```python
116
+ import pandas as pd
117
+ from langchain_openai import ChatOpenAI
118
+ from ai_data_science_team.agents import DataVisualizationAgent
119
+
120
+ llm = ChatOpenAI(model="gpt-4o-mini")
121
+
122
+ data_visualization_agent = DataVisualizationAgent(
123
+ model=llm,
124
+ n_samples=30,
125
+ log=True,
126
+ log_path="logs",
127
+ human_in_the_loop=True
128
+ )
129
+
130
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
131
+
132
+ data_visualization_agent.invoke_agent(
133
+ user_instructions="Generate a scatter plot of age vs. total charges with a trend line.",
134
+ data_raw=df,
135
+ max_retries=3,
136
+ retry_count=0
137
+ )
138
+
139
+ plotly_graph_dict = data_visualization_agent.get_plotly_graph()
140
+ # You can render plotly_graph_dict with plotly.io.from_json or
141
+ # something similar in a Jupyter Notebook.
142
+
143
+ response = data_visualization_agent.get_response()
144
+ ```
145
+
146
+ Returns
147
+ --------
148
+ DataVisualizationAgent : langchain.graphs.CompiledStateGraph
149
+ A data visualization agent implemented as a compiled state graph.
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ model,
155
+ n_samples=30,
156
+ log=False,
157
+ log_path=None,
158
+ file_name="data_visualization.py",
159
+ function_name="data_visualization",
160
+ overwrite=True,
161
+ human_in_the_loop=False,
162
+ bypass_recommended_steps=False,
163
+ bypass_explain_code=False
164
+ ):
165
+ self._params = {
166
+ "model": model,
167
+ "n_samples": n_samples,
168
+ "log": log,
169
+ "log_path": log_path,
170
+ "file_name": file_name,
171
+ "function_name": function_name,
172
+ "overwrite": overwrite,
173
+ "human_in_the_loop": human_in_the_loop,
174
+ "bypass_recommended_steps": bypass_recommended_steps,
175
+ "bypass_explain_code": bypass_explain_code,
176
+ }
177
+ self._compiled_graph = self._make_compiled_graph()
178
+ self.response = None
179
+
180
+ def _make_compiled_graph(self):
181
+ """
182
+ Create the compiled graph for the data visualization agent.
183
+ Running this method will reset the response to None.
184
+ """
185
+ self.response = None
186
+ return make_data_visualization_agent(**self._params)
187
+
188
+ def update_params(self, **kwargs):
189
+ """
190
+ Updates the agent's parameters and rebuilds the compiled graph.
191
+ """
192
+ # Update parameters
193
+ for k, v in kwargs.items():
194
+ self._params[k] = v
195
+ # Rebuild the compiled graph
196
+ self._compiled_graph = self._make_compiled_graph()
197
+
198
+ def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
199
+ """
200
+ Asynchronously invokes the agent to generate a visualization.
201
+ The response is stored in the 'response' attribute.
202
+
203
+ Parameters
204
+ ----------
205
+ data_raw : pd.DataFrame
206
+ The raw dataset to be visualized.
207
+ user_instructions : str
208
+ Instructions for data visualization.
209
+ max_retries : int
210
+ Maximum retry attempts.
211
+ retry_count : int
212
+ Current retry attempt count.
213
+ **kwargs : dict
214
+ Additional keyword arguments passed to ainvoke().
215
+
216
+ Returns
217
+ -------
218
+ None
219
+ """
220
+ response = self._compiled_graph.ainvoke({
221
+ "user_instructions": user_instructions,
222
+ "data_raw": data_raw.to_dict(),
223
+ "max_retries": max_retries,
224
+ "retry_count": retry_count,
225
+ }, **kwargs)
226
+ self.response = response
227
+ return None
228
+
229
+ def invoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
230
+ """
231
+ Synchronously invokes the agent to generate a visualization.
232
+ The response is stored in the 'response' attribute.
233
+
234
+ Parameters
235
+ ----------
236
+ data_raw : pd.DataFrame
237
+ The raw dataset to be visualized.
238
+ user_instructions : str
239
+ Instructions for data visualization agent.
240
+ max_retries : int
241
+ Maximum retry attempts.
242
+ retry_count : int
243
+ Current retry attempt count.
244
+ **kwargs : dict
245
+ Additional keyword arguments passed to invoke().
246
+
247
+ Returns
248
+ -------
249
+ None
250
+ """
251
+ response = self._compiled_graph.invoke({
252
+ "user_instructions": user_instructions,
253
+ "data_raw": data_raw.to_dict(),
254
+ "max_retries": max_retries,
255
+ "retry_count": retry_count,
256
+ }, **kwargs)
257
+ self.response = response
258
+ return None
259
+
260
+ def explain_visualization_steps(self):
261
+ """
262
+ Provides an explanation of the visualization steps performed by the agent.
263
+
264
+ Returns
265
+ -------
266
+ str
267
+ Explanation of the visualization steps, if any are available.
268
+ """
269
+ if self.response:
270
+ return self.response.get("messages", [])
271
+ return []
272
+
273
+ def get_log_summary(self, markdown=False):
274
+ """
275
+ Logs a summary of the agent's operations, if logging is enabled.
276
+
277
+ Parameters
278
+ ----------
279
+ markdown : bool, optional
280
+ If True, returns Markdown-formatted output.
281
+
282
+ Returns
283
+ -------
284
+ str or None
285
+ Summary of logs or None if no logs are available.
286
+ """
287
+ if self.response and self.response.get('data_visualization_function_path'):
288
+ log_details = f"Log Path: {self.response.get('data_visualization_function_path')}"
289
+ if markdown:
290
+ return Markdown(log_details)
291
+ else:
292
+ return log_details
293
+ return None
294
+
295
+ def get_plotly_graph(self):
296
+ """
297
+ Retrieves the Plotly graph (in dictionary form) produced by the agent.
298
+
299
+ Returns
300
+ -------
301
+ dict or None
302
+ The Plotly graph dictionary if available, otherwise None.
303
+ """
304
+ if self.response:
305
+ return plotly_from_dict(self.response.get("plotly_graph", None))
306
+ return None
307
+
308
+ def get_data_raw(self):
309
+ """
310
+ Retrieves the raw dataset used in the last invocation.
311
+
312
+ Returns
313
+ -------
314
+ pd.DataFrame or None
315
+ The raw dataset as a DataFrame if available, otherwise None.
316
+ """
317
+ if self.response and self.response.get("data_raw"):
318
+ return pd.DataFrame(self.response.get("data_raw"))
319
+ return None
320
+
321
+ def get_data_visualization_function(self, markdown=False):
322
+ """
323
+ Retrieves the generated Python function used for data visualization.
324
+
325
+ Parameters
326
+ ----------
327
+ markdown : bool, optional
328
+ If True, returns the function in Markdown code block format.
329
+
330
+ Returns
331
+ -------
332
+ str or None
333
+ The Python function code as a string if available, otherwise None.
334
+ """
335
+ if self.response:
336
+ func_code = self.response.get("data_visualization_function", "")
337
+ if markdown:
338
+ return Markdown(f"```python\n{func_code}\n```")
339
+ return func_code
340
+ return None
341
+
342
+ def get_recommended_visualization_steps(self, markdown=False):
343
+ """
344
+ Retrieves the agent's recommended visualization steps.
345
+
346
+ Parameters
347
+ ----------
348
+ markdown : bool, optional
349
+ If True, returns the steps in Markdown format.
350
+
351
+ Returns
352
+ -------
353
+ str or None
354
+ The recommended steps if available, otherwise None.
355
+ """
356
+ if self.response:
357
+ steps = self.response.get("recommended_steps", "")
358
+ if markdown:
359
+ return Markdown(steps)
360
+ return steps
361
+ return None
362
+
363
+ def get_response(self):
364
+ """
365
+ Returns the agent's full response dictionary.
366
+
367
+ Returns
368
+ -------
369
+ dict or None
370
+ The response dictionary if available, otherwise None.
371
+ """
372
+ return self.response
373
+
374
+ def show(self):
375
+ """
376
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
377
+ """
378
+ return self._compiled_graph.show()
379
+
380
+
39
381
  # Agent
40
382
 
41
383
  def make_data_visualization_agent(
@@ -44,14 +386,85 @@ def make_data_visualization_agent(
44
386
  log=False,
45
387
  log_path=None,
46
388
  file_name="data_visualization.py",
47
- overwrite = True,
389
+ function_name="data_visualization",
390
+ overwrite=True,
48
391
  human_in_the_loop=False,
49
392
  bypass_recommended_steps=False,
50
393
  bypass_explain_code=False
51
394
  ):
395
+ """
396
+ Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
397
+ default visualization steps. The agent generates a Python function to produce the visualization, executes it,
398
+ and logs the process, including code and errors. It is designed to facilitate reproducible and customizable
399
+ data visualization workflows.
400
+
401
+ The agent can perform the following default visualization steps unless instructed otherwise:
402
+ - Generating a recommended chart type (bar, scatter, line, etc.)
403
+ - Creating user-friendly titles and axis labels
404
+ - Applying consistent styling (template, font sizes, color themes)
405
+ - Handling theme details (white background, base font size, line size, etc.)
406
+
407
+ User instructions can modify, add, or remove any of these steps to tailor the visualization process.
408
+
409
+ Parameters
410
+ ----------
411
+ model : langchain.llms.base.LLM
412
+ The language model used to generate the data visualization function.
413
+ n_samples : int, optional
414
+ Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
415
+ log : bool, optional
416
+ Whether to log the generated code and errors. Defaults to False.
417
+ log_path : str, optional
418
+ Directory path for storing log files. Defaults to None.
419
+ file_name : str, optional
420
+ Name of the file for saving the generated response. Defaults to "data_visualization.py".
421
+ function_name : str, optional
422
+ Name of the function for data visualization. Defaults to "data_visualization".
423
+ overwrite : bool, optional
424
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
425
+ human_in_the_loop : bool, optional
426
+ Enables user review of data visualization instructions. Defaults to False.
427
+ bypass_recommended_steps : bool, optional
428
+ If True, skips the default recommended visualization steps. Defaults to False.
429
+ bypass_explain_code : bool, optional
430
+ If True, skips the step that provides code explanations. Defaults to False.
431
+
432
+ Examples
433
+ --------
434
+ ``` python
435
+ import pandas as pd
436
+ from langchain_openai import ChatOpenAI
437
+ from ai_data_science_team.agents import data_visualization_agent
438
+
439
+ llm = ChatOpenAI(model="gpt-4o-mini")
440
+
441
+ data_visualization_agent = make_data_visualization_agent(llm)
442
+
443
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
444
+
445
+ response = data_visualization_agent.invoke({
446
+ "user_instructions": "Generate a scatter plot of tenure vs. total charges with a trend line.",
447
+ "data_raw": df.to_dict(),
448
+ "max_retries": 3,
449
+ "retry_count": 0
450
+ })
451
+
452
+ pd.DataFrame(response['plotly_graph'])
453
+ ```
454
+
455
+ Returns
456
+ -------
457
+ app : langchain.graphs.CompiledStateGraph
458
+ The data visualization agent as a state graph.
459
+ """
52
460
 
53
461
  llm = model
54
462
 
463
+ # Human in th loop requires recommended steps
464
+ if bypass_recommended_steps and human_in_the_loop:
465
+ bypass_recommended_steps = False
466
+ print("Bypass recommended steps set to False to enable human in the loop.")
467
+
55
468
  # Setup Log Directory
56
469
  if log:
57
470
  if log_path is None:
@@ -70,6 +483,7 @@ def make_data_visualization_agent(
70
483
  all_datasets_summary: str
71
484
  data_visualization_function: str
72
485
  data_visualization_function_path: str
486
+ data_visualization_function_file_name: str
73
487
  data_visualization_function_name: str
74
488
  data_visualization_error: str
75
489
  max_retries: int
@@ -140,7 +554,7 @@ def make_data_visualization_agent(
140
554
  })
141
555
 
142
556
  return {
143
- "recommended_steps": "\n\n# Recommended Data Cleaning Steps:\n" + recommended_steps.content.strip(),
557
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Cleaning Steps:"),
144
558
  "all_datasets_summary": all_datasets_summary_str
145
559
  }
146
560
 
@@ -169,7 +583,7 @@ def make_data_visualization_agent(
169
583
  template="""
170
584
  You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
171
585
 
172
- Your job is to produce python code to generate visualizations.
586
+ Your job is to produce python code to generate visualizations with a function named {function_name}.
173
587
 
174
588
  You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
175
589
 
@@ -181,13 +595,13 @@ def make_data_visualization_agent(
181
595
 
182
596
  RETURN:
183
597
 
184
- Return Python code in ```python ``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
598
+ Return Python code in ```python ``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
185
599
 
186
600
  Return the plotly chart as a dictionary.
187
601
 
188
602
  Return code to provide the data visualization function:
189
603
 
190
- def data_visualization(data_raw):
604
+ def {function_name}(data_raw):
191
605
  import pandas as pd
192
606
  import numpy as np
193
607
  import json
@@ -206,14 +620,15 @@ def make_data_visualization_agent(
206
620
  2. Do not include unrelated user instructions that are not related to the chart generation.
207
621
 
208
622
  """,
209
- input_variables=["chart_generator_instructions", "all_datasets_summary"]
623
+ input_variables=["chart_generator_instructions", "all_datasets_summary", "function_name"]
210
624
  )
211
-
625
+
212
626
  data_visualization_agent = prompt_template | llm | PythonOutputParser()
213
627
 
214
628
  response = data_visualization_agent.invoke({
215
629
  "chart_generator_instructions": chart_generator_instructions,
216
- "all_datasets_summary": all_datasets_summary_str
630
+ "all_datasets_summary": all_datasets_summary_str,
631
+ "function_name": function_name
217
632
  })
218
633
 
219
634
  response = relocate_imports_inside_function(response)
@@ -231,19 +646,37 @@ def make_data_visualization_agent(
231
646
  return {
232
647
  "data_visualization_function": response,
233
648
  "data_visualization_function_path": file_path,
234
- "data_visualization_function_name": file_name_2,
649
+ "data_visualization_function_file_name": file_name_2,
650
+ "data_visualization_function_name": function_name,
235
651
  "all_datasets_summary": all_datasets_summary_str
236
652
  }
237
-
238
- def human_review(state: GraphState) -> Command[Literal["chart_instructor", "chart_generator"]]:
239
- return node_func_human_review(
240
- state=state,
241
- prompt_text="Is the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}",
242
- yes_goto="chart_generator",
243
- no_goto="chart_instructor",
244
- user_instructions_key="user_instructions",
245
- recommended_steps_key="recommended_steps"
246
- )
653
+
654
+ # Human Review
655
+
656
+ prompt_text_human_review = "Are the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
657
+
658
+ if not bypass_explain_code:
659
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "explain_data_visualization_code"]]:
660
+ return node_func_human_review(
661
+ state=state,
662
+ prompt_text=prompt_text_human_review,
663
+ yes_goto= 'explain_data_visualization_code',
664
+ no_goto="chart_instructor",
665
+ user_instructions_key="user_instructions",
666
+ recommended_steps_key="recommended_steps",
667
+ code_snippet_key="data_visualization_function",
668
+ )
669
+ else:
670
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "__end__"]]:
671
+ return node_func_human_review(
672
+ state=state,
673
+ prompt_text=prompt_text_human_review,
674
+ yes_goto= '__end__',
675
+ no_goto="chart_instructor",
676
+ user_instructions_key="user_instructions",
677
+ recommended_steps_key="recommended_steps",
678
+ code_snippet_key="data_visualization_function",
679
+ )
247
680
 
248
681
 
249
682
  def execute_data_visualization_code(state):
@@ -253,7 +686,7 @@ def make_data_visualization_agent(
253
686
  result_key="plotly_graph",
254
687
  error_key="data_visualization_error",
255
688
  code_snippet_key="data_visualization_function",
256
- agent_function_name="data_visualization",
689
+ agent_function_name=state.get("data_visualization_function_name"),
257
690
  pre_processing=lambda data: pd.DataFrame.from_dict(data),
258
691
  # post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
259
692
  error_message_prefix="An error occurred during data visualization: "
@@ -261,11 +694,11 @@ def make_data_visualization_agent(
261
694
 
262
695
  def fix_data_visualization_code(state: GraphState):
263
696
  prompt = """
264
- You are a Data Visualization Agent. Your job is to create a data_visualization() function that can be run on the data provided. The function is currently broken and needs to be fixed.
697
+ You are a Data Visualization Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
265
698
 
266
- Make sure to only return the function definition for data_visualization().
699
+ Make sure to only return the function definition for {function_name}().
267
700
 
268
- Return Python code in ```python``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
701
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
269
702
 
270
703
  This is the broken code (please fix):
271
704
  {code_snippet}
@@ -283,6 +716,7 @@ def make_data_visualization_agent(
283
716
  agent_name=AGENT_NAME,
284
717
  log=log,
285
718
  file_path=state.get("data_visualization_function_path"),
719
+ function_name=state.get("data_visualization_function_name"),
286
720
  )
287
721
 
288
722
  def explain_data_visualization_code(state: GraphState):
@@ -328,4 +762,3 @@ def make_data_visualization_agent(
328
762
  )
329
763
 
330
764
  return app
331
-