ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9007__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +2 -1
- ai_data_science_team/agents/data_cleaning_agent.py +204 -19
- ai_data_science_team/agents/data_visualization_agent.py +331 -0
- ai_data_science_team/agents/data_wrangling_agent.py +56 -11
- ai_data_science_team/agents/feature_engineering_agent.py +40 -11
- ai_data_science_team/agents/sql_database_agent.py +30 -12
- ai_data_science_team/templates/__init__.py +8 -0
- ai_data_science_team/tools/metadata.py +110 -47
- ai_data_science_team/tools/regex.py +6 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/METADATA +41 -23
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +21 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/WHEEL +1 -1
- ai_data_science_team-0.0.0.9006.dist-info/RECORD +0 -20
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9007.dist-info}/top_level.txt +0 -0
ai_data_science_team/_version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.0.0.
|
1
|
+
__version__ = "0.0.0.9007"
|
@@ -1,5 +1,6 @@
|
|
1
|
-
from ai_data_science_team.agents.data_cleaning_agent import make_data_cleaning_agent
|
1
|
+
from ai_data_science_team.agents.data_cleaning_agent import make_data_cleaning_agent, DataCleaningAgent
|
2
2
|
from ai_data_science_team.agents.feature_engineering_agent import make_feature_engineering_agent
|
3
3
|
from ai_data_science_team.agents.data_wrangling_agent import make_data_wrangling_agent
|
4
4
|
from ai_data_science_team.agents.sql_database_agent import make_sql_database_agent
|
5
|
+
from ai_data_science_team.agents.data_visualization_agent import make_data_visualization_agent
|
5
6
|
|
@@ -13,11 +13,13 @@ from langchain_core.messages import BaseMessage
|
|
13
13
|
from langgraph.types import Command
|
14
14
|
from langgraph.checkpoint.memory import MemorySaver
|
15
15
|
|
16
|
+
from langgraph.graph.state import CompiledStateGraph
|
17
|
+
|
16
18
|
import os
|
17
19
|
import io
|
18
20
|
import pandas as pd
|
19
21
|
|
20
|
-
from ai_data_science_team.templates
|
22
|
+
from ai_data_science_team.templates import(
|
21
23
|
node_func_execute_agent_code_on_data,
|
22
24
|
node_func_human_review,
|
23
25
|
node_func_fix_agent_code,
|
@@ -25,7 +27,7 @@ from ai_data_science_team.templates.agent_templates import(
|
|
25
27
|
create_coding_agent_graph
|
26
28
|
)
|
27
29
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
28
|
-
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
|
30
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
|
29
31
|
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
30
32
|
from ai_data_science_team.tools.logging import log_ai_function
|
31
33
|
|
@@ -33,9 +35,170 @@ from ai_data_science_team.tools.logging import log_ai_function
|
|
33
35
|
AGENT_NAME = "data_cleaning_agent"
|
34
36
|
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
35
37
|
|
38
|
+
|
39
|
+
|
40
|
+
# Class
|
41
|
+
class DataCleaningAgent(CompiledStateGraph):
|
42
|
+
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
model,
|
46
|
+
n_samples=30,
|
47
|
+
log=False,
|
48
|
+
log_path=None,
|
49
|
+
file_name="data_cleaner.py",
|
50
|
+
overwrite=True,
|
51
|
+
human_in_the_loop=False,
|
52
|
+
bypass_recommended_steps=False,
|
53
|
+
bypass_explain_code=False
|
54
|
+
):
|
55
|
+
self._params = {
|
56
|
+
"model": model,
|
57
|
+
"n_samples": n_samples,
|
58
|
+
"log": log,
|
59
|
+
"log_path": log_path,
|
60
|
+
"file_name": file_name,
|
61
|
+
"overwrite": overwrite,
|
62
|
+
"human_in_the_loop": human_in_the_loop,
|
63
|
+
"bypass_recommended_steps": bypass_recommended_steps,
|
64
|
+
"bypass_explain_code": bypass_explain_code,
|
65
|
+
}
|
66
|
+
self._compiled_graph = self._make_compiled_graph()
|
67
|
+
self.response = None
|
68
|
+
|
69
|
+
def _make_compiled_graph(self):
|
70
|
+
self.response = None
|
71
|
+
return make_data_cleaning_agent(**self._params)
|
72
|
+
|
73
|
+
def update_params(self, **kwargs):
|
74
|
+
"""
|
75
|
+
Update one or more parameters at once, then rebuild the compiled graph.
|
76
|
+
e.g. agent.update_params(model=new_llm, n_samples=100)
|
77
|
+
"""
|
78
|
+
self._params.update(kwargs)
|
79
|
+
self._compiled_graph = self._make_compiled_graph()
|
80
|
+
|
81
|
+
def __getattr__(self, name: str):
|
82
|
+
"""
|
83
|
+
Delegate attribute access to `_compiled_graph` if `name` is not
|
84
|
+
found in this instance. This 'inherits' methods from the compiled graph.
|
85
|
+
"""
|
86
|
+
return getattr(self._compiled_graph, name)
|
87
|
+
|
88
|
+
def ainvoke(self, user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0):
|
89
|
+
"""
|
90
|
+
Cleans the provided dataset based on user instructions.
|
91
|
+
|
92
|
+
Parameters:
|
93
|
+
user_instructions (str): Instructions for data cleaning.
|
94
|
+
data_raw (pd.DataFrame): The raw dataset to be cleaned.
|
95
|
+
max_retries (int): Maximum retry attempts for cleaning.
|
96
|
+
retry_count (int): Current retry attempt.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
None. The response is stored in the response attribute.
|
100
|
+
"""
|
101
|
+
response = self.ainvoke({
|
102
|
+
"user_instructions": user_instructions,
|
103
|
+
"data_raw": data_raw.to_dict(),
|
104
|
+
"max_retries": max_retries,
|
105
|
+
"retry_count": retry_count,
|
106
|
+
})
|
107
|
+
self.response = response
|
108
|
+
return None
|
109
|
+
|
110
|
+
def invoke(self, user_instructions: str, data_raw: pd.DataFrame, max_retries=3, retry_count=0):
|
111
|
+
"""
|
112
|
+
Cleans the provided dataset based on user instructions.
|
113
|
+
|
114
|
+
Parameters:
|
115
|
+
user_instructions (str): Instructions for data cleaning.
|
116
|
+
data_raw (pd.DataFrame): The raw dataset to be cleaned.
|
117
|
+
max_retries (int): Maximum retry attempts for cleaning.
|
118
|
+
retry_count (int): Current retry attempt.
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
None. The response is stored in the response attribute.
|
122
|
+
"""
|
123
|
+
response = self.invoke({
|
124
|
+
"user_instructions": user_instructions,
|
125
|
+
"data_raw": data_raw.to_dict(),
|
126
|
+
"max_retries": max_retries,
|
127
|
+
"retry_count": retry_count,
|
128
|
+
})
|
129
|
+
self.response = response
|
130
|
+
return None
|
131
|
+
|
132
|
+
def explain_cleaning_steps(self):
|
133
|
+
"""
|
134
|
+
Provides an explanation of the cleaning steps performed by the agent.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
str: Explanation of the cleaning steps.
|
138
|
+
"""
|
139
|
+
messages = self.response.get("messages", [])
|
140
|
+
return messages
|
141
|
+
|
142
|
+
def get_log_summary(self):
|
143
|
+
"""
|
144
|
+
Logs a summary of the agent's operations, if logging is enabled.
|
145
|
+
"""
|
146
|
+
if self.response:
|
147
|
+
if self.log:
|
148
|
+
log_details = f"Log Path: {self.response.get('data_cleaner_function_path')}"
|
149
|
+
return log_details
|
150
|
+
|
151
|
+
def get_state_keys(self):
|
152
|
+
"""
|
153
|
+
Returns a list of keys that the state graph returns in a response.
|
154
|
+
"""
|
155
|
+
return list(self.get_output_jsonschema()['properties'].keys())
|
156
|
+
|
157
|
+
def get_state_properties(self):
|
158
|
+
"""
|
159
|
+
Returns a list of keys that the state graph returns in a response.
|
160
|
+
"""
|
161
|
+
return self.get_output_jsonschema()['properties']
|
162
|
+
|
163
|
+
def get_data_cleaned(self):
|
164
|
+
"""
|
165
|
+
Retrieves the cleaned data stored after running invoke or clean_data methods.
|
166
|
+
"""
|
167
|
+
if self.response:
|
168
|
+
return pd.DataFrame(self.response.get("data_cleaned"))
|
169
|
+
|
170
|
+
def get_data_raw(self):
|
171
|
+
"""
|
172
|
+
Retrieves the raw data.
|
173
|
+
"""
|
174
|
+
if self.response:
|
175
|
+
return pd.DataFrame(self.response.get("data_raw"))
|
176
|
+
|
177
|
+
def get_data_cleaner_function(self):
|
178
|
+
"""
|
179
|
+
Retrieves the agent's pipeline function.
|
180
|
+
"""
|
181
|
+
if self.response:
|
182
|
+
return self.response.get("data_cleaner_function")
|
183
|
+
|
184
|
+
|
185
|
+
|
186
|
+
|
187
|
+
|
188
|
+
|
36
189
|
# Agent
|
37
190
|
|
38
|
-
def make_data_cleaning_agent(
|
191
|
+
def make_data_cleaning_agent(
|
192
|
+
model,
|
193
|
+
n_samples = 30,
|
194
|
+
log=False,
|
195
|
+
log_path=None,
|
196
|
+
file_name="data_cleaner.py",
|
197
|
+
overwrite = True,
|
198
|
+
human_in_the_loop=False,
|
199
|
+
bypass_recommended_steps=False,
|
200
|
+
bypass_explain_code=False
|
201
|
+
):
|
39
202
|
"""
|
40
203
|
Creates a data cleaning agent that can be run on a dataset. The agent can be used to clean a dataset in a variety of
|
41
204
|
ways, such as removing columns with more than 40% missing values, imputing missing
|
@@ -44,9 +207,9 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
|
|
44
207
|
The agent takes in a dataset and some user instructions, and outputs a python
|
45
208
|
function that can be used to clean the dataset. The agent also logs the code
|
46
209
|
generated and any errors that occur.
|
47
|
-
|
210
|
+
|
48
211
|
The agent is instructed to to perform the following data cleaning steps:
|
49
|
-
|
212
|
+
|
50
213
|
- Removing columns if more than 40 percent of the data is missing
|
51
214
|
- Imputing missing values with the mean of the column if the column is numeric
|
52
215
|
- Imputing missing values with the mode of the column if the column is categorical
|
@@ -60,12 +223,18 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
|
|
60
223
|
----------
|
61
224
|
model : langchain.llms.base.LLM
|
62
225
|
The language model to use to generate code.
|
226
|
+
n_samples : int, optional
|
227
|
+
The number of samples to use when summarizing the dataset. Defaults to 30.
|
228
|
+
If you get an error due to maximum tokens, try reducing this number.
|
229
|
+
> "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
|
63
230
|
log : bool, optional
|
64
231
|
Whether or not to log the code generated and any errors that occur.
|
65
232
|
Defaults to False.
|
66
233
|
log_path : str, optional
|
67
234
|
The path to the directory where the log files should be stored. Defaults to
|
68
235
|
"logs/".
|
236
|
+
file_name : str, optional
|
237
|
+
The name of the file to save the response to. Defaults to "data_cleaner.py".
|
69
238
|
overwrite : bool, optional
|
70
239
|
Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
|
71
240
|
Defaults to True.
|
@@ -82,26 +251,26 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
|
|
82
251
|
import pandas as pd
|
83
252
|
from langchain_openai import ChatOpenAI
|
84
253
|
from ai_data_science_team.agents import data_cleaning_agent
|
85
|
-
|
254
|
+
|
86
255
|
llm = ChatOpenAI(model = "gpt-4o-mini")
|
87
256
|
|
88
257
|
data_cleaning_agent = make_data_cleaning_agent(llm)
|
89
|
-
|
258
|
+
|
90
259
|
df = pd.read_csv("https://raw.githubusercontent.com/business-science/ai-data-science-team/refs/heads/master/data/churn_data.csv")
|
91
|
-
|
260
|
+
|
92
261
|
response = data_cleaning_agent.invoke({
|
93
262
|
"user_instructions": "Don't remove outliers when cleaning the data.",
|
94
263
|
"data_raw": df.to_dict(),
|
95
264
|
"max_retries":3,
|
96
265
|
"retry_count":0
|
97
266
|
})
|
98
|
-
|
267
|
+
|
99
268
|
pd.DataFrame(response['data_cleaned'])
|
100
269
|
```
|
101
270
|
|
102
271
|
Returns
|
103
272
|
-------
|
104
|
-
app : langchain.graphs.
|
273
|
+
app : langchain.graphs.CompiledStateGraph
|
105
274
|
The data cleaning agent as a state graph.
|
106
275
|
"""
|
107
276
|
llm = model
|
@@ -134,7 +303,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
|
|
134
303
|
Recommend a series of data cleaning steps based on the input data.
|
135
304
|
These recommended steps will be appended to the user_instructions.
|
136
305
|
"""
|
137
|
-
print(
|
306
|
+
print(format_agent_name(AGENT_NAME))
|
138
307
|
print(" * RECOMMEND CLEANING STEPS")
|
139
308
|
|
140
309
|
# Prompt to get recommended steps from the LLM
|
@@ -177,6 +346,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
|
|
177
346
|
|
178
347
|
Avoid these:
|
179
348
|
1. Do not include steps to save files.
|
349
|
+
2. Do not include unrelated user instructions that are not related to the data cleaning.
|
180
350
|
""",
|
181
351
|
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
182
352
|
)
|
@@ -184,7 +354,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
|
|
184
354
|
data_raw = state.get("data_raw")
|
185
355
|
df = pd.DataFrame.from_dict(data_raw)
|
186
356
|
|
187
|
-
all_datasets_summary = get_dataframe_summary([df])
|
357
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
|
188
358
|
|
189
359
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
190
360
|
|
@@ -201,10 +371,21 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
|
|
201
371
|
}
|
202
372
|
|
203
373
|
def create_data_cleaner_code(state: GraphState):
|
204
|
-
|
205
|
-
print("---DATA CLEANING AGENT----")
|
374
|
+
|
206
375
|
print(" * CREATE DATA CLEANER CODE")
|
207
376
|
|
377
|
+
if bypass_recommended_steps:
|
378
|
+
print(format_agent_name(AGENT_NAME))
|
379
|
+
|
380
|
+
data_raw = state.get("data_raw")
|
381
|
+
df = pd.DataFrame.from_dict(data_raw)
|
382
|
+
|
383
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
|
384
|
+
|
385
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
386
|
+
else:
|
387
|
+
all_datasets_summary_str = state.get("all_datasets_summary")
|
388
|
+
|
208
389
|
data_cleaning_prompt = PromptTemplate(
|
209
390
|
template="""
|
210
391
|
You are a Data Cleaning Agent. Your job is to create a data_cleaner() function that can be run on the data provided using the following recommended steps.
|
@@ -218,7 +399,7 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
|
|
218
399
|
|
219
400
|
{all_datasets_summary}
|
220
401
|
|
221
|
-
Return Python code in ```python ``` format with a single function definition, data_cleaner(data_raw), that
|
402
|
+
Return Python code in ```python ``` format with a single function definition, data_cleaner(data_raw), that includes all imports inside the function.
|
222
403
|
|
223
404
|
Return code to provide the data cleaning function:
|
224
405
|
|
@@ -240,16 +421,16 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
|
|
240
421
|
|
241
422
|
response = data_cleaning_agent.invoke({
|
242
423
|
"recommended_steps": state.get("recommended_steps"),
|
243
|
-
"all_datasets_summary":
|
424
|
+
"all_datasets_summary": all_datasets_summary_str
|
244
425
|
})
|
245
426
|
|
246
427
|
response = relocate_imports_inside_function(response)
|
247
428
|
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
248
429
|
|
249
430
|
# For logging: store the code generated:
|
250
|
-
file_path,
|
431
|
+
file_path, file_name_2 = log_ai_function(
|
251
432
|
response=response,
|
252
|
-
file_name=
|
433
|
+
file_name=file_name,
|
253
434
|
log=log,
|
254
435
|
log_path=log_path,
|
255
436
|
overwrite=overwrite
|
@@ -258,7 +439,8 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
|
|
258
439
|
return {
|
259
440
|
"data_cleaner_function" : response,
|
260
441
|
"data_cleaner_function_path": file_path,
|
261
|
-
"data_cleaner_function_name":
|
442
|
+
"data_cleaner_function_name": file_name_2,
|
443
|
+
"all_datasets_summary": all_datasets_summary_str
|
262
444
|
}
|
263
445
|
|
264
446
|
def human_review(state: GraphState) -> Command[Literal["recommend_cleaning_steps", "create_data_cleaner_code"]]:
|
@@ -353,3 +535,6 @@ def make_data_cleaning_agent(model, log=False, log_path=None, overwrite = True,
|
|
353
535
|
)
|
354
536
|
|
355
537
|
return app
|
538
|
+
|
539
|
+
|
540
|
+
|
@@ -0,0 +1,331 @@
|
|
1
|
+
# BUSINESS SCIENCE UNIVERSITY
|
2
|
+
# AI DATA SCIENCE TEAM
|
3
|
+
# ***
|
4
|
+
# * Agents: Data Visualization Agent
|
5
|
+
|
6
|
+
|
7
|
+
|
8
|
+
# Libraries
|
9
|
+
from typing import TypedDict, Annotated, Sequence, Literal
|
10
|
+
import operator
|
11
|
+
|
12
|
+
from langchain.prompts import PromptTemplate
|
13
|
+
from langchain_core.output_parsers import StrOutputParser
|
14
|
+
from langchain_core.messages import BaseMessage
|
15
|
+
|
16
|
+
from langgraph.types import Command
|
17
|
+
from langgraph.checkpoint.memory import MemorySaver
|
18
|
+
|
19
|
+
import os
|
20
|
+
import io
|
21
|
+
import pandas as pd
|
22
|
+
|
23
|
+
from ai_data_science_team.templates import(
|
24
|
+
node_func_execute_agent_code_on_data,
|
25
|
+
node_func_human_review,
|
26
|
+
node_func_fix_agent_code,
|
27
|
+
node_func_explain_agent_code,
|
28
|
+
create_coding_agent_graph
|
29
|
+
)
|
30
|
+
from ai_data_science_team.tools.parsers import PythonOutputParser
|
31
|
+
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
|
32
|
+
from ai_data_science_team.tools.metadata import get_dataframe_summary
|
33
|
+
from ai_data_science_team.tools.logging import log_ai_function
|
34
|
+
|
35
|
+
# Setup
|
36
|
+
AGENT_NAME = "data_visualization_agent"
|
37
|
+
LOG_PATH = os.path.join(os.getcwd(), "logs/")
|
38
|
+
|
39
|
+
# Agent
|
40
|
+
|
41
|
+
def make_data_visualization_agent(
|
42
|
+
model,
|
43
|
+
n_samples=30,
|
44
|
+
log=False,
|
45
|
+
log_path=None,
|
46
|
+
file_name="data_visualization.py",
|
47
|
+
overwrite = True,
|
48
|
+
human_in_the_loop=False,
|
49
|
+
bypass_recommended_steps=False,
|
50
|
+
bypass_explain_code=False
|
51
|
+
):
|
52
|
+
|
53
|
+
llm = model
|
54
|
+
|
55
|
+
# Setup Log Directory
|
56
|
+
if log:
|
57
|
+
if log_path is None:
|
58
|
+
log_path = LOG_PATH
|
59
|
+
if not os.path.exists(log_path):
|
60
|
+
os.makedirs(log_path)
|
61
|
+
|
62
|
+
# Define GraphState for the router
|
63
|
+
class GraphState(TypedDict):
|
64
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
65
|
+
user_instructions: str
|
66
|
+
user_instructions_processed: str
|
67
|
+
recommended_steps: str
|
68
|
+
data_raw: dict
|
69
|
+
plotly_graph: dict
|
70
|
+
all_datasets_summary: str
|
71
|
+
data_visualization_function: str
|
72
|
+
data_visualization_function_path: str
|
73
|
+
data_visualization_function_name: str
|
74
|
+
data_visualization_error: str
|
75
|
+
max_retries: int
|
76
|
+
retry_count: int
|
77
|
+
|
78
|
+
def chart_instructor(state: GraphState):
|
79
|
+
|
80
|
+
print(format_agent_name(AGENT_NAME))
|
81
|
+
print(" * CREATE CHART GENERATOR INSTRUCTIONS")
|
82
|
+
|
83
|
+
recommend_steps_prompt = PromptTemplate(
|
84
|
+
template="""
|
85
|
+
You are a supervisor that is an expert in providing instructions to a chart generator agent for plotting.
|
86
|
+
|
87
|
+
You will take a question that a user has and the data that was generated to answer the question, and create instructions to create a chart from the data that will be passed to a chart generator agent.
|
88
|
+
|
89
|
+
USER QUESTION / INSTRUCTIONS:
|
90
|
+
{user_instructions}
|
91
|
+
|
92
|
+
Previously Recommended Instructions (if any):
|
93
|
+
{recommended_steps}
|
94
|
+
|
95
|
+
DATA:
|
96
|
+
{all_datasets_summary}
|
97
|
+
|
98
|
+
Formulate chart generator instructions by informing the chart generator of what type of plotly plot to use (e.g. bar, line, scatter, etc) to best represent the data.
|
99
|
+
|
100
|
+
Come up with an informative title from the user's question and data provided. Also provide X and Y axis titles.
|
101
|
+
|
102
|
+
Instruct the chart generator to use the following theme colors, sizes, etc:
|
103
|
+
|
104
|
+
- Start with the "plotly_white" template
|
105
|
+
- Use a white background
|
106
|
+
- Use this color for bars and lines:
|
107
|
+
'blue': '#3381ff',
|
108
|
+
- Base Font Size: 8.8 (Used for x and y axes tickfont, any annotations, hovertips)
|
109
|
+
- Title Font Size: 13.2
|
110
|
+
- Line Size: 0.65 (specify these within the xaxis and yaxis dictionaries)
|
111
|
+
- Add smoothers or trendlines to scatter plots unless not desired by the user
|
112
|
+
- Do not use color_discrete_map (this will result in an error)
|
113
|
+
- Hover tip size: 8.8
|
114
|
+
|
115
|
+
Return your instructions in the following format:
|
116
|
+
CHART GENERATOR INSTRUCTIONS:
|
117
|
+
FILL IN THE INSTRUCTIONS HERE
|
118
|
+
|
119
|
+
Avoid these:
|
120
|
+
1. Do not include steps to save files.
|
121
|
+
2. Do not include unrelated user instructions that are not related to the chart generation.
|
122
|
+
""",
|
123
|
+
input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
|
124
|
+
|
125
|
+
)
|
126
|
+
|
127
|
+
data_raw = state.get("data_raw")
|
128
|
+
df = pd.DataFrame.from_dict(data_raw)
|
129
|
+
|
130
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
|
131
|
+
|
132
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
133
|
+
|
134
|
+
chart_instructor = recommend_steps_prompt | llm
|
135
|
+
|
136
|
+
recommended_steps = chart_instructor.invoke({
|
137
|
+
"user_instructions": state.get("user_instructions"),
|
138
|
+
"recommended_steps": state.get("recommended_steps"),
|
139
|
+
"all_datasets_summary": all_datasets_summary_str
|
140
|
+
})
|
141
|
+
|
142
|
+
return {
|
143
|
+
"recommended_steps": "\n\n# Recommended Data Cleaning Steps:\n" + recommended_steps.content.strip(),
|
144
|
+
"all_datasets_summary": all_datasets_summary_str
|
145
|
+
}
|
146
|
+
|
147
|
+
def chart_generator(state: GraphState):
|
148
|
+
|
149
|
+
print(" * CREATE DATA VISUALIZATION CODE")
|
150
|
+
|
151
|
+
|
152
|
+
if bypass_recommended_steps:
|
153
|
+
print(format_agent_name(AGENT_NAME))
|
154
|
+
|
155
|
+
data_raw = state.get("data_raw")
|
156
|
+
df = pd.DataFrame.from_dict(data_raw)
|
157
|
+
|
158
|
+
all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples, skip_stats=False)
|
159
|
+
|
160
|
+
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
161
|
+
|
162
|
+
chart_generator_instructions = state.get("user_instructions")
|
163
|
+
|
164
|
+
else:
|
165
|
+
all_datasets_summary_str = state.get("all_datasets_summary")
|
166
|
+
chart_generator_instructions = state.get("recommended_steps")
|
167
|
+
|
168
|
+
prompt_template = PromptTemplate(
|
169
|
+
template="""
|
170
|
+
You are a chart generator agent that is an expert in generating plotly charts. You must use plotly or plotly.express to produce plots.
|
171
|
+
|
172
|
+
Your job is to produce python code to generate visualizations.
|
173
|
+
|
174
|
+
You will take instructions from a Chart Instructor and generate a plotly chart from the data provided.
|
175
|
+
|
176
|
+
CHART INSTRUCTIONS:
|
177
|
+
{chart_generator_instructions}
|
178
|
+
|
179
|
+
DATA:
|
180
|
+
{all_datasets_summary}
|
181
|
+
|
182
|
+
RETURN:
|
183
|
+
|
184
|
+
Return Python code in ```python ``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
|
185
|
+
|
186
|
+
Return the plotly chart as a dictionary.
|
187
|
+
|
188
|
+
Return code to provide the data visualization function:
|
189
|
+
|
190
|
+
def data_visualization(data_raw):
|
191
|
+
import pandas as pd
|
192
|
+
import numpy as np
|
193
|
+
import json
|
194
|
+
import plotly.graph_objects as go
|
195
|
+
import plotly.io as pio
|
196
|
+
|
197
|
+
...
|
198
|
+
|
199
|
+
fig_json = pio.to_json(fig)
|
200
|
+
fig_dict = json.loads(fig_json)
|
201
|
+
|
202
|
+
return fig_dict
|
203
|
+
|
204
|
+
Avoid these:
|
205
|
+
1. Do not include steps to save files.
|
206
|
+
2. Do not include unrelated user instructions that are not related to the chart generation.
|
207
|
+
|
208
|
+
""",
|
209
|
+
input_variables=["chart_generator_instructions", "all_datasets_summary"]
|
210
|
+
)
|
211
|
+
|
212
|
+
data_visualization_agent = prompt_template | llm | PythonOutputParser()
|
213
|
+
|
214
|
+
response = data_visualization_agent.invoke({
|
215
|
+
"chart_generator_instructions": chart_generator_instructions,
|
216
|
+
"all_datasets_summary": all_datasets_summary_str
|
217
|
+
})
|
218
|
+
|
219
|
+
response = relocate_imports_inside_function(response)
|
220
|
+
response = add_comments_to_top(response, agent_name=AGENT_NAME)
|
221
|
+
|
222
|
+
# For logging: store the code generated:
|
223
|
+
file_path, file_name_2 = log_ai_function(
|
224
|
+
response=response,
|
225
|
+
file_name=file_name,
|
226
|
+
log=log,
|
227
|
+
log_path=log_path,
|
228
|
+
overwrite=overwrite
|
229
|
+
)
|
230
|
+
|
231
|
+
return {
|
232
|
+
"data_visualization_function": response,
|
233
|
+
"data_visualization_function_path": file_path,
|
234
|
+
"data_visualization_function_name": file_name_2,
|
235
|
+
"all_datasets_summary": all_datasets_summary_str
|
236
|
+
}
|
237
|
+
|
238
|
+
def human_review(state: GraphState) -> Command[Literal["chart_instructor", "chart_generator"]]:
|
239
|
+
return node_func_human_review(
|
240
|
+
state=state,
|
241
|
+
prompt_text="Is the following data visualization instructions correct? (Answer 'yes' or provide modifications)\n{steps}",
|
242
|
+
yes_goto="chart_generator",
|
243
|
+
no_goto="chart_instructor",
|
244
|
+
user_instructions_key="user_instructions",
|
245
|
+
recommended_steps_key="recommended_steps"
|
246
|
+
)
|
247
|
+
|
248
|
+
|
249
|
+
def execute_data_visualization_code(state):
|
250
|
+
return node_func_execute_agent_code_on_data(
|
251
|
+
state=state,
|
252
|
+
data_key="data_raw",
|
253
|
+
result_key="plotly_graph",
|
254
|
+
error_key="data_visualization_error",
|
255
|
+
code_snippet_key="data_visualization_function",
|
256
|
+
agent_function_name="data_visualization",
|
257
|
+
pre_processing=lambda data: pd.DataFrame.from_dict(data),
|
258
|
+
# post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
|
259
|
+
error_message_prefix="An error occurred during data visualization: "
|
260
|
+
)
|
261
|
+
|
262
|
+
def fix_data_visualization_code(state: GraphState):
|
263
|
+
prompt = """
|
264
|
+
You are a Data Visualization Agent. Your job is to create a data_visualization() function that can be run on the data provided. The function is currently broken and needs to be fixed.
|
265
|
+
|
266
|
+
Make sure to only return the function definition for data_visualization().
|
267
|
+
|
268
|
+
Return Python code in ```python``` format with a single function definition, data_visualization(data_raw), that includes all imports inside the function.
|
269
|
+
|
270
|
+
This is the broken code (please fix):
|
271
|
+
{code_snippet}
|
272
|
+
|
273
|
+
Last Known Error:
|
274
|
+
{error}
|
275
|
+
"""
|
276
|
+
|
277
|
+
return node_func_fix_agent_code(
|
278
|
+
state=state,
|
279
|
+
code_snippet_key="data_visualization_function",
|
280
|
+
error_key="data_visualization_error",
|
281
|
+
llm=llm,
|
282
|
+
prompt_template=prompt,
|
283
|
+
agent_name=AGENT_NAME,
|
284
|
+
log=log,
|
285
|
+
file_path=state.get("data_visualization_function_path"),
|
286
|
+
)
|
287
|
+
|
288
|
+
def explain_data_visualization_code(state: GraphState):
|
289
|
+
return node_func_explain_agent_code(
|
290
|
+
state=state,
|
291
|
+
code_snippet_key="data_visualization_function",
|
292
|
+
result_key="messages",
|
293
|
+
error_key="data_visualization_error",
|
294
|
+
llm=llm,
|
295
|
+
role=AGENT_NAME,
|
296
|
+
explanation_prompt_template="""
|
297
|
+
Explain the data visualization steps that the data visualization agent performed in this function.
|
298
|
+
Keep the summary succinct and to the point.\n\n# Data Visualization Agent:\n\n{code}
|
299
|
+
""",
|
300
|
+
success_prefix="# Data Visualization Agent:\n\n ",
|
301
|
+
error_message="The Data Visualization Agent encountered an error during data visualization. No explanation could be provided."
|
302
|
+
)
|
303
|
+
|
304
|
+
# Define the graph
|
305
|
+
node_functions = {
|
306
|
+
"chart_instructor": chart_instructor,
|
307
|
+
"human_review": human_review,
|
308
|
+
"chart_generator": chart_generator,
|
309
|
+
"execute_data_visualization_code": execute_data_visualization_code,
|
310
|
+
"fix_data_visualization_code": fix_data_visualization_code,
|
311
|
+
"explain_data_visualization_code": explain_data_visualization_code
|
312
|
+
}
|
313
|
+
|
314
|
+
app = create_coding_agent_graph(
|
315
|
+
GraphState=GraphState,
|
316
|
+
node_functions=node_functions,
|
317
|
+
recommended_steps_node_name="chart_instructor",
|
318
|
+
create_code_node_name="chart_generator",
|
319
|
+
execute_code_node_name="execute_data_visualization_code",
|
320
|
+
fix_code_node_name="fix_data_visualization_code",
|
321
|
+
explain_code_node_name="explain_data_visualization_code",
|
322
|
+
error_key="data_visualization_error",
|
323
|
+
human_in_the_loop=human_in_the_loop, # or False
|
324
|
+
human_review_node_name="human_review",
|
325
|
+
checkpointer=MemorySaver() if human_in_the_loop else None,
|
326
|
+
bypass_recommended_steps=bypass_recommended_steps,
|
327
|
+
bypass_explain_code=bypass_explain_code,
|
328
|
+
)
|
329
|
+
|
330
|
+
return app
|
331
|
+
|