ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9007__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -15,7 +15,7 @@ from langchain_core.messages import BaseMessage
15
15
  from langgraph.types import Command
16
16
  from langgraph.checkpoint.memory import MemorySaver
17
17
 
18
- from ai_data_science_team.templates.agent_templates import(
18
+ from ai_data_science_team.templates import(
19
19
  node_func_execute_agent_code_on_data,
20
20
  node_func_human_review,
21
21
  node_func_fix_agent_code,
@@ -23,7 +23,7 @@ from ai_data_science_team.templates.agent_templates import(
23
23
  create_coding_agent_graph
24
24
  )
25
25
  from ai_data_science_team.tools.parsers import PythonOutputParser
26
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
26
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
27
27
  from ai_data_science_team.tools.metadata import get_dataframe_summary
28
28
  from ai_data_science_team.tools.logging import log_ai_function
29
29
 
@@ -31,7 +31,17 @@ from ai_data_science_team.tools.logging import log_ai_function
31
31
  AGENT_NAME = "data_wrangling_agent"
32
32
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
33
33
 
34
- def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True, human_in_the_loop=False, bypass_recommended_steps=False, bypass_explain_code=False):
34
+ def make_data_wrangling_agent(
35
+ model,
36
+ n_samples=30,
37
+ log=False,
38
+ log_path=None,
39
+ file_name="data_wrangler.py",
40
+ overwrite = True,
41
+ human_in_the_loop=False,
42
+ bypass_recommended_steps=False,
43
+ bypass_explain_code=False
44
+ ):
35
45
  """
36
46
  Creates a data wrangling agent that can be run on one or more datasets. The agent can be
37
47
  instructed to perform common data wrangling steps such as:
@@ -52,11 +62,17 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
52
62
  ----------
53
63
  model : langchain.llms.base.LLM
54
64
  The language model to use to generate code.
65
+ n_samples : int, optional
66
+ The number of samples to show in the data summary. Defaults to 30.
67
+ If you get an error due to maximum tokens, try reducing this number.
68
+ > "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
55
69
  log : bool, optional
56
70
  Whether or not to log the code generated and any errors that occur.
57
71
  Defaults to False.
58
72
  log_path : str, optional
59
73
  The path to the directory where the log files should be stored. Defaults to "logs/".
74
+ file_name : str, optional
75
+ The name of the file to save the response to. Defaults to "data_wrangler.py".
60
76
  overwrite : bool, optional
61
77
  Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
62
78
  Defaults to True.
@@ -94,7 +110,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
94
110
 
95
111
  Returns
96
112
  -------
97
- app : langchain.graphs.StateGraph
113
+ app : langchain.graphs.CompiledStateGraph
98
114
  The data wrangling agent as a state graph.
99
115
  """
100
116
  llm = model
@@ -122,7 +138,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
122
138
  retry_count: int
123
139
 
124
140
  def recommend_wrangling_steps(state: GraphState):
125
- print("---DATA WRANGLING AGENT----")
141
+ print(format_agent_name(AGENT_NAME))
126
142
  print(" * RECOMMEND WRANGLING STEPS")
127
143
 
128
144
  data_raw = state.get("data_raw")
@@ -143,7 +159,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
143
159
 
144
160
  # Create a summary for all datasets
145
161
  # We'll include a short sample and info for each dataset
146
- all_datasets_summary = get_dataframe_summary(dataframes)
162
+ all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
147
163
 
148
164
  # Join all datasets summaries into one big text block
149
165
  all_datasets_summary_str = "\n\n".join(all_datasets_summary)
@@ -176,6 +192,7 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
176
192
 
177
193
  Avoid these:
178
194
  1. Do not include steps to save files.
195
+ 2. Do not include unrelated user instructions that are not related to the data wrangling.
179
196
  """,
180
197
  input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
181
198
  )
@@ -195,7 +212,34 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
195
212
 
196
213
  def create_data_wrangler_code(state: GraphState):
197
214
  if bypass_recommended_steps:
198
- print("---DATA WRANGLING AGENT----")
215
+ print(format_agent_name(AGENT_NAME))
216
+
217
+ data_raw = state.get("data_raw")
218
+
219
+ if isinstance(data_raw, dict):
220
+ # Single dataset scenario
221
+ primary_dataset_name = "main"
222
+ datasets = {primary_dataset_name: data_raw}
223
+ elif isinstance(data_raw, list) and all(isinstance(item, dict) for item in data_raw):
224
+ # Multiple datasets scenario
225
+ datasets = {f"dataset_{i}": d for i, d in enumerate(data_raw, start=1)}
226
+ primary_dataset_name = "dataset_1"
227
+ else:
228
+ raise ValueError("data_raw must be a dict or a list of dicts.")
229
+
230
+ # Convert all datasets to DataFrames for inspection
231
+ dataframes = {name: pd.DataFrame.from_dict(d) for name, d in datasets.items()}
232
+
233
+ # Create a summary for all datasets
234
+ # We'll include a short sample and info for each dataset
235
+ all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
236
+
237
+ # Join all datasets summaries into one big text block
238
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
239
+
240
+ else:
241
+ all_datasets_summary_str = state.get("all_datasets_summary")
242
+
199
243
  print(" * CREATE DATA WRANGLER CODE")
200
244
 
201
245
  data_wrangling_prompt = PromptTemplate(
@@ -242,16 +286,16 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
242
286
 
243
287
  response = data_wrangling_agent.invoke({
244
288
  "recommended_steps": state.get("recommended_steps"),
245
- "all_datasets_summary": state.get("all_datasets_summary")
289
+ "all_datasets_summary": all_datasets_summary_str
246
290
  })
247
291
 
248
292
  response = relocate_imports_inside_function(response)
249
293
  response = add_comments_to_top(response, agent_name=AGENT_NAME)
250
294
 
251
295
  # For logging: store the code generated
252
- file_path, file_name = log_ai_function(
296
+ file_path, file_name_2 = log_ai_function(
253
297
  response=response,
254
- file_name="data_wrangler.py",
298
+ file_name=file_name,
255
299
  log=log,
256
300
  log_path=log_path,
257
301
  overwrite=overwrite
@@ -260,7 +304,8 @@ def make_data_wrangling_agent(model, log=False, log_path=None, overwrite = True,
260
304
  return {
261
305
  "data_wrangler_function" : response,
262
306
  "data_wrangler_function_path": file_path,
263
- "data_wrangler_function_name": file_name
307
+ "data_wrangler_function_name": file_name_2,
308
+ "all_datasets_summary": all_datasets_summary_str
264
309
  }
265
310
 
266
311
 
@@ -17,7 +17,7 @@ import os
17
17
  import io
18
18
  import pandas as pd
19
19
 
20
- from ai_data_science_team.templates.agent_templates import(
20
+ from ai_data_science_team.templates import(
21
21
  node_func_execute_agent_code_on_data,
22
22
  node_func_human_review,
23
23
  node_func_fix_agent_code,
@@ -25,7 +25,7 @@ from ai_data_science_team.templates.agent_templates import(
25
25
  create_coding_agent_graph
26
26
  )
27
27
  from ai_data_science_team.tools.parsers import PythonOutputParser
28
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
28
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
29
29
  from ai_data_science_team.tools.metadata import get_dataframe_summary
30
30
  from ai_data_science_team.tools.logging import log_ai_function
31
31
 
@@ -35,7 +35,17 @@ LOG_PATH = os.path.join(os.getcwd(), "logs/")
35
35
 
36
36
  # * Feature Engineering Agent
37
37
 
38
- def make_feature_engineering_agent(model, log=False, log_path=None, overwrite = True, human_in_the_loop=False, bypass_recommended_steps=False, bypass_explain_code=False):
38
+ def make_feature_engineering_agent(
39
+ model,
40
+ n_samples=30,
41
+ log=False,
42
+ log_path=None,
43
+ file_name="feature_engineer.py",
44
+ overwrite = True,
45
+ human_in_the_loop=False,
46
+ bypass_recommended_steps=False,
47
+ bypass_explain_code=False,
48
+ ):
39
49
  """
40
50
  Creates a feature engineering agent that can be run on a dataset. The agent applies various feature engineering
41
51
  techniques, such as encoding categorical variables, scaling numeric variables, creating interaction terms,
@@ -61,11 +71,17 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
61
71
  ----------
62
72
  model : langchain.llms.base.LLM
63
73
  The language model to use to generate code.
74
+ n_samples : int, optional
75
+ The number of data samples to use for generating the feature engineering code. Defaults to 30.
76
+ If you get an error due to maximum tokens, try reducing this number.
77
+ > "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
64
78
  log : bool, optional
65
79
  Whether or not to log the code generated and any errors that occur.
66
80
  Defaults to False.
67
81
  log_path : str, optional
68
82
  The path to the directory where the log files should be stored. Defaults to "logs/".
83
+ file_name : str, optional
84
+ The name of the file to save the log to. Defaults to "feature_engineer.py".
69
85
  overwrite : bool, optional
70
86
  Whether or not to overwrite the log file if it already exists. If False, a unique file name will be created.
71
87
  Defaults to True.
@@ -102,7 +118,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
102
118
 
103
119
  Returns
104
120
  -------
105
- app : langchain.graphs.StateGraph
121
+ app : langchain.graphs.CompiledStateGraph
106
122
  The feature engineering agent as a state graph.
107
123
  """
108
124
  llm = model
@@ -135,7 +151,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
135
151
  Recommend a series of feature engineering steps based on the input data.
136
152
  These recommended steps will be appended to the user_instructions.
137
153
  """
138
- print("---FEATURE ENGINEERING AGENT----")
154
+ print(format_agent_name(AGENT_NAME))
139
155
  print(" * RECOMMEND FEATURE ENGINEERING STEPS")
140
156
 
141
157
  # Prompt to get recommended steps from the LLM
@@ -182,6 +198,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
182
198
 
183
199
  Avoid these:
184
200
  1. Do not include steps to save files.
201
+ 2. Do not include unrelated user instructions that are not related to the feature engineering.
185
202
  """,
186
203
  input_variables=["user_instructions", "recommended_steps", "all_datasets_summary"]
187
204
  )
@@ -189,7 +206,7 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
189
206
  data_raw = state.get("data_raw")
190
207
  df = pd.DataFrame.from_dict(data_raw)
191
208
 
192
- all_datasets_summary = get_dataframe_summary([df])
209
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
193
210
 
194
211
  all_datasets_summary_str = "\n\n".join(all_datasets_summary)
195
212
 
@@ -217,7 +234,18 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
217
234
 
218
235
  def create_feature_engineering_code(state: GraphState):
219
236
  if bypass_recommended_steps:
220
- print("---FEATURE ENGINEERING AGENT----")
237
+ print(format_agent_name(AGENT_NAME))
238
+
239
+ data_raw = state.get("data_raw")
240
+ df = pd.DataFrame.from_dict(data_raw)
241
+
242
+ all_datasets_summary = get_dataframe_summary([df], n_sample=n_samples)
243
+
244
+ all_datasets_summary_str = "\n\n".join(all_datasets_summary)
245
+
246
+ else:
247
+ all_datasets_summary_str = state.get("all_datasets_summary")
248
+
221
249
  print(" * CREATE FEATURE ENGINEERING CODE")
222
250
 
223
251
  feature_engineering_prompt = PromptTemplate(
@@ -272,16 +300,16 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
272
300
  response = feature_engineering_agent.invoke({
273
301
  "recommended_steps": state.get("recommended_steps"),
274
302
  "target_variable": state.get("target_variable"),
275
- "all_datasets_summary": state.get("all_datasets_summary"),
303
+ "all_datasets_summary": all_datasets_summary_str,
276
304
  })
277
305
 
278
306
  response = relocate_imports_inside_function(response)
279
307
  response = add_comments_to_top(response, agent_name=AGENT_NAME)
280
308
 
281
309
  # For logging: store the code generated
282
- file_path, file_name = log_ai_function(
310
+ file_path, file_name_2 = log_ai_function(
283
311
  response=response,
284
- file_name="feature_engineer.py",
312
+ file_name=file_name,
285
313
  log=log,
286
314
  log_path=log_path,
287
315
  overwrite=overwrite
@@ -290,7 +318,8 @@ def make_feature_engineering_agent(model, log=False, log_path=None, overwrite =
290
318
  return {
291
319
  "feature_engineer_function": response,
292
320
  "feature_engineer_function_path": file_path,
293
- "feature_engineer_function_name": file_name
321
+ "feature_engineer_function_name": file_name_2,
322
+ "all_datasets_summary": all_datasets_summary_str
294
323
  }
295
324
 
296
325
 
@@ -14,7 +14,7 @@ import io
14
14
  import pandas as pd
15
15
  import sqlalchemy as sql
16
16
 
17
- from ai_data_science_team.templates.agent_templates import(
17
+ from ai_data_science_team.templates import(
18
18
  node_func_execute_agent_from_sql_connection,
19
19
  node_func_human_review,
20
20
  node_func_fix_agent_code,
@@ -22,7 +22,7 @@ from ai_data_science_team.templates.agent_templates import(
22
22
  create_coding_agent_graph
23
23
  )
24
24
  from ai_data_science_team.tools.parsers import PythonOutputParser, SQLOutputParser
25
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
25
+ from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top, format_agent_name
26
26
  from ai_data_science_team.tools.metadata import get_database_metadata
27
27
  from ai_data_science_team.tools.logging import log_ai_function
28
28
 
@@ -31,7 +31,16 @@ AGENT_NAME = "sql_database_agent"
31
31
  LOG_PATH = os.path.join(os.getcwd(), "logs/")
32
32
 
33
33
 
34
- def make_sql_database_agent(model, connection, log=False, log_path=None, overwrite = True, human_in_the_loop=False, bypass_recommended_steps=False, bypass_explain_code=False):
34
+ def make_sql_database_agent(
35
+ model, connection,
36
+ n_samples = 10,
37
+ log=False,
38
+ log_path=None,
39
+ file_name="sql_database.py",
40
+ overwrite = True,
41
+ human_in_the_loop=False, bypass_recommended_steps=False,
42
+ bypass_explain_code=False
43
+ ):
35
44
  """
36
45
  Creates a SQL Database Agent that can recommend SQL steps and generate SQL code to query a database.
37
46
 
@@ -41,10 +50,16 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
41
50
  The language model to use for the agent.
42
51
  connection : sqlalchemy.engine.base.Engine
43
52
  The connection to the SQL database.
53
+ n_samples : int, optional
54
+ The number of samples to retrieve for each column, by default 10.
55
+ If you get an error due to maximum tokens, try reducing this number.
56
+ > "This model's maximum context length is 128000 tokens. However, your messages resulted in 333858 tokens. Please reduce the length of the messages."
44
57
  log : bool, optional
45
58
  Whether to log the generated code, by default False
46
59
  log_path : str, optional
47
60
  The path to the log directory, by default None
61
+ file_name : str, optional
62
+ The name of the file to save the generated code, by default "sql_database.py"
48
63
  overwrite : bool, optional
49
64
  Whether to overwrite the existing log file, by default True
50
65
  human_in_the_loop : bool, optional
@@ -56,7 +71,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
56
71
 
57
72
  Returns
58
73
  -------
59
- app : langchain.graphs.StateGraph
74
+ app : langchain.graphs.CompiledStateGraph
60
75
  The data cleaning agent as a state graph.
61
76
 
62
77
  Examples
@@ -116,8 +131,8 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
116
131
 
117
132
  def recommend_sql_steps(state: GraphState):
118
133
 
119
- print("---SQL DATABASE AGENT---")
120
- print(" * RECOMMEND SQL QUERY STEPS")
134
+ print(format_agent_name(AGENT_NAME))
135
+ print(" * RECOMMEND STEPS")
121
136
 
122
137
 
123
138
  # Prompt to get recommended steps from the LLM
@@ -156,6 +171,8 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
156
171
  2. Do not include steps to modify existing tables, create new tables or modify the database schema.
157
172
  3. Do not include steps that alter the existing data in the database.
158
173
  4. Make sure not to include unsafe code that could cause data loss or corruption or SQL injections.
174
+ 5. Make sure to not include irrelevant steps that do not help in the SQL agent's data collection and processing. Examples include steps to create new tables, modify the schema, save files, create charts, etc.
175
+
159
176
 
160
177
  """,
161
178
  input_variables=["user_instructions", "recommended_steps", "all_sql_database_summary"]
@@ -166,7 +183,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
166
183
  conn = connection.connect() if is_engine else connection
167
184
 
168
185
  # Get the database metadata
169
- all_sql_database_summary = get_database_metadata(conn, n_values=10)
186
+ all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
170
187
 
171
188
  steps_agent = recommend_steps_prompt | llm
172
189
 
@@ -183,7 +200,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
183
200
 
184
201
  def create_sql_query_code(state: GraphState):
185
202
  if bypass_recommended_steps:
186
- print("---SQL DATABASE AGENT---")
203
+ print(format_agent_name(AGENT_NAME))
187
204
  print(" * CREATE SQL QUERY CODE")
188
205
 
189
206
  # Prompt to get the SQL code from the LLM
@@ -228,7 +245,7 @@ def make_sql_database_agent(model, connection, log=False, log_path=None, overwri
228
245
  conn = connection.connect() if is_engine else connection
229
246
 
230
247
  # Get the database metadata
231
- all_sql_database_summary = get_database_metadata(conn, n_values=10)
248
+ all_sql_database_summary = get_database_metadata(conn, n_samples=n_samples)
232
249
 
233
250
  sql_query_code_agent = sql_query_code_prompt | llm | SQLOutputParser()
234
251
 
@@ -259,9 +276,9 @@ def sql_database_pipeline(connection):
259
276
  response = add_comments_to_top(response, AGENT_NAME)
260
277
 
261
278
  # For logging: store the code generated
262
- file_path, file_name = log_ai_function(
279
+ file_path, file_name_2 = log_ai_function(
263
280
  response=response,
264
- file_name="sql_database.py",
281
+ file_name=file_name,
265
282
  log=log,
266
283
  log_path=log_path,
267
284
  overwrite=overwrite
@@ -271,7 +288,8 @@ def sql_database_pipeline(connection):
271
288
  "sql_query_code": sql_query_code,
272
289
  "sql_database_function": response,
273
290
  "sql_database_function_path": file_path,
274
- "sql_database_function_name": file_name
291
+ "sql_database_function_name": file_name_2,
292
+ "all_sql_database_summary": all_sql_database_summary
275
293
  }
276
294
 
277
295
  def human_review(state: GraphState) -> Command[Literal["recommend_sql_steps", "create_sql_query_code"]]:
@@ -0,0 +1,8 @@
1
+ from ai_data_science_team.templates.agent_templates import(
2
+ node_func_execute_agent_code_on_data,
3
+ node_func_human_review,
4
+ node_func_fix_agent_code,
5
+ node_func_explain_agent_code,
6
+ node_func_execute_agent_from_sql_connection,
7
+ create_coding_agent_graph
8
+ )
@@ -4,7 +4,9 @@ import sqlalchemy as sql
4
4
  from typing import Union, List, Dict
5
5
 
6
6
  def get_dataframe_summary(
7
- dataframes: Union[pd.DataFrame, List[pd.DataFrame], Dict[str, pd.DataFrame]]
7
+ dataframes: Union[pd.DataFrame, List[pd.DataFrame], Dict[str, pd.DataFrame]],
8
+ n_sample: int = 30,
9
+ skip_stats: bool = False,
8
10
  ) -> List[str]:
9
11
  """
10
12
  Generate a summary for one or more DataFrames. Accepts a single DataFrame, a list of DataFrames,
@@ -16,6 +18,10 @@ def get_dataframe_summary(
16
18
  - Single DataFrame: produce a single summary (returned within a one-element list).
17
19
  - List of DataFrames: produce a summary for each DataFrame, using index-based names.
18
20
  - Dictionary of DataFrames: produce a summary for each DataFrame, using dictionary keys as names.
21
+ n_sample : int, default 30
22
+ Number of rows to display in the "Data (first 30 rows)" section.
23
+ skip_stats : bool, default False
24
+ If True, skip the descriptive statistics and DataFrame info sections.
19
25
 
20
26
  Example:
21
27
  --------
@@ -49,17 +55,17 @@ def get_dataframe_summary(
49
55
  # --- Dictionary Case ---
50
56
  if isinstance(dataframes, dict):
51
57
  for dataset_name, df in dataframes.items():
52
- summaries.append(_summarize_dataframe(df, dataset_name))
58
+ summaries.append(_summarize_dataframe(df, dataset_name, n_sample, skip_stats))
53
59
 
54
60
  # --- Single DataFrame Case ---
55
61
  elif isinstance(dataframes, pd.DataFrame):
56
- summaries.append(_summarize_dataframe(dataframes, "Single_Dataset"))
62
+ summaries.append(_summarize_dataframe(dataframes, "Single_Dataset", n_sample, skip_stats))
57
63
 
58
64
  # --- List of DataFrames Case ---
59
65
  elif isinstance(dataframes, list):
60
66
  for idx, df in enumerate(dataframes):
61
67
  dataset_name = f"Dataset_{idx}"
62
- summaries.append(_summarize_dataframe(df, dataset_name))
68
+ summaries.append(_summarize_dataframe(df, dataset_name, n_sample, skip_stats))
63
69
 
64
70
  else:
65
71
  raise TypeError(
@@ -69,7 +75,7 @@ def get_dataframe_summary(
69
75
  return summaries
70
76
 
71
77
 
72
- def _summarize_dataframe(df: pd.DataFrame, dataset_name: str) -> str:
78
+ def _summarize_dataframe(df: pd.DataFrame, dataset_name: str, n_sample=30, skip_stats=False) -> str:
73
79
  """Generate a summary string for a single DataFrame."""
74
80
  # 1. Convert dictionary-type cells to strings
75
81
  # This prevents unhashable dict errors during df.nunique().
@@ -91,77 +97,134 @@ def _summarize_dataframe(df: pd.DataFrame, dataset_name: str) -> str:
91
97
  unique_counts = df.nunique() # Will no longer fail on unhashable dict
92
98
  unique_counts_summary = "\n".join([f"{col}: {count}" for col, count in unique_counts.items()])
93
99
 
94
- summary_text = f"""
95
- Dataset Name: {dataset_name}
96
- ----------------------------
97
- Shape: {df.shape[0]} rows x {df.shape[1]} columns
100
+ # 6. Generate the summary text
101
+ if not skip_stats:
102
+ summary_text = f"""
103
+ Dataset Name: {dataset_name}
104
+ ----------------------------
105
+ Shape: {df.shape[0]} rows x {df.shape[1]} columns
98
106
 
99
- Column Data Types:
100
- {column_types}
107
+ Column Data Types:
108
+ {column_types}
101
109
 
102
- Missing Value Percentage:
103
- {missing_summary}
110
+ Missing Value Percentage:
111
+ {missing_summary}
104
112
 
105
- Unique Value Counts:
106
- {unique_counts_summary}
113
+ Unique Value Counts:
114
+ {unique_counts_summary}
107
115
 
108
- Data (first 30 rows):
109
- {df.head(30).to_string()}
116
+ Data (first {n_sample} rows):
117
+ {df.head(n_sample).to_string()}
110
118
 
111
- Data Description:
112
- {df.describe().to_string()}
119
+ Data Description:
120
+ {df.describe().to_string()}
113
121
 
114
- Data Info:
115
- {info_text}
116
- """
122
+ Data Info:
123
+ {info_text}
124
+ """
125
+ else:
126
+ summary_text = f"""
127
+ Dataset Name: {dataset_name}
128
+ ----------------------------
129
+ Shape: {df.shape[0]} rows x {df.shape[1]} columns
130
+
131
+ Column Data Types:
132
+ {column_types}
133
+
134
+ Data (first {n_sample} rows):
135
+ {df.head(n_sample).to_string()}
136
+ """
137
+
117
138
  return summary_text.strip()
118
139
 
119
140
 
120
- def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engine.base.Engine], n_values: int=10):
121
- """
122
- Collects metadata and sample data from a database.
123
141
 
124
- Parameters:
125
- -----------
126
- connection (sqlalchemy.engine.base.Connection or sqlalchemy.engine.base.Engine):
142
+ def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engine.base.Engine],
143
+ n_samples: int = 10) -> str:
144
+ """
145
+ Collects metadata and sample data from a database, with safe identifier quoting and
146
+ basic dialect-aware row limiting. Prevents issues with spaces/reserved words in identifiers.
147
+
148
+ Parameters
149
+ ----------
150
+ connection : Union[sql.engine.base.Connection, sql.engine.base.Engine]
127
151
  An active SQLAlchemy connection or engine.
128
- n_values (int):
152
+ n_samples : int
129
153
  Number of sample values to retrieve for each column.
130
154
 
131
- Returns:
132
- --------
133
- str: Formatted text with database metadata.
155
+ Returns
156
+ -------
157
+ str
158
+ A formatted string with database metadata, including some sample data from each column.
134
159
  """
160
+
135
161
  # If a connection is passed, use it; if an engine is passed, connect to it
136
162
  is_engine = isinstance(connection, sql.engine.base.Engine)
137
163
  conn = connection.connect() if is_engine else connection
138
- output = []
139
164
 
165
+ output = []
140
166
  try:
141
- # Engine metadata
167
+ # Grab the engine off the connection
142
168
  sql_engine = conn.engine
169
+ dialect_name = sql_engine.dialect.name.lower()
170
+
143
171
  output.append(f"Database Dialect: {sql_engine.dialect.name}")
144
172
  output.append(f"Driver: {sql_engine.driver}")
145
173
  output.append(f"Connection URL: {sql_engine.url}")
146
-
174
+
147
175
  # Inspect the database
148
176
  inspector = sql.inspect(sql_engine)
149
- output.append(f"Tables: {inspector.get_table_names()}")
177
+ tables = inspector.get_table_names()
178
+ output.append(f"Tables: {tables}")
150
179
  output.append(f"Schemas: {inspector.get_schema_names()}")
151
-
152
- # For each table, get the columns and their metadata
153
- for table_name in inspector.get_table_names():
180
+
181
+ # Helper to build a dialect-specific limit clause
182
+ def build_query(col_name_quoted: str, table_name_quoted: str, n: int) -> str:
183
+ """
184
+ Returns a SQL query string to select N rows from the given column/table
185
+ across different dialects (SQLite, MySQL, Postgres, MSSQL, Oracle, etc.)
186
+ """
187
+ if "sqlite" in dialect_name or "mysql" in dialect_name or "postgres" in dialect_name:
188
+ # Common dialects supporting LIMIT
189
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
190
+ elif "mssql" in dialect_name:
191
+ # Microsoft SQL Server syntax
192
+ return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted}"
193
+ elif "oracle" in dialect_name:
194
+ # Oracle syntax
195
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
196
+ else:
197
+ # Fallback
198
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
199
+
200
+ # Prepare for quoting
201
+ preparer = inspector.bind.dialect.identifier_preparer
202
+
203
+ # For each table, get columns and sample data
204
+ for table_name in tables:
154
205
  output.append(f"\nTable: {table_name}")
206
+ # Properly quote the table name
207
+ table_name_quoted = preparer.quote_identifier(table_name)
208
+
155
209
  for column in inspector.get_columns(table_name):
156
- output.append(f" Column: {column['name']} Type: {column['type']}")
157
- # Fetch sample values for the column
158
- query = f"SELECT {column['name']} FROM {table_name} LIMIT {n_values}"
159
- data = pd.read_sql(query, sql_engine)
160
- output.append(f" First {n_values} Values: {data.values.flatten().tolist()}")
210
+ col_name = column["name"]
211
+ col_type = column["type"]
212
+ output.append(f" Column: {col_name} Type: {col_type}")
213
+
214
+ # Properly quote the column name
215
+ col_name_quoted = preparer.quote_identifier(col_name)
216
+
217
+ # Build a dialect-aware query with safe quoting
218
+ query = build_query(col_name_quoted, table_name_quoted, n_samples)
219
+
220
+ # Read a few sample values
221
+ df = pd.read_sql(sql.text(query), conn)
222
+ first_values = df[col_name].tolist()
223
+ output.append(f" First {n_samples} Values: {first_values}")
224
+
161
225
  finally:
162
- # Close connection if it was created inside this function
226
+ # Close connection if created inside the function
163
227
  if is_engine:
164
228
  conn.close()
165
-
166
- # Join all collected information into a single string
229
+
167
230
  return "\n".join(output)
@@ -71,3 +71,9 @@ def add_comments_to_top(code_text, agent_name="data_wrangler"):
71
71
  # Join the header with newlines, then prepend to the existing code_text
72
72
  header_block = "\n".join(header_comments)
73
73
  return header_block + code_text
74
+
75
+ def format_agent_name(agent_name: str) -> str:
76
+
77
+ formatted_name = agent_name.strip().replace("_", " ").upper()
78
+
79
+ return f"---{formatted_name}----"