ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,25 +17,367 @@ from langgraph.types import Command
17
17
  from langgraph.checkpoint.memory import MemorySaver
18
18
 
19
19
  import os
20
- import io
21
20
  import pandas as pd
22
21
 
22
+ from IPython.display import Markdown
23
+
23
24
  from ai_data_science_team.templates import(
24
25
  node_func_execute_agent_code_on_data,
25
26
  node_func_human_review,
26
27
  node_func_fix_agent_code,
27
28
  node_func_explain_agent_code,
28
- create_coding_agent_graph
29
+ create_coding_agent_graph,
30
+ BaseAgent,
29
31
  )
30
32
  from ai_data_science_team.tools.parsers import PythonOutputParser
31
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
33
+ from ai_data_science_team.tools.regex import (
34
+ relocate_imports_inside_function,
35
+ add_comments_to_top,
36
+ format_agent_name,
37
+ format_recommended_steps
38
+ )
32
39
  from ai_data_science_team.tools.metadata import get_dataframe_summary
33
40
  from ai_data_science_team.tools.logging import log_ai_function
41
+ from ai_data_science_team.utils.plotly import plotly_from_dict
34
42
 
35
43
  # Setup
36
44
  AGENT_NAME = "data_visualization_agent"
37
45
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
38
46
 
47
+ # Class
48
+
49
+ class DataVisualizationAgent(BaseAgent):
50
+ """
51
+ Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
52
+ default visualization steps (if any). The agent generates a Python function to produce the visualization,
53
+ executes it, and logs the process, including code and errors. It is designed to facilitate reproducible
54
+ and customizable data visualization workflows.
55
+
56
+ The agent may use default instructions for creating charts unless instructed otherwise, such as:
57
+ - Generating a recommended chart type (bar, scatter, line, etc.)
58
+ - Creating user-friendly titles and axis labels
59
+ - Applying consistent styling (template, font sizes, color themes)
60
+ - Handling theme details (white background, base font size, line size, etc.)
61
+
62
+ User instructions can modify, add, or remove any of these steps to tailor the visualization process.
63
+
64
+ Parameters
65
+ ----------
66
+ model : langchain.llms.base.LLM
67
+ The language model used to generate the data visualization function.
68
+ n_samples : int, optional
69
+ Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
70
+ Reducing this number can help avoid exceeding the model's token limits.
71
+ log : bool, optional
72
+ Whether to log the generated code and errors. Defaults to False.
73
+ log_path : str, optional
74
+ Directory path for storing log files. Defaults to None.
75
+ file_name : str, optional
76
+ Name of the file for saving the generated response. Defaults to "data_visualization.py".
77
+ function_name : str, optional
78
+ Name of the function for data visualization. Defaults to "data_visualization".
79
+ overwrite : bool, optional
80
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
81
+ human_in_the_loop : bool, optional
82
+ Enables user review of data visualization instructions. Defaults to False.
83
+ bypass_recommended_steps : bool, optional
84
+ If True, skips the default recommended visualization steps. Defaults to False.
85
+ bypass_explain_code : bool, optional
86
+ If True, skips the step that provides code explanations. Defaults to False.
87
+
88
+ Methods
89
+ -------
90
+ update_params(**kwargs)
91
+ Updates the agent's parameters and rebuilds the compiled state graph.
92
+ ainvoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
93
+ Asynchronously generates a visualization based on user instructions.
94
+ invoke_agent(user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0)
95
+ Synchronously generates a visualization based on user instructions.
96
+ explain_visualization_steps()
97
+ Returns an explanation of the visualization steps performed by the agent.
98
+ get_log_summary()
99
+ Retrieves a summary of logged operations if logging is enabled.
100
+ get_plotly_graph()
101
+ Retrieves the Plotly graph (as a dictionary) produced by the agent.
102
+ get_data_raw()
103
+ Retrieves the raw dataset as a pandas DataFrame (based on the last response).
104
+ get_data_visualization_function()
105
+ Retrieves the generated Python function used for data visualization.
106
+ get_recommended_visualization_steps()
107
+ Retrieves the agent's recommended visualization steps.
108
+ get_response()
109
+ Returns the response from the agent as a dictionary.
110
+ show()
111
+ Displays the agent's mermaid diagram.
112
+
113
+ Examples
114
+ --------
115
+ ```python
116
+ import pandas as pd
117
+ from langchain_openai import ChatOpenAI
118
+ from ai_data_science_team.agents import DataVisualizationAgent
119
+
120
+ llm = ChatOpenAI(model="gpt-4o-mini")
121
+
122
+ data_visualization_agent = DataVisualizationAgent(
123
+ model=llm,
124
+ n_samples=30,
125
+ log=True,
126
+ log_path="logs",
127
+ human_in_the_loop=True
128
+ )
129
+
130
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
131
+
132
+ data_visualization_agent.invoke_agent(
133
+ user_instructions="Generate a scatter plot of age vs. total charges with a trend line.",
134
+ data_raw=df,
135
+ max_retries=3,
136
+ retry_count=0
137
+ )
138
+
139
+ plotly_graph_dict = data_visualization_agent.get_plotly_graph()
140
+ # You can render plotly_graph_dict with plotly.io.from_json or
141
+ # something similar in a Jupyter Notebook.
142
+
143
+ response = data_visualization_agent.get_response()
144
+ ```
145
+
146
+ Returns
147
+ --------
148
+ DataVisualizationAgent : langchain.graphs.CompiledStateGraph
149
+ A data visualization agent implemented as a compiled state graph.
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ model,
155
+ n_samples=30,
156
+ log=False,
157
+ log_path=None,
158
+ file_name="data_visualization.py",
159
+ function_name="data_visualization",
160
+ overwrite=True,
161
+ human_in_the_loop=False,
162
+ bypass_recommended_steps=False,
163
+ bypass_explain_code=False
164
+ ):
165
+ self._params = {
166
+ "model": model,
167
+ "n_samples": n_samples,
168
+ "log": log,
169
+ "log_path": log_path,
170
+ "file_name": file_name,
171
+ "function_name": function_name,
172
+ "overwrite": overwrite,
173
+ "human_in_the_loop": human_in_the_loop,
174
+ "bypass_recommended_steps": bypass_recommended_steps,
175
+ "bypass_explain_code": bypass_explain_code,
176
+ }
177
+ self._compiled_graph = self._make_compiled_graph()
178
+ self.response = None
179
+
180
+ def _make_compiled_graph(self):
181
+ """
182
+ Create the compiled graph for the data visualization agent.
183
+ Running this method will reset the response to None.
184
+ """
185
+ self.response = None
186
+ return make_data_visualization_agent(**self._params)
187
+
188
+ def update_params(self, **kwargs):
189
+ """
190
+ Updates the agent's parameters and rebuilds the compiled graph.
191
+ """
192
+ # Update parameters
193
+ for k, v in kwargs.items():
194
+ self._params[k] = v
195
+ # Rebuild the compiled graph
196
+ self._compiled_graph = self._make_compiled_graph()
197
+
198
+ def ainvoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
199
+ """
200
+ Asynchronously invokes the agent to generate a visualization.
201
+ The response is stored in the 'response' attribute.
202
+
203
+ Parameters
204
+ ----------
205
+ data_raw : pd.DataFrame
206
+ The raw dataset to be visualized.
207
+ user_instructions : str
208
+ Instructions for data visualization.
209
+ max_retries : int
210
+ Maximum retry attempts.
211
+ retry_count : int
212
+ Current retry attempt count.
213
+ **kwargs : dict
214
+ Additional keyword arguments passed to ainvoke().
215
+
216
+ Returns
217
+ -------
218
+ None
219
+ """
220
+ response = self._compiled_graph.ainvoke({
221
+ "user_instructions": user_instructions,
222
+ "data_raw": data_raw.to_dict(),
223
+ "max_retries": max_retries,
224
+ "retry_count": retry_count,
225
+ }, **kwargs)
226
+ self.response = response
227
+ return None
228
+
229
+ def invoke_agent(self, data_raw: pd.DataFrame, user_instructions: str=None, max_retries:int=3, retry_count:int=0, **kwargs):
230
+ """
231
+ Synchronously invokes the agent to generate a visualization.
232
+ The response is stored in the 'response' attribute.
233
+
234
+ Parameters
235
+ ----------
236
+ data_raw : pd.DataFrame
237
+ The raw dataset to be visualized.
238
+ user_instructions : str
239
+ Instructions for data visualization agent.
240
+ max_retries : int
241
+ Maximum retry attempts.
242
+ retry_count : int
243
+ Current retry attempt count.
244
+ **kwargs : dict
245
+ Additional keyword arguments passed to invoke().
246
+
247
+ Returns
248
+ -------
249
+ None
250
+ """
251
+ response = self._compiled_graph.invoke({
252
+ "user_instructions": user_instructions,
253
+ "data_raw": data_raw.to_dict(),
254
+ "max_retries": max_retries,
255
+ "retry_count": retry_count,
256
+ }, **kwargs)
257
+ self.response = response
258
+ return None
259
+
260
+ def explain_visualization_steps(self):
261
+ """
262
+ Provides an explanation of the visualization steps performed by the agent.
263
+
264
+ Returns
265
+ -------
266
+ str
267
+ Explanation of the visualization steps, if any are available.
268
+ """
269
+ if self.response:
270
+ return self.response.get("messages", [])
271
+ return []
272
+
273
+ def get_log_summary(self, markdown=False):
274
+ """
275
+ Logs a summary of the agent's operations, if logging is enabled.
276
+
277
+ Parameters
278
+ ----------
279
+ markdown : bool, optional
280
+ If True, returns Markdown-formatted output.
281
+
282
+ Returns
283
+ -------
284
+ str or None
285
+ Summary of logs or None if no logs are available.
286
+ """
287
+ if self.response and self.response.get('data_visualization_function_path'):
288
+ log_details = f"Log Path: {self.response.get('data_visualization_function_path')}"
289
+ if markdown:
290
+ return Markdown(log_details)
291
+ else:
292
+ return log_details
293
+ return None
294
+
295
+ def get_plotly_graph(self):
296
+ """
297
+ Retrieves the Plotly graph (in dictionary form) produced by the agent.
298
+
299
+ Returns
300
+ -------
301
+ dict or None
302
+ The Plotly graph dictionary if available, otherwise None.
303
+ """
304
+ if self.response:
305
+ return plotly_from_dict(self.response.get("plotly_graph", None))
306
+ return None
307
+
308
+ def get_data_raw(self):
309
+ """
310
+ Retrieves the raw dataset used in the last invocation.
311
+
312
+ Returns
313
+ -------
314
+ pd.DataFrame or None
315
+ The raw dataset as a DataFrame if available, otherwise None.
316
+ """
317
+ if self.response and self.response.get("data_raw"):
318
+ return pd.DataFrame(self.response.get("data_raw"))
319
+ return None
320
+
321
+ def get_data_visualization_function(self, markdown=False):
322
+ """
323
+ Retrieves the generated Python function used for data visualization.
324
+
325
+ Parameters
326
+ ----------
327
+ markdown : bool, optional
328
+ If True, returns the function in Markdown code block format.
329
+
330
+ Returns
331
+ -------
332
+ str or None
333
+ The Python function code as a string if available, otherwise None.
334
+ """
335
+ if self.response:
336
+ func_code = self.response.get("data_visualization_function", "")
337
+ if markdown:
338
+ return Markdown(f"```python\n{func_code}\n```")
339
+ return func_code
340
+ return None
341
+
342
+ def get_recommended_visualization_steps(self, markdown=False):
343
+ """
344
+ Retrieves the agent's recommended visualization steps.
345
+
346
+ Parameters
347
+ ----------
348
+ markdown : bool, optional
349
+ If True, returns the steps in Markdown format.
350
+
351
+ Returns
352
+ -------
353
+ str or None
354
+ The recommended steps if available, otherwise None.
355
+ """
356
+ if self.response:
357
+ steps = self.response.get("recommended_steps", "")
358
+ if markdown:
359
+ return Markdown(steps)
360
+ return steps
361
+ return None
362
+
363
+ def get_response(self):
364
+ """
365
+ Returns the agent's full response dictionary.
366
+
367
+ Returns
368
+ -------
369
+ dict or None
370
+ The response dictionary if available, otherwise None.
371
+ """
372
+ return self.response
373
+
374
+ def show(self):
375
+ """
376
+ Displays the agent's mermaid diagram for visual inspection of the compiled graph.
377
+ """
378
+ return self._compiled_graph.show()
379
+
380
+
39
381
  # Agent
40
382
 
41
383
  def make_data_visualization_agent(
@@ -44,14 +386,85 @@ def make_data_visualization_agent(
44
386
  log=False,
45
387
  log_path=None,
46
388
  file_name="data_visualization.py",
47
- overwrite = True,
389
+ function_name="data_visualization",
390
+ overwrite=True,
48
391
  human_in_the_loop=False,
49
392
  bypass_recommended_steps=False,
50
393
  bypass_explain_code=False
51
394
  ):
395
+ """
396
+ Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
397
+ default visualization steps. The agent generates a Python function to produce the visualization, executes it,
398
+ and logs the process, including code and errors. It is designed to facilitate reproducible and customizable
399
+ data visualization workflows.
400
+
401
+ The agent can perform the following default visualization steps unless instructed otherwise:
402
+ - Generating a recommended chart type (bar, scatter, line, etc.)
403
+ - Creating user-friendly titles and axis labels
404
+ - Applying consistent styling (template, font sizes, color themes)
405
+ - Handling theme details (white background, base font size, line size, etc.)
406
+
407
+ User instructions can modify, add, or remove any of these steps to tailor the visualization process.
408
+
409
+ Parameters
410
+ ----------
411
+ model : langchain.llms.base.LLM
412
+ The language model used to generate the data visualization function.
413
+ n_samples : int, optional
414
+ Number of samples used when summarizing the dataset for chart instructions. Defaults to 30.
415
+ log : bool, optional
416
+ Whether to log the generated code and errors. Defaults to False.
417
+ log_path : str, optional
418
+ Directory path for storing log files. Defaults to None.
419
+ file_name : str, optional
420
+ Name of the file for saving the generated response. Defaults to "data_visualization.py".
421
+ function_name : str, optional
422
+ Name of the function for data visualization. Defaults to "data_visualization".
423
+ overwrite : bool, optional
424
+ Whether to overwrite the log file if it exists. If False, a unique file name is created. Defaults to True.
425
+ human_in_the_loop : bool, optional
426
+ Enables user review of data visualization instructions. Defaults to False.
427
+ bypass_recommended_steps : bool, optional
428
+ If True, skips the default recommended visualization steps. Defaults to False.
429
+ bypass_explain_code : bool, optional
430
+ If True, skips the step that provides code explanations. Defaults to False.
431
+
432
+ Examples
433
+ --------
434
+ ``` python
435
+ import pandas as pd
436
+ from langchain_openai import ChatOpenAI
437
+ from ai_data_science_team.agents import data_visualization_agent
438
+
439
+ llm = ChatOpenAI(model="gpt-4o-mini")
440
+
441
+ data_visualization_agent = make_data_visualization_agent(llm)
442
+
443
+ df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
444
+
445
+ response = data_visualization_agent.invoke({
446
+ "user_instructions": "Generate a scatter plot of tenure vs. total charges with a trend line.",
447
+ "data_raw": df.to_dict(),
448
+ "max_retries": 3,
449
+ "retry_count": 0
450
+ })
451
+
452
+ pd.DataFrame(response['plotly_graph'])
453
+ ```
454
+
455
+ Returns
456
+ -------
457
+ app : langchain.graphs.CompiledStateGraph
458
+ The data visualization agent as a state graph.
459
+ """
52
460
 
53
461
  llm = model
54
462
 
463
+ # Human in th loop requires recommended steps
464
+ if bypass_recommended_steps and human_in_the_loop:
465
+ bypass_recommended_steps = False
466
+ print("Bypass recommended steps set to False to enable human in the loop.")
467
+
55
468
  # Setup Log Directory
56
469
  if log:
57
470
  if log_path is None:
@@ -70,6 +483,7 @@ def make_data_visualization_agent(
70
483
  all_datasets_summary: str
71
484
  data_visualization_function: str
72
485
  data_visualization_function_path: str
486
+ data_visualization_function_file_name: str
73
487
  data_visualization_function_name: str
74
488
  data_visualization_error: str
75
489
  max_retries: int
@@ -140,7 +554,7 @@ def make_data_visualization_agent(
140
554
  })
141
555
 
142
556
  return {
143
- "recommended_steps": "\n\n# Recommended Data Cleaning Steps:\n" + recommended_steps.content.strip(),
557
+ "recommended_steps": format_recommended_steps(recommended_steps.content.strip(), heading="# Recommended Data Cleaning Steps:"),
144
558
  "all_datasets_summary": all_datasets_summary_str
145
559
  }
146
560
 
@@ -169,7 +583,7 @@ def make_data_visualization_agent(
169
583
  template="""
170
584
  You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
171
585
 
172
- Your job is to produce python code to generate visualizations.
586
+ Your job is to produce python code to generate visualizations with a function named {function_name}.
173
587
 
174
588
  You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
175
589
 
@@ -181,13 +595,13 @@ def make_data_visualization_agent(
181
595
 
182
596
  RETURN:
183
597
 
184
- Return Python code in ```python ``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
598
+ Return Python code in ```python ``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
185
599
 
186
600
  Return the plotly chart as a dictionary.
187
601
 
188
602
  Return code to provide the data visualization function:
189
603
 
190
- def data_visualization(data_raw):
604
+ def {function_name}(data_raw):
191
605
  import pandas as pd
192
606
  import numpy as np
193
607
  import json
@@ -206,14 +620,15 @@ def make_data_visualization_agent(
206
620
  2. Do not include unrelated user instructions that are not related to the chart generation.
207
621
 
208
622
  """,
209
- input_variables=["chart_generator_instructions", "all_datasets_summary"]
623
+ input_variables=["chart_generator_instructions", "all_datasets_summary", "function_name"]
210
624
  )
211
-
625
+
212
626
  data_visualization_agent = prompt_template | llm | PythonOutputParser()
213
627
 
214
628
  response = data_visualization_agent.invoke({
215
629
  "chart_generator_instructions": chart_generator_instructions,
216
- "all_datasets_summary": all_datasets_summary_str
630
+ "all_datasets_summary": all_datasets_summary_str,
631
+ "function_name": function_name
217
632
  })
218
633
 
219
634
  response = relocate_imports_inside_function(response)
@@ -231,19 +646,37 @@ def make_data_visualization_agent(
231
646
  return {
232
647
  "data_visualization_function": response,
233
648
  "data_visualization_function_path": file_path,
234
- "data_visualization_function_name": file_name_2,
649
+ "data_visualization_function_file_name": file_name_2,
650
+ "data_visualization_function_name": function_name,
235
651
  "all_datasets_summary": all_datasets_summary_str
236
652
  }
237
-
238
- def human_review(state: GraphState) -> Command[Literal["chart_instructor", "chart_generator"]]:
239
- return node_func_human_review(
240
- state=state,
241
- prompt_text="Is the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}",
242
- yes_goto="chart_generator",
243
- no_goto="chart_instructor",
244
- user_instructions_key="user_instructions",
245
- recommended_steps_key="recommended_steps"
246
- )
653
+
654
+ # Human Review
655
+
656
+ prompt_text_human_review = "Are the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}"
657
+
658
+ if not bypass_explain_code:
659
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "explain_data_visualization_code"]]:
660
+ return node_func_human_review(
661
+ state=state,
662
+ prompt_text=prompt_text_human_review,
663
+ yes_goto= 'explain_data_visualization_code',
664
+ no_goto="chart_instructor",
665
+ user_instructions_key="user_instructions",
666
+ recommended_steps_key="recommended_steps",
667
+ code_snippet_key="data_visualization_function",
668
+ )
669
+ else:
670
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "__end__"]]:
671
+ return node_func_human_review(
672
+ state=state,
673
+ prompt_text=prompt_text_human_review,
674
+ yes_goto= '__end__',
675
+ no_goto="chart_instructor",
676
+ user_instructions_key="user_instructions",
677
+ recommended_steps_key="recommended_steps",
678
+ code_snippet_key="data_visualization_function",
679
+ )
247
680
 
248
681
 
249
682
  def execute_data_visualization_code(state):
@@ -253,7 +686,7 @@ def make_data_visualization_agent(
253
686
  result_key="plotly_graph",
254
687
  error_key="data_visualization_error",
255
688
  code_snippet_key="data_visualization_function",
256
- agent_function_name="data_visualization",
689
+ agent_function_name=state.get("data_visualization_function_name"),
257
690
  pre_processing=lambda data: pd.DataFrame.from_dict(data),
258
691
  # post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
259
692
  error_message_prefix="An error occurred during data visualization: "
@@ -261,11 +694,11 @@ def make_data_visualization_agent(
261
694
 
262
695
  def fix_data_visualization_code(state: GraphState):
263
696
  prompt = """
264
- You are a Data Visualization Agent. Your job is to create a data_visualization() function that can be run on the data provided. The function is currently broken and needs to be fixed.
697
+ You are a Data Visualization Agent. Your job is to create a {function_name}() function that can be run on the data provided. The function is currently broken and needs to be fixed.
265
698
 
266
- Make sure to only return the function definition for data_visualization().
699
+ Make sure to only return the function definition for {function_name}().
267
700
 
268
- Return Python code in ```python``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
701
+ Return Python code in ```python``` format with a single function definition, {function_name}(data_raw), that includes all imports inside the function.
269
702
 
270
703
  This is the broken code (please fix):
271
704
  {code_snippet}
@@ -283,6 +716,7 @@ def make_data_visualization_agent(
283
716
  agent_name=AGENT_NAME,
284
717
  log=log,
285
718
  file_path=state.get("data_visualization_function_path"),
719
+ function_name=state.get("data_visualization_function_name"),
286
720
  )
287
721
 
288
722
  def explain_data_visualization_code(state: GraphState):
@@ -328,4 +762,3 @@ def make_data_visualization_agent(
328
762
  )
329
763
 
330
764
  return app
331
-