ai-data-science-team 0.0.0.9005__py3-none-any.whl → 0.0.0.9007__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -17,7 +17,7 @@ import os
17
17
  import io
18
18
  import pandas as pd
19
19
 
20
- from ai_data_science_team.templates.agent_templates import(
20
+ from ai_data_science_team.templates import(
21
21
  node_func_execute_agent_code_on_data,
22
22
  node_func_human_review,
23
23
  node_func_fix_agent_code,
@@ -25,8 +25,8 @@ from ai_data_science_team.templates.agent_templates import(
25
25
  create_coding_agent_graph
26
26
  )
27
27
  from ai_data_science_team.tools.parsers import PythonOutputParser
28
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
29
- from ai_data_science_team.tools.data_analysis import summarize_dataframes
28
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
29
+ from ai_data_science_team.tools.metadata import get_dataframe_summary
30
30
  from ai_data_science_team.tools.logging import log_ai_function
31
31
 
32
32
  # Setup
@@ -35,7 +35,17 @@ LOG_PATH = os.path.join(os.getcwd(), "logs/")
35
35
 
36
36
  # * Feature Engineering Agent
37
37
 
38
- def make_feature_engineering_agent(model, log=False, log_path=None, overwrite = True, human_in_the_loop=False):
38
+ def make_feature_engineering_agent(
39
+ model,
40
+ n_samples=30,
41
+ log=False,
42
+ log_path=None,
43
+ file_name="feature_engineer.py",
44
+ overwrite = True,
45
+ human_in_the_loop=False,
46
+ bypass_recommended_steps=False,
47
+ bypass_explain_code=False,
48
+ ):
39
49
  """
40
50
  Creates a feature engineering agent that can be run on a dataset. The agent applies various feature engineering
41
51
  techniques, such as encoding categorical variables, scaling numeric variables, creating interaction terms,
@@ -61,16 +71,26 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
61
71
  ----------
62
72
  model : langchain.llms.base.LLM
63
73
  The language model to use to generate code.
74
+ n_samples : int, optional
75
+ The number of data samples to use for generating the feature engineering code. Defaults to 30.
76
+ If you get an error due to maximum tokens, try reducing this number.
77
+ > "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
64
78
  log : bool, optional
65
79
  Whether or not to log the code generated and any errors that occur.
66
80
  Defaults to False.
67
81
  log_path : str, optional
68
82
  The path to the directory where the log files should be stored. Defaults to "logs/".
83
+ file_name : str, optional
84
+ The name of the file to save the log to. Defaults to "feature_engineer.py".
69
85
  overwrite : bool, optional
70
86
  Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
71
87
  Defaults to True.
72
88
  human_in_the_loop : bool, optional
73
89
  Whether or not to use human in the loop. If True, adds an interput and human in the loop step that asks the user to review the feature engineering instructions. Defaults to False.
90
+ bypass_recommended_steps : bool, optional
91
+ Bypass the recommendation step, by default False
92
+ bypass_explain_code : bool, optional
93
+ Bypass the code explanation step, by default False.
74
94
 
75
95
  Examples
76
96
  -------
@@ -98,7 +118,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
98
118
 
99
119
  Returns
100
120
  -------
101
- app : langchain.graphs.StateGraph
121
+ app : langchain.graphs.CompiledStateGraph
102
122
  The feature engineering agent as a state graph.
103
123
  """
104
124
  llm = model
@@ -131,7 +151,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
131
151
  Recommend a series of feature engineering steps based on the input data.
132
152
  These recommended steps will be appended to the user_instructions.
133
153
  """
134
- print("---FEATURE ENGINEERING AGENT----")
154
+ print(format_agent_name(AGENT_NAME))
135
155
  print(" * RECOMMEND FEATURE ENGINEERING STEPS")
136
156
 
137
157
  # Prompt to get recommended steps from the LLM
@@ -178,6 +198,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
178
198
 
179
199
  Avoid these:
180
200
  1. Do not include steps to save files.
201
+ 2. Do not include unrelated user instructions that are not related to the feature engineering.
181
202
  """,
182
203
  input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
183
204
  )
@@ -185,7 +206,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
185
206
  data_raw = state.get("data_raw")
186
207
  df = pd.DataFrame.from_dict(data_raw)
187
208
 
188
- all_datasets_summary = summarize_dataframes([df])
209
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
189
210
 
190
211
  all_datasets_summary_str = "\n\n".join(all_datasets_summary)
191
212
 
@@ -212,6 +233,19 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
212
233
  )
213
234
 
214
235
  def create_feature_engineering_code(state: GraphState):
236
+ if bypass_recommended_steps:
237
+ print(format_agent_name(AGENT_NAME))
238
+
239
+ data_raw = state.get("data_raw")
240
+ df = pd.DataFrame.from_dict(data_raw)
241
+
242
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
243
+
244
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
245
+
246
+ else:
247
+ all_datasets_summary_str = state.get("all_datasets_summary")
248
+
215
249
  print(" * CREATE FEATURE ENGINEERING CODE")
216
250
 
217
251
  feature_engineering_prompt = PromptTemplate(
@@ -266,16 +300,16 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
266
300
  response = feature_engineering_agent.invoke({
267
301
  "recommended_steps": state.get("recommended_steps"),
268
302
  "target_variable": state.get("target_variable"),
269
- "all_datasets_summary": state.get("all_datasets_summary"),
303
+ "all_datasets_summary": all_datasets_summary_str,
270
304
  })
271
305
 
272
306
  response = relocate_imports_inside_function(response)
273
307
  response = add_comments_to_top(response, agent_name=AGENT_NAME)
274
308
 
275
309
  # For logging: store the code generated
276
- file_path, file_name = log_ai_function(
310
+ file_path, file_name_2 = log_ai_function(
277
311
  response=response,
278
- file_name="feature_engineer.py",
312
+ file_name=file_name,
279
313
  log=log,
280
314
  log_path=log_path,
281
315
  overwrite=overwrite
@@ -284,7 +318,8 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
284
318
  return {
285
319
  "feature_engineer_function": response,
286
320
  "feature_engineer_function_path": file_path,
287
- "feature_engineer_function_name": file_name
321
+ "feature_engineer_function_name": file_name_2,
322
+ "all_datasets_summary": all_datasets_summary_str
288
323
  }
289
324
 
290
325
 
@@ -298,7 +333,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
298
333
  code_snippet_key="feature_engineer_function",
299
334
  agent_function_name="feature_engineer",
300
335
  pre_processing=lambda data: pd.DataFrame.from_dict(data),
301
- post_processing=lambda df: df.to_dict(),
336
+ post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
302
337
  error_message_prefix="An error occurred during feature engineering: "
303
338
  )
304
339
 
@@ -362,7 +397,9 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
362
397
  error_key="feature_engineer_error",
363
398
  human_in_the_loop=human_in_the_loop,
364
399
  human_review_node_name="human_review",
365
- checkpointer=MemorySaver() if human_in_the_loop else None
400
+ checkpointer=MemorySaver() if human_in_the_loop else None,
401
+ bypass_recommended_steps=bypass_recommended_steps,
402
+ bypass_explain_code=bypass_explain_code,
366
403
  )
367
404
 
368
405
  return app
@@ -0,0 +1,397 @@
1
+
2
+
3
+ from typing import TypedDict, Annotated, Sequence, Literal
4
+ import operator
5
+
6
+ from langchain.prompts import PromptTemplate
7
+ from langchain_core.messages import BaseMessage
8
+
9
+ from langgraph.types import Command
10
+ from langgraph.checkpoint.memory import MemorySaver
11
+
12
+ import os
13
+ import io
14
+ import pandas as pd
15
+ import sqlalchemy as sql
16
+
17
+ from ai_data_science_team.templates import(
18
+ node_func_execute_agent_from_sql_connection,
19
+ node_func_human_review,
20
+ node_func_fix_agent_code,
21
+ node_func_explain_agent_code,
22
+ create_coding_agent_graph
23
+ )
24
+ from ai_data_science_team.tools.parsers import PythonOutputParser, SQLOutputParser
25
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
26
+ from ai_data_science_team.tools.metadata import get_database_metadata
27
+ from ai_data_science_team.tools.logging import log_ai_function
28
+
29
+ # Setup
30
+ AGENT_NAME = "sql_database_agent"
31
+ LOG_PATH = os.path.join(os.getcwd(), "logs/")
32
+
33
+
34
+ def make_sql_database_agent(
35
+ model, connection,
36
+ n_samples = 10,
37
+ log=False,
38
+ log_path=None,
39
+ file_name="sql_database.py",
40
+ overwrite = True,
41
+ human_in_the_loop=False, bypass_recommended_steps=False,
42
+ bypass_explain_code=False
43
+ ):
44
+ """
45
+ Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
46
+
47
+ Parameters
48
+ ----------
49
+ model : ChatOpenAI
50
+ The language model to use for the agent.
51
+ connection : sqlalchemy.engine.base.Engine
52
+ The connection to the SQL database.
53
+ n_samples : int, optional
54
+ The number of samples to retrieve for each column, by default 10.
55
+ If you get an error due to maximum tokens, try reducing this number.
56
+ > "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
57
+ log : bool, optional
58
+ Whether to log the generated code, by default False
59
+ log_path : str, optional
60
+ The path to the log directory, by default None
61
+ file_name : str, optional
62
+ The name of the file to save the generated code, by default "sql_database.py"
63
+ overwrite : bool, optional
64
+ Whether to overwrite the existing log file, by default True
65
+ human_in_the_loop : bool, optional
66
+ Whether or not to use human in the loop. If True, adds an interput and human in the loop step that asks the user to review the feature engineering instructions. Defaults to False.
67
+ bypass_recommended_steps : bool, optional
68
+ Bypass the recommendation step, by default False
69
+ bypass_explain_code : bool, optional
70
+ Bypass the code explanation step, by default False.
71
+
72
+ Returns
73
+ -------
74
+ app : langchain.graphs.CompiledStateGraph
75
+ The data cleaning agent as a state graph.
76
+
77
+ Examples
78
+ --------
79
+ ```python
80
+ from ai_data_science_team.agents import make_sql_database_agent
81
+ import sqlalchemy as sql
82
+ from langchain_openai import ChatOpenAI
83
+
84
+ sql_engine = sql.create_engine("sqlite:///data/leads_scored.db")
85
+
86
+ conn = sql_engine.connect()
87
+
88
+ llm = ChatOpenAI(model="gpt-4o-mini")
89
+
90
+ sql_agent = make_sql_database_agent(
91
+ model=llm,
92
+ connection=conn
93
+ )
94
+
95
+ sql_agent
96
+
97
+ response = sql_agent.invoke({
98
+ "user_instructions": "List the tables in the database",
99
+ "max_retries":3,
100
+ "retry_count":0
101
+ })
102
+ ```
103
+
104
+ """
105
+
106
+ is_engine = isinstance(connection, sql.engine.base.Engine)
107
+ conn = connection.connect() if is_engine else connection
108
+
109
+ llm = model
110
+
111
+ # Setup Log Directory
112
+ if log:
113
+ if log_path is None:
114
+ log_path = LOG_PATH
115
+ if not os.path.exists(log_path):
116
+ os.makedirs(log_path)
117
+
118
+ class GraphState(TypedDict):
119
+ messages: Annotated[Sequence[BaseMessage], operator.add]
120
+ user_instructions: str
121
+ recommended_steps: str
122
+ data_sql: dict
123
+ all_sql_database_summary: str
124
+ sql_query_code: str
125
+ sql_database_function: str
126
+ sql_database_function_path: str
127
+ sql_database_function_name: str
128
+ sql_database_error: str
129
+ max_retries: int
130
+ retry_count: int
131
+
132
+ def recommend_sql_steps(state: GraphState):
133
+
134
+ print(format_agent_name(AGENT_NAME))
135
+ print(" * RECOMMEND STEPS")
136
+
137
+
138
+ # Prompt to get recommended steps from the LLM
139
+ recommend_steps_prompt = PromptTemplate(
140
+ template="""
141
+ You are a SQL Database Instructions Expert. Given the following information about the SQL database,
142
+ recommend a series of numbered steps to take to collect the data and process it according to user instructions.
143
+ The steps should be tailored to the SQL database characteristics and should be helpful
144
+ for a sql database coding agent that will write the SQL code.
145
+
146
+ IMPORTANT INSTRUCTIONS:
147
+ - Take into account the user instructions and the previously recommended steps.
148
+ - If no user instructions are provided, just return the steps needed to understand the database.
149
+ - Take into account the database dialect and the tables and columns in the database.
150
+ - Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
151
+
152
+
153
+ User instructions / Question:
154
+ {user_instructions}
155
+
156
+ Previously Recommended Steps (if any):
157
+ {recommended_steps}
158
+
159
+ Below are summaries of the database metadata and the SQL tables:
160
+ {all_sql_database_summary}
161
+
162
+ Return the steps as a numbered point list (no code, just the steps).
163
+
164
+ Consider these:
165
+
166
+ 1. Consider the database dialect and the tables and columns in the database.
167
+
168
+
169
+ Avoid these:
170
+ 1. Do not include steps to save files.
171
+ 2. Do not include steps to modify existing tables, create new tables or modify the database schema.
172
+ 3. Do not include steps that alter the existing data in the database.
173
+ 4. Make sure not to include unsafe code that could cause data loss or corruption or SQL injections.
174
+ 5. Make sure to not include irrelevant steps that do not help in the SQL agent's data collection and processing. Examples include steps to create new tables, modify the schema, save files, create charts, etc.
175
+
176
+
177
+ """,
178
+ input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
179
+ )
180
+
181
+ # Create a connection if needed
182
+ is_engine = isinstance(connection, sql.engine.base.Engine)
183
+ conn = connection.connect() if is_engine else connection
184
+
185
+ # Get the database metadata
186
+ all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
187
+
188
+ steps_agent = recommend_steps_prompt | llm
189
+
190
+ recommended_steps = steps_agent.invoke({
191
+ "user_instructions": state.get("user_instructions"),
192
+ "recommended_steps": state.get("recommended_steps"),
193
+ "all_sql_database_summary": all_sql_database_summary
194
+ })
195
+
196
+ return {
197
+ "recommended_steps": "\n\n# Recommended SQL Database Steps:\n" + recommended_steps.content.strip(),
198
+ "all_sql_database_summary": all_sql_database_summary
199
+ }
200
+
201
+ def create_sql_query_code(state: GraphState):
202
+ if bypass_recommended_steps:
203
+ print(format_agent_name(AGENT_NAME))
204
+ print(" * CREATE SQL QUERY CODE")
205
+
206
+ # Prompt to get the SQL code from the LLM
207
+ sql_query_code_prompt = PromptTemplate(
208
+ template="""
209
+ You are a SQL Database Coding Expert. Given the following information about the SQL database,
210
+ write the SQL code to collect the data and process it according to user instructions.
211
+ The code should be tailored to the SQL database characteristics and should take into account user instructions, recommended steps, database and table characteristics.
212
+
213
+ IMPORTANT INSTRUCTIONS:
214
+ - Do not use a LIMIT clause unless a user specifies a limit to be returned.
215
+ - Return SQL in ```sql ``` format.
216
+ - Only return a single query if possible.
217
+ - Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
218
+ - Pay attention to the SQL dialect from the database summary metadata. Write the SQL code according to the dialect specified.
219
+
220
+
221
+ User instructions / Question:
222
+ {user_instructions}
223
+
224
+ Recommended Steps:
225
+ {recommended_steps}
226
+
227
+ Below are summaries of the database metadata and the SQL tables:
228
+ {all_sql_database_summary}
229
+
230
+ Return:
231
+ - The SQL code in ```sql ``` format to collect the data and process it according to the user instructions.
232
+
233
+ Avoid these:
234
+ - Do not include steps to save files.
235
+ - Do not include steps to modify existing tables, create new tables or modify the database schema.
236
+ - Make sure not to alter the existing data in the database.
237
+ - Make sure not to include unsafe code that could cause data loss or corruption.
238
+
239
+ """,
240
+ input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
241
+ )
242
+
243
+ # Create a connection if needed
244
+ is_engine = isinstance(connection, sql.engine.base.Engine)
245
+ conn = connection.connect() if is_engine else connection
246
+
247
+ # Get the database metadata
248
+ all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
249
+
250
+ sql_query_code_agent = sql_query_code_prompt | llm | SQLOutputParser()
251
+
252
+ sql_query_code = sql_query_code_agent.invoke({
253
+ "user_instructions": state.get("user_instructions"),
254
+ "recommended_steps": state.get("recommended_steps"),
255
+ "all_sql_database_summary": all_sql_database_summary
256
+ })
257
+
258
+ print(" * CREATE PYTHON FUNCTION TO RUN SQL CODE")
259
+
260
+ response = f"""
261
+ def sql_database_pipeline(connection):
262
+ import pandas as pd
263
+ import sqlalchemy as sql
264
+
265
+ # Create a connection if needed
266
+ is_engine = isinstance(connection, sql.engine.base.Engine)
267
+ conn = connection.connect() if is_engine else connection
268
+
269
+ sql_query = '''
270
+ {sql_query_code}
271
+ '''
272
+
273
+ return pd.read_sql(sql_query, connection)
274
+ """
275
+
276
+ response = add_comments_to_top(response, AGENT_NAME)
277
+
278
+ # For logging: store the code generated
279
+ file_path, file_name_2 = log_ai_function(
280
+ response=response,
281
+ file_name=file_name,
282
+ log=log,
283
+ log_path=log_path,
284
+ overwrite=overwrite
285
+ )
286
+
287
+ return {
288
+ "sql_query_code": sql_query_code,
289
+ "sql_database_function": response,
290
+ "sql_database_function_path": file_path,
291
+ "sql_database_function_name": file_name_2,
292
+ "all_sql_database_summary": all_sql_database_summary
293
+ }
294
+
295
+ def human_review(state: GraphState) -> Command[Literal["recommend_sql_steps", "create_sql_query_code"]]:
296
+ return node_func_human_review(
297
+ state=state,
298
+ prompt_text="Are the following SQL database querying steps correct? (Answer 'yes' or provide modifications)\n{steps}",
299
+ yes_goto="create_sql_query_code",
300
+ no_goto="recommend_sql_steps",
301
+ user_instructions_key="user_instructions",
302
+ recommended_steps_key="recommended_steps"
303
+ )
304
+
305
+ def execute_sql_database_code(state: GraphState):
306
+
307
+ is_engine = isinstance(connection, sql.engine.base.Engine)
308
+ conn = connection.connect() if is_engine else connection
309
+
310
+ return node_func_execute_agent_from_sql_connection(
311
+ state=state,
312
+ connection=conn,
313
+ result_key="data_sql",
314
+ error_key="sql_database_error",
315
+ code_snippet_key="sql_database_function",
316
+ agent_function_name="sql_database_pipeline",
317
+ post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
318
+ error_message_prefix="An error occurred during executing the sql database pipeline: "
319
+ )
320
+
321
+ def fix_sql_database_code(state: GraphState):
322
+ prompt = """
323
+ You are a SQL Database Agent code fixer. Your job is to create a sql_database_pipeline(connection) function that can be run on a sql connection. The function is currently broken and needs to be fixed.
324
+
325
+ Make sure to only return the function definition for sql_database_pipeline().
326
+
327
+ Return Python code in ```python``` format with a single function definition, sql_database_pipeline(connection), that includes all imports inside the function. The connection object is a SQLAlchemy connection object. Don't specify the class of the connection object, just use it as an argument to the function.
328
+
329
+ This is the broken code (please fix):
330
+ {code_snippet}
331
+
332
+ Last Known Error:
333
+ {error}
334
+ """
335
+
336
+ return node_func_fix_agent_code(
337
+ state=state,
338
+ code_snippet_key="sql_database_function",
339
+ error_key="sql_database_error",
340
+ llm=llm,
341
+ prompt_template=prompt,
342
+ agent_name=AGENT_NAME,
343
+ log=log,
344
+ file_path=state.get("sql_database_function_path", None),
345
+ )
346
+
347
+ def explain_sql_database_code(state: GraphState):
348
+ return node_func_explain_agent_code(
349
+ state=state,
350
+ code_snippet_key="sql_database_function",
351
+ result_key="messages",
352
+ error_key="sql_database_error",
353
+ llm=llm,
354
+ role=AGENT_NAME,
355
+ explanation_prompt_template="""
356
+ Explain the SQL steps that the SQL Database agent performed in this function.
357
+ Keep the summary succinct and to the point.\n\n# SQL Database Agent:\n\n{code}
358
+ """,
359
+ success_prefix="# SQL Database Agent:\n\n",
360
+ error_message="The SQL Database Agent encountered an error during SQL Query Analysis. No SQL function explanation is returned."
361
+ )
362
+
363
+ # Create the graph
364
+ node_functions = {
365
+ "recommend_sql_steps": recommend_sql_steps,
366
+ "human_review": human_review,
367
+ "create_sql_query_code": create_sql_query_code,
368
+ "execute_sql_database_code": execute_sql_database_code,
369
+ "fix_sql_database_code": fix_sql_database_code,
370
+ "explain_sql_database_code": explain_sql_database_code
371
+ }
372
+
373
+ app = create_coding_agent_graph(
374
+ GraphState=GraphState,
375
+ node_functions=node_functions,
376
+ recommended_steps_node_name="recommend_sql_steps",
377
+ create_code_node_name="create_sql_query_code",
378
+ execute_code_node_name="execute_sql_database_code",
379
+ fix_code_node_name="fix_sql_database_code",
380
+ explain_code_node_name="explain_sql_database_code",
381
+ error_key="sql_database_error",
382
+ human_in_the_loop=human_in_the_loop,
383
+ human_review_node_name="human_review",
384
+ checkpointer=MemorySaver() if human_in_the_loop else None,
385
+ bypass_recommended_steps=bypass_recommended_steps,
386
+ bypass_explain_code=bypass_explain_code,
387
+ )
388
+
389
+ return app
390
+
391
+
392
+
393
+
394
+
395
+
396
+
397
+
@@ -0,0 +1,8 @@
1
+ from ai_data_science_team.templates.agent_templates import(
2
+ node_func_execute_agent_code_on_data,
3
+ node_func_human_review,
4
+ node_func_fix_agent_code,
5
+ node_func_explain_agent_code,
6
+ node_func_execute_agent_from_sql_connection,
7
+ create_coding_agent_graph
8
+ )