ai-data-science-team 0.0.0.9005__py3-none-any.whl → 0.0.0.9007__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,331 @@
1
+ # BUSINESS SCIENCE UNIVERSITY
2
+ # AI DATA SCIENCE TEAM
3
+ # ***
4
+ # * Agents: Data Visualization Agent
5
+
6
+
7
+
8
+ # Libraries
9
+ from typing import TypedDict, Annotated, Sequence, Literal
10
+ import operator
11
+
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from langchain_core.messages import BaseMessage
15
+
16
+ from langgraph.types import Command
17
+ from langgraph.checkpoint.memory import MemorySaver
18
+
19
+ import os
20
+ import io
21
+ import pandas as pd
22
+
23
+ from ai_data_science_team.templates import(
24
+ node_func_execute_agent_code_on_data,
25
+ node_func_human_review,
26
+ node_func_fix_agent_code,
27
+ node_func_explain_agent_code,
28
+ create_coding_agent_graph
29
+ )
30
+ from ai_data_science_team.tools.parsers import PythonOutputParser
31
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
32
+ from ai_data_science_team.tools.metadata import get_dataframe_summary
33
+ from ai_data_science_team.tools.logging import log_ai_function
34
+
35
+ # Setup
36
+ AGENT_NAME = "data_visualization_agent"
37
+ LOG_PATH = os.path.join(os.getcwd(), "logs/")
38
+
39
+ # Agent
40
+
41
+ def make_data_visualization_agent(
42
+ model,
43
+ n_samples=30,
44
+ log=False,
45
+ log_path=None,
46
+ file_name="data_visualization.py",
47
+ overwrite = True,
48
+ human_in_the_loop=False,
49
+ bypass_recommended_steps=False,
50
+ bypass_explain_code=False
51
+ ):
52
+
53
+ llm = model
54
+
55
+ # Setup Log Directory
56
+ if log:
57
+ if log_path is None:
58
+ log_path = LOG_PATH
59
+ if not os.path.exists(log_path):
60
+ os.makedirs(log_path)
61
+
62
+ # Define GraphState for the router
63
+ class GraphState(TypedDict):
64
+ messages: Annotated[Sequence[BaseMessage], operator.add]
65
+ user_instructions: str
66
+ user_instructions_processed: str
67
+ recommended_steps: str
68
+ data_raw: dict
69
+ plotly_graph: dict
70
+ all_datasets_summary: str
71
+ data_visualization_function: str
72
+ data_visualization_function_path: str
73
+ data_visualization_function_name: str
74
+ data_visualization_error: str
75
+ max_retries: int
76
+ retry_count: int
77
+
78
+ def chart_instructor(state: GraphState):
79
+
80
+ print(format_agent_name(AGENT_NAME))
81
+ print(" * CREATE CHART GENERATOR INSTRUCTIONS")
82
+
83
+ recommend_steps_prompt = PromptTemplate(
84
+ template="""
85
+ You are a supervisor that is an expert in providing instructions to a chart generator agent for plotting.
86
+
87
+ You will take a question that a user has and the data that was generated to answer the question, and create instructions to create a chart from the data that will be passed to a chart generator agent.
88
+
89
+ USER QUESTION / INSTRUCTIONS:
90
+ {user_instructions}
91
+
92
+ Previously Recommended Instructions (if any):
93
+ {recommended_steps}
94
+
95
+ DATA:
96
+ {all_datasets_summary}
97
+
98
+ Formulate chart generator instructions by informing the chart generator of what type of plotly plot to use (e.g. bar, line, scatter, etc) to best represent the data.
99
+
100
+ Come up with an informative title from the user's question and data provided. Also provide X and Y axis titles.
101
+
102
+ Instruct the chart generator to use the following theme colors, sizes, etc:
103
+
104
+ - Start with the "plotly_white" template
105
+ - Use a white background
106
+ - Use this color for bars and lines:
107
+ 'blue': '#3381ff',
108
+ - Base Font Size: 8.8 (Used for x and y axes tickfont, any annotations, hovertips)
109
+ - Title Font Size: 13.2
110
+ - Line Size: 0.65 (specify these within the xaxis and yaxis dictionaries)
111
+ - Add smoothers or trendlines to scatter plots unless not desired by the user
112
+ - Do not use color_discrete_map (this will result in an error)
113
+ - Hover tip size: 8.8
114
+
115
+ Return your instructions in the following format:
116
+ CHART GENERATOR INSTRUCTIONS:
117
+ FILL IN THE INSTRUCTIONS HERE
118
+
119
+ Avoid these:
120
+ 1. Do not include steps to save files.
121
+ 2. Do not include unrelated user instructions that are not related to the chart generation.
122
+ """,
123
+ input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
124
+
125
+ )
126
+
127
+ data_raw = state.get("data_raw")
128
+ df = pd.DataFrame.from_dict(data_raw)
129
+
130
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
131
+
132
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
133
+
134
+ chart_instructor = recommend_steps_prompt | llm
135
+
136
+ recommended_steps = chart_instructor.invoke({
137
+ "user_instructions": state.get("user_instructions"),
138
+ "recommended_steps": state.get("recommended_steps"),
139
+ "all_datasets_summary": all_datasets_summary_str
140
+ })
141
+
142
+ return {
143
+ "recommended_steps": "\n\n# Recommended Data Cleaning Steps:\n" + recommended_steps.content.strip(),
144
+ "all_datasets_summary": all_datasets_summary_str
145
+ }
146
+
147
+ def chart_generator(state: GraphState):
148
+
149
+ print(" * CREATE DATA VISUALIZATION CODE")
150
+
151
+
152
+ if bypass_recommended_steps:
153
+ print(format_agent_name(AGENT_NAME))
154
+
155
+ data_raw = state.get("data_raw")
156
+ df = pd.DataFrame.from_dict(data_raw)
157
+
158
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
159
+
160
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
161
+
162
+ chart_generator_instructions = state.get("user_instructions")
163
+
164
+ else:
165
+ all_datasets_summary_str = state.get("all_datasets_summary")
166
+ chart_generator_instructions = state.get("recommended_steps")
167
+
168
+ prompt_template = PromptTemplate(
169
+ template="""
170
+ You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
171
+
172
+ Your job is to produce python code to generate visualizations.
173
+
174
+ You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
175
+
176
+ CHART INSTRUCTIONS:
177
+ {chart_generator_instructions}
178
+
179
+ DATA:
180
+ {all_datasets_summary}
181
+
182
+ RETURN:
183
+
184
+ Return Python code in ```python ``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
185
+
186
+ Return the plotly chart as a dictionary.
187
+
188
+ Return code to provide the data visualization function:
189
+
190
+ def data_visualization(data_raw):
191
+ import pandas as pd
192
+ import numpy as np
193
+ import json
194
+ import plotly.graph_objects as go
195
+ import plotly.io as pio
196
+
197
+ ...
198
+
199
+ fig_json = pio.to_json(fig)
200
+ fig_dict = json.loads(fig_json)
201
+
202
+ return fig_dict
203
+
204
+ Avoid these:
205
+ 1. Do not include steps to save files.
206
+ 2. Do not include unrelated user instructions that are not related to the chart generation.
207
+
208
+ """,
209
+ input_variables=["chart_generator_instructions", "all_datasets_summary"]
210
+ )
211
+
212
+ data_visualization_agent = prompt_template | llm | PythonOutputParser()
213
+
214
+ response = data_visualization_agent.invoke({
215
+ "chart_generator_instructions": chart_generator_instructions,
216
+ "all_datasets_summary": all_datasets_summary_str
217
+ })
218
+
219
+ response = relocate_imports_inside_function(response)
220
+ response = add_comments_to_top(response, agent_name=AGENT_NAME)
221
+
222
+ # For logging: store the code generated:
223
+ file_path, file_name_2 = log_ai_function(
224
+ response=response,
225
+ file_name=file_name,
226
+ log=log,
227
+ log_path=log_path,
228
+ overwrite=overwrite
229
+ )
230
+
231
+ return {
232
+ "data_visualization_function": response,
233
+ "data_visualization_function_path": file_path,
234
+ "data_visualization_function_name": file_name_2,
235
+ "all_datasets_summary": all_datasets_summary_str
236
+ }
237
+
238
+ def human_review(state: GraphState) -> Command[Literal["chart_instructor", "chart_generator"]]:
239
+ return node_func_human_review(
240
+ state=state,
241
+ prompt_text="Is the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}",
242
+ yes_goto="chart_generator",
243
+ no_goto="chart_instructor",
244
+ user_instructions_key="user_instructions",
245
+ recommended_steps_key="recommended_steps"
246
+ )
247
+
248
+
249
+ def execute_data_visualization_code(state):
250
+ return node_func_execute_agent_code_on_data(
251
+ state=state,
252
+ data_key="data_raw",
253
+ result_key="plotly_graph",
254
+ error_key="data_visualization_error",
255
+ code_snippet_key="data_visualization_function",
256
+ agent_function_name="data_visualization",
257
+ pre_processing=lambda data: pd.DataFrame.from_dict(data),
258
+ # post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
259
+ error_message_prefix="An error occurred during data visualization: "
260
+ )
261
+
262
+ def fix_data_visualization_code(state: GraphState):
263
+ prompt = """
264
+ You are a Data Visualization Agent. Your job is to create a data_visualization() function that can be run on the data provided. The function is currently broken and needs to be fixed.
265
+
266
+ Make sure to only return the function definition for data_visualization().
267
+
268
+ Return Python code in ```python``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
269
+
270
+ This is the broken code (please fix):
271
+ {code_snippet}
272
+
273
+ Last Known Error:
274
+ {error}
275
+ """
276
+
277
+ return node_func_fix_agent_code(
278
+ state=state,
279
+ code_snippet_key="data_visualization_function",
280
+ error_key="data_visualization_error",
281
+ llm=llm,
282
+ prompt_template=prompt,
283
+ agent_name=AGENT_NAME,
284
+ log=log,
285
+ file_path=state.get("data_visualization_function_path"),
286
+ )
287
+
288
+ def explain_data_visualization_code(state: GraphState):
289
+ return node_func_explain_agent_code(
290
+ state=state,
291
+ code_snippet_key="data_visualization_function",
292
+ result_key="messages",
293
+ error_key="data_visualization_error",
294
+ llm=llm,
295
+ role=AGENT_NAME,
296
+ explanation_prompt_template="""
297
+ Explain the data visualization steps that the data visualization agent performed in this function.
298
+ Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
299
+ """,
300
+ success_prefix="# Data Visualization Agent:\n\n ",
301
+ error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
302
+ )
303
+
304
+ # Define the graph
305
+ node_functions = {
306
+ "chart_instructor": chart_instructor,
307
+ "human_review": human_review,
308
+ "chart_generator": chart_generator,
309
+ "execute_data_visualization_code": execute_data_visualization_code,
310
+ "fix_data_visualization_code": fix_data_visualization_code,
311
+ "explain_data_visualization_code": explain_data_visualization_code
312
+ }
313
+
314
+ app = create_coding_agent_graph(
315
+ GraphState=GraphState,
316
+ node_functions=node_functions,
317
+ recommended_steps_node_name="chart_instructor",
318
+ create_code_node_name="chart_generator",
319
+ execute_code_node_name="execute_data_visualization_code",
320
+ fix_code_node_name="fix_data_visualization_code",
321
+ explain_code_node_name="explain_data_visualization_code",
322
+ error_key="data_visualization_error",
323
+ human_in_the_loop=human_in_the_loop, # or False
324
+ human_review_node_name="human_review",
325
+ checkpointer=MemorySaver() if human_in_the_loop else None,
326
+ bypass_recommended_steps=bypass_recommended_steps,
327
+ bypass_explain_code=bypass_explain_code,
328
+ )
329
+
330
+ return app
331
+
@@ -15,7 +15,7 @@ from langchain_core.messages import BaseMessage
15
15
  from langgraph.types import Command
16
16
  from langgraph.checkpoint.memory import MemorySaver
17
17
 
18
- from ai_data_science_team.templates.agent_templates import(
18
+ from ai_data_science_team.templates import(
19
19
  node_func_execute_agent_code_on_data,
20
20
  node_func_human_review,
21
21
  node_func_fix_agent_code,
@@ -23,15 +23,25 @@ from ai_data_science_team.templates.agent_templates import(
23
23
  create_coding_agent_graph
24
24
  )
25
25
  from ai_data_science_team.tools.parsers import PythonOutputParser
26
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
27
- from ai_data_science_team.tools.data_analysis import summarize_dataframes
26
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
27
+ from ai_data_science_team.tools.metadata import get_dataframe_summary
28
28
  from ai_data_science_team.tools.logging import log_ai_function
29
29
 
30
30
  # Setup Logging Path
31
31
  AGENT_NAME = "data_wrangling_agent"
32
32
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
33
33
 
34
- def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True, human_in_the_loop=False):
34
+ def make_data_wrangling_agent(
35
+ model,
36
+ n_samples=30,
37
+ log=False,
38
+ log_path=None,
39
+ file_name="data_wrangler.py",
40
+ overwrite = True,
41
+ human_in_the_loop=False,
42
+ bypass_recommended_steps=False,
43
+ bypass_explain_code=False
44
+ ):
35
45
  """
36
46
  Creates a data wrangling agent that can be run on one or more datasets. The agent can be
37
47
  instructed to perform common data wrangling steps such as:
@@ -52,17 +62,27 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
52
62
  ----------
53
63
  model : langchain.llms.base.LLM
54
64
  The language model to use to generate code.
65
+ n_samples : int, optional
66
+ The number of samples to show in the data summary. Defaults to 30.
67
+ If you get an error due to maximum tokens, try reducing this number.
68
+ > "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
55
69
  log : bool, optional
56
70
  Whether or not to log the code generated and any errors that occur.
57
71
  Defaults to False.
58
72
  log_path : str, optional
59
73
  The path to the directory where the log files should be stored. Defaults to "logs/".
74
+ file_name : str, optional
75
+ The name of the file to save the response to. Defaults to "data_wrangler.py".
60
76
  overwrite : bool, optional
61
77
  Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
62
78
  Defaults to True.
63
79
  human_in_the_loop : bool, optional
64
80
  Whether or not to use human in the loop. If True, adds an interrupt and human-in-the-loop
65
81
  step that asks the user to review the data wrangling instructions. Defaults to False.
82
+ bypass_recommended_steps : bool, optional
83
+ Bypass the recommendation step, by default False
84
+ bypass_explain_code : bool, optional
85
+ Bypass the code explanation step, by default False.
66
86
 
67
87
  Example
68
88
  -------
@@ -90,7 +110,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
90
110
 
91
111
  Returns
92
112
  -------
93
- app : langchain.graphs.StateGraph
113
+ app : langchain.graphs.CompiledStateGraph
94
114
  The data wrangling agent as a state graph.
95
115
  """
96
116
  llm = model
@@ -118,7 +138,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
118
138
  retry_count: int
119
139
 
120
140
  def recommend_wrangling_steps(state: GraphState):
121
- print("---DATA WRANGLING AGENT----")
141
+ print(format_agent_name(AGENT_NAME))
122
142
  print(" * RECOMMEND WRANGLING STEPS")
123
143
 
124
144
  data_raw = state.get("data_raw")
@@ -139,7 +159,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
139
159
 
140
160
  # Create a summary for all datasets
141
161
  # We'll include a short sample and info for each dataset
142
- all_datasets_summary = summarize_dataframes(dataframes)
162
+ all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
143
163
 
144
164
  # Join all datasets summaries into one big text block
145
165
  all_datasets_summary_str = "\n\n".join(all_datasets_summary)
@@ -172,6 +192,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
172
192
 
173
193
  Avoid these:
174
194
  1. Do not include steps to save files.
195
+ 2. Do not include unrelated user instructions that are not related to the data wrangling.
175
196
  """,
176
197
  input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
177
198
  )
@@ -190,6 +211,35 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
190
211
 
191
212
 
192
213
  def create_data_wrangler_code(state: GraphState):
214
+ if bypass_recommended_steps:
215
+ print(format_agent_name(AGENT_NAME))
216
+
217
+ data_raw = state.get("data_raw")
218
+
219
+ if isinstance(data_raw, dict):
220
+ # Single dataset scenario
221
+ primary_dataset_name = "main"
222
+ datasets = {primary_dataset_name: data_raw}
223
+ elif isinstance(data_raw, list) and all(isinstance(item, dict) for item in data_raw):
224
+ # Multiple datasets scenario
225
+ datasets = {f"dataset_{i}": d for i, d in enumerate(data_raw, start=1)}
226
+ primary_dataset_name = "dataset_1"
227
+ else:
228
+ raise ValueError("data_raw must be a dict or a list of dicts.")
229
+
230
+ # Convert all datasets to DataFrames for inspection
231
+ dataframes = {name: pd.DataFrame.from_dict(d) for name, d in datasets.items()}
232
+
233
+ # Create a summary for all datasets
234
+ # We'll include a short sample and info for each dataset
235
+ all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
236
+
237
+ # Join all datasets summaries into one big text block
238
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
239
+
240
+ else:
241
+ all_datasets_summary_str = state.get("all_datasets_summary")
242
+
193
243
  print(" * CREATE DATA WRANGLER CODE")
194
244
 
195
245
  data_wrangling_prompt = PromptTemplate(
@@ -236,16 +286,16 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
236
286
 
237
287
  response = data_wrangling_agent.invoke({
238
288
  "recommended_steps": state.get("recommended_steps"),
239
- "all_datasets_summary": state.get("all_datasets_summary")
289
+ "all_datasets_summary": all_datasets_summary_str
240
290
  })
241
291
 
242
292
  response = relocate_imports_inside_function(response)
243
293
  response = add_comments_to_top(response, agent_name=AGENT_NAME)
244
294
 
245
295
  # For logging: store the code generated
246
- file_path, file_name = log_ai_function(
296
+ file_path, file_name_2 = log_ai_function(
247
297
  response=response,
248
- file_name="data_wrangler.py",
298
+ file_name=file_name,
249
299
  log=log,
250
300
  log_path=log_path,
251
301
  overwrite=overwrite
@@ -254,7 +304,8 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
254
304
  return {
255
305
  "data_wrangler_function" : response,
256
306
  "data_wrangler_function_path": file_path,
257
- "data_wrangler_function_name": file_name
307
+ "data_wrangler_function_name": file_name_2,
308
+ "all_datasets_summary": all_datasets_summary_str
258
309
  }
259
310
 
260
311
 
@@ -269,17 +320,6 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
269
320
  )
270
321
 
271
322
  def execute_data_wrangler_code(state: GraphState):
272
-
273
- # Handle multiple datasets as lists
274
- # def pre_processing(data):
275
- # df = []
276
- # for i in range(len(data)):
277
- # df[i] = pd.DataFrame.from_dict(data[i])
278
- # return df
279
-
280
- # def post_processing(df):
281
- # return df.to_dict()
282
-
283
323
  return node_func_execute_agent_code_on_data(
284
324
  state=state,
285
325
  data_key="data_raw",
@@ -288,7 +328,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
288
328
  code_snippet_key="data_wrangler_function",
289
329
  agent_function_name="data_wrangler",
290
330
  # pre_processing=pre_processing,
291
- # post_processing=post_processing,
331
+ post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
292
332
  error_message_prefix="An error occurred during data wrangling: "
293
333
  )
294
334
 
@@ -355,7 +395,9 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
355
395
  error_key="data_wrangler_error",
356
396
  human_in_the_loop=human_in_the_loop,
357
397
  human_review_node_name="human_review",
358
- checkpointer=MemorySaver() if human_in_the_loop else None
398
+ checkpointer=MemorySaver() if human_in_the_loop else None,
399
+ bypass_recommended_steps=bypass_recommended_steps,
400
+ bypass_explain_code=bypass_explain_code,
359
401
  )
360
402
 
361
403
  return app