ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (25) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +4 -5
  3. ai_data_science_team/agents/data_cleaning_agent.py +268 -116
  4. ai_data_science_team/agents/data_visualization_agent.py +470 -41
  5. ai_data_science_team/agents/data_wrangling_agent.py +471 -31
  6. ai_data_science_team/agents/feature_engineering_agent.py +426 -41
  7. ai_data_science_team/agents/sql_database_agent.py +458 -58
  8. ai_data_science_team/ml_agents/__init__.py +1 -0
  9. ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
  10. ai_data_science_team/multiagents/__init__.py +1 -0
  11. ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
  12. ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
  13. ai_data_science_team/templates/__init__.py +3 -1
  14. ai_data_science_team/templates/agent_templates.py +319 -43
  15. ai_data_science_team/tools/metadata.py +94 -62
  16. ai_data_science_team/tools/regex.py +86 -1
  17. ai_data_science_team/utils/__init__.py +0 -0
  18. ai_data_science_team/utils/plotly.py +24 -0
  19. ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
  20. ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
  21. ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
  22. ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
  23. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
  24. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
  25. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,225 @@
1
1
  from langchain_core.messages import AIMessage
2
2
  from langgraph.graph import StateGraph, END
3
3
  from langgraph.types import interrupt, Command
4
+ from langgraph.graph.state import CompiledStateGraph
5
+
6
+ from langchain_core.runnables import RunnableConfig
7
+ from langgraph.pregel.types import StreamMode
4
8
 
5
9
  import pandas as pd
6
10
  import sqlalchemy as sql
11
+ import json
7
12
 
8
- from typing import Any, Callable, Dict, Type, Optional
13
+ from typing import Any, Callable, Dict, Type, Optional, Union, List
9
14
 
10
15
  from ai_data_science_team.tools.parsers import PythonOutputParser
11
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
16
+ from ai_data_science_team.tools.regex import (
17
+ relocate_imports_inside_function,
18
+ add_comments_to_top,
19
+ remove_consecutive_duplicates
20
+ )
21
+
22
+ from IPython.display import Image, display
23
+ import pandas as pd
24
+
25
+ class BaseAgent(CompiledStateGraph):
26
+ """
27
+ A generic base class for agents that interact with compiled state graphs.
28
+
29
+ Provides shared functionality for handling parameters, responses, and state
30
+ graph operations.
31
+ """
32
+
33
+ def __init__(self, **params):
34
+ """
35
+ Initialize the agent with provided parameters.
36
+
37
+ Parameters:
38
+ **params: Arbitrary keyword arguments representing the agent's parameters.
39
+ """
40
+ self._params = params
41
+ self._compiled_graph = self._make_compiled_graph()
42
+ self.response = None
43
+
44
+ def _make_compiled_graph(self):
45
+ """
46
+ Subclasses should override this method to create a specific compiled graph.
47
+ """
48
+ raise NotImplementedError("Subclasses must implement the `_make_compiled_graph` method.")
49
+
50
+ def update_params(self, **kwargs):
51
+ """
52
+ Update one or more parameters and rebuild the compiled graph.
53
+
54
+ Parameters:
55
+ **kwargs: Parameters to update.
56
+ """
57
+ self._params.update(kwargs)
58
+ self._compiled_graph = self._make_compiled_graph()
59
+
60
+ def __getattr__(self, name: str):
61
+ """
62
+ Delegate attribute access to the compiled graph if the attribute is not found.
63
+
64
+ Parameters:
65
+ name (str): The attribute name.
66
+
67
+ Returns:
68
+ Any: The attribute from the compiled graph.
69
+ """
70
+ return getattr(self._compiled_graph, name)
71
+
72
+ def invoke(
73
+ self,
74
+ input: Union[dict[str, Any], Any],
75
+ config: Optional[RunnableConfig] = None,
76
+ **kwargs
77
+ ):
78
+ """
79
+ Wrapper for self._compiled_graph.invoke()
80
+
81
+ Parameters:
82
+ input: The input data for the graph. It can be a dictionary or any other type.
83
+ config: Optional. The configuration for the graph run.
84
+ **kwarg: Arguments to pass to self._compiled_graph.invoke()
85
+
86
+ Returns:
87
+ Any: The agent's response.
88
+ """
89
+ self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
90
+
91
+ if self.response.get("messages"):
92
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
93
+
94
+ return self.response
95
+
96
+ def ainvoke(
97
+ self,
98
+ input: Union[dict[str, Any], Any],
99
+ config: Optional[RunnableConfig] = None,
100
+ **kwargs
101
+ ):
102
+ """
103
+ Wrapper for self._compiled_graph.ainvoke()
104
+
105
+ Parameters:
106
+ input: The input data for the graph. It can be a dictionary or any other type.
107
+ config: Optional. The configuration for the graph run.
108
+ **kwarg: Arguments to pass to self._compiled_graph.ainvoke()
109
+
110
+ Returns:
111
+ Any: The agent's response.
112
+ """
113
+ self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
114
+
115
+ if self.response.get("messages"):
116
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
117
+
118
+ return self.response
119
+
120
+ def stream(
121
+ self,
122
+ input: dict[str, Any] | Any,
123
+ config: RunnableConfig | None = None,
124
+ stream_mode: StreamMode | list[StreamMode] | None = None,
125
+ **kwargs
126
+ ):
127
+ """
128
+ Wrapper for self._compiled_graph.stream()
129
+
130
+ Parameters:
131
+ input: The input to the graph.
132
+ config: The configuration to use for the run.
133
+ stream_mode: The mode to stream output, defaults to self.stream_mode.
134
+ Options are 'values', 'updates', and 'debug'.
135
+ values: Emit the current values of the state for each step.
136
+ updates: Emit only the updates to the state for each step.
137
+ Output is a dict with the node name as key and the updated values as value.
138
+ debug: Emit debug events for each step.
139
+ **kwarg: Arguments to pass to self._compiled_graph.stream()
140
+
141
+ Returns:
142
+ Any: The agent's response.
143
+ """
144
+ self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
145
+
146
+ if self.response.get("messages"):
147
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
148
+
149
+ return self.response
150
+
151
+ def astream(
152
+ self,
153
+ input: dict[str, Any] | Any,
154
+ config: RunnableConfig | None = None,
155
+ stream_mode: StreamMode | list[StreamMode] | None = None,
156
+ **kwargs
157
+ ):
158
+ """
159
+ Wrapper for self._compiled_graph.astream()
160
+
161
+ Parameters:
162
+ input: The input to the graph.
163
+ config: The configuration to use for the run.
164
+ stream_mode: The mode to stream output, defaults to self.stream_mode.
165
+ Options are 'values', 'updates', and 'debug'.
166
+ values: Emit the current values of the state for each step.
167
+ updates: Emit only the updates to the state for each step.
168
+ Output is a dict with the node name as key and the updated values as value.
169
+ debug: Emit debug events for each step.
170
+ **kwarg: Arguments to pass to self._compiled_graph.astream()
171
+
172
+ Returns:
173
+ Any: The agent's response.
174
+ """
175
+ self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
176
+
177
+ if self.response.get("messages"):
178
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
179
+
180
+ return self.response
181
+
182
+ def get_state_keys(self):
183
+ """
184
+ Returns a list of keys that the state graph response contains.
185
+
186
+ Returns:
187
+ list: A list of keys in the response.
188
+ """
189
+ return list(self.get_output_jsonschema()['properties'].keys())
190
+
191
+ def get_state_properties(self):
192
+ """
193
+ Returns detailed properties of the state graph response.
194
+
195
+ Returns:
196
+ dict: The properties of the response.
197
+ """
198
+ return self.get_output_jsonschema()['properties']
199
+
200
+ def get_response(self):
201
+ """
202
+ Returns the response generated by the agent.
203
+
204
+ Returns:
205
+ Any: The agent's response.
206
+ """
207
+ if self.response.get("messages"):
208
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
209
+
210
+ return self.response
211
+
212
+ def show(self, xray: int = 0):
213
+ """
214
+ Displays the agent's state graph as a Mermaid diagram.
215
+
216
+ Parameters:
217
+ xray (int): If set to 1, displays subgraph levels. Defaults to 0.
218
+ """
219
+ display(Image(self.get_graph(xray=xray).draw_mermaid_png()))
220
+
221
+
222
+
12
223
 
13
224
  def create_coding_agent_graph(
14
225
  GraphState: Type,
@@ -79,35 +290,37 @@ def create_coding_agent_graph(
79
290
 
80
291
  workflow = StateGraph(GraphState)
81
292
 
82
- # Conditionally add the recommended-steps node
83
- if not bypass_recommended_steps:
84
- workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
293
+ # * NODES
85
294
 
86
295
  # Always add create, execute, and fix nodes
87
296
  workflow.add_node(create_code_node_name, node_functions[create_code_node_name])
88
297
  workflow.add_node(execute_code_node_name, node_functions[execute_code_node_name])
89
298
  workflow.add_node(fix_code_node_name, node_functions[fix_code_node_name])
90
299
 
300
+ # Conditionally add the recommended-steps node
301
+ if not bypass_recommended_steps:
302
+ workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
303
+
304
+ # Conditionally add the human review node
305
+ if human_in_the_loop:
306
+ workflow.add_node(human_review_node_name, node_functions[human_review_node_name])
307
+
91
308
  # Conditionally add the explanation node
92
309
  if not bypass_explain_code:
93
310
  workflow.add_node(explain_code_node_name, node_functions[explain_code_node_name])
94
311
 
312
+ # * EDGES
313
+
95
314
  # Set the entry point
96
315
  entry_point = create_code_node_name if bypass_recommended_steps else recommended_steps_node_name
316
+
97
317
  workflow.set_entry_point(entry_point)
98
318
 
99
- # Add edges for recommended steps
100
319
  if not bypass_recommended_steps:
101
- if human_in_the_loop:
102
- workflow.add_edge(recommended_steps_node_name, human_review_node_name)
103
- else:
104
- workflow.add_edge(recommended_steps_node_name, create_code_node_name)
105
- elif human_in_the_loop:
106
- # Skip recommended steps but still include human review
107
- workflow.add_edge(create_code_node_name, human_review_node_name)
320
+ workflow.add_edge(recommended_steps_node_name, create_code_node_name)
108
321
 
109
- # Create -> Execute
110
322
  workflow.add_edge(create_code_node_name, execute_code_node_name)
323
+ workflow.add_edge(fix_code_node_name, execute_code_node_name)
111
324
 
112
325
  # Define a helper to check if we have an error & can still retry
113
326
  def error_and_can_retry(state):
@@ -117,39 +330,43 @@ def create_coding_agent_graph(
117
330
  and state.get(max_retries_key) is not None
118
331
  and state[retry_count_key] < state[max_retries_key]
119
332
  )
120
-
121
- # ---- Split into two branches for bypass_explain_code ----
122
- if not bypass_explain_code:
123
- # If we are NOT bypassing explain, the next node is fix_code if error,
124
- # else explain_code. Then we wire explain_code -> END afterward.
333
+
334
+ # If human in the loop, add a branch for human review
335
+ if human_in_the_loop:
125
336
  workflow.add_conditional_edges(
126
337
  execute_code_node_name,
127
- lambda s: "fix_code" if error_and_can_retry(s) else "explain_code",
338
+ lambda s: "fix_code" if error_and_can_retry(s) else "human_review",
128
339
  {
340
+ "human_review": human_review_node_name,
129
341
  "fix_code": fix_code_node_name,
130
- "explain_code": explain_code_node_name,
131
342
  },
132
343
  )
133
- # Fix code -> Execute again
134
- workflow.add_edge(fix_code_node_name, execute_code_node_name)
135
- # explain_code -> END
136
- workflow.add_edge(explain_code_node_name, END)
137
344
  else:
138
- # If we ARE bypassing explain_code, the next node is fix_code if error,
139
- # else straight to END.
140
- workflow.add_conditional_edges(
141
- execute_code_node_name,
142
- lambda s: "fix_code" if error_and_can_retry(s) else "END",
143
- {
144
- "fix_code": fix_code_node_name,
145
- "END": END,
146
- },
147
- )
148
- # Fix code -> Execute again
149
- workflow.add_edge(fix_code_node_name, execute_code_node_name)
345
+ # If no human review, the next node is fix_code if error, else explain_code.
346
+ if not bypass_explain_code:
347
+ workflow.add_conditional_edges(
348
+ execute_code_node_name,
349
+ lambda s: "fix_code" if error_and_can_retry(s) else "explain_code",
350
+ {
351
+ "fix_code": fix_code_node_name,
352
+ "explain_code": explain_code_node_name,
353
+ },
354
+ )
355
+ else:
356
+ workflow.add_conditional_edges(
357
+ execute_code_node_name,
358
+ lambda s: "fix_code" if error_and_can_retry(s) else "END",
359
+ {
360
+ "fix_code": fix_code_node_name,
361
+ "END": END,
362
+ },
363
+ )
364
+
365
+ if not bypass_explain_code:
366
+ workflow.add_edge(explain_code_node_name, END)
150
367
 
151
368
  # Finally, compile
152
- if human_in_the_loop and checkpointer is not None:
369
+ if human_in_the_loop:
153
370
  app = workflow.compile(checkpointer=checkpointer)
154
371
  else:
155
372
  app = workflow.compile()
@@ -165,6 +382,8 @@ def node_func_human_review(
165
382
  no_goto: str,
166
383
  user_instructions_key: str = "user_instructions",
167
384
  recommended_steps_key: str = "recommended_steps",
385
+ code_snippet_key: str = "code_snippet",
386
+ code_type: str = "python"
168
387
  ) -> Command[str]:
169
388
  """
170
389
  A generic function to handle human review steps.
@@ -183,6 +402,10 @@ def node_func_human_review(
183
402
  The key in the state to store user instructions.
184
403
  recommended_steps_key : str, optional
185
404
  The key in the state to store recommended steps.
405
+ code_snippet_key : str, optional
406
+ The key in the state to store the code snippet.
407
+ code_type : str, optional
408
+ The type of code snippet to display (e.g., "python").
186
409
 
187
410
  Returns
188
411
  -------
@@ -190,9 +413,11 @@ def node_func_human_review(
190
413
  A Command object directing the next state and updates to the state.
191
414
  """
192
415
  print(" * HUMAN REVIEW")
416
+
417
+ code_markdown=f"```{code_type}\n" + state.get(code_snippet_key)+"\n```"
193
418
 
194
419
  # Display instructions and get user response
195
- user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '')))
420
+ user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '') + "\n\n" + code_markdown))
196
421
 
197
422
  # Decide next steps based on user input
198
423
  if user_input.strip().lower() == "yes":
@@ -200,11 +425,11 @@ def node_func_human_review(
200
425
  update = {}
201
426
  else:
202
427
  goto = no_goto
203
- modifications = "Modifications: \n" + user_input
428
+ modifications = "User Has Requested Modifications To Previous Code: \n" + user_input
204
429
  if state.get(user_instructions_key) is None:
205
- update = {user_instructions_key: modifications}
430
+ update = {user_instructions_key: modifications + "\n\nPrevious Code:\n" + code_markdown}
206
431
  else:
207
- update = {user_instructions_key: state.get(user_instructions_key) + modifications}
432
+ update = {user_instructions_key: state.get(user_instructions_key) + modifications + "\n\nPrevious Code:\n" + code_markdown}
208
433
 
209
434
  return Command(goto=goto, update=update)
210
435
 
@@ -394,6 +619,7 @@ def node_func_fix_agent_code(
394
619
  retry_count_key: str = "retry_count",
395
620
  log: bool = False,
396
621
  file_path: str = "logs/agent_function.py",
622
+ function_name: str = "agent_function"
397
623
  ) -> dict:
398
624
  """
399
625
  Generic function to fix a given piece of agent code using an LLM and a prompt template.
@@ -420,6 +646,8 @@ def node_func_fix_agent_code(
420
646
  Whether to log the returned code to a file.
421
647
  file_path : str, optional
422
648
  The path to the file where the code will be logged.
649
+ function_name : str, optional
650
+ The name of the function in the code snippet that will be fixed.
423
651
 
424
652
  Returns
425
653
  -------
@@ -436,7 +664,8 @@ def node_func_fix_agent_code(
436
664
  # Format the prompt with the code snippet and the error
437
665
  prompt = prompt_template.format(
438
666
  code_snippet=code_snippet,
439
- error=error_message
667
+ error=error_message,
668
+ function_name=function_name,
440
669
  )
441
670
 
442
671
  # Execute the prompt with the LLM
@@ -524,3 +753,50 @@ def node_func_explain_agent_code(
524
753
  # Return an error message if there was a problem with the code
525
754
  message = AIMessage(content=error_message)
526
755
  return {result_key: [message]}
756
+
757
+
758
+
759
+ def node_func_report_agent_outputs(
760
+ state: Dict[str, Any],
761
+ keys_to_include: List[str],
762
+ result_key: str,
763
+ role: str,
764
+ custom_title: str = "Agent Output Summary"
765
+ ) -> Dict[str, Any]:
766
+ """
767
+ Gathers relevant data directly from the state (filtered by `keys_to_include`)
768
+ and returns them as a structured message in `state[result_key]`.
769
+
770
+ No LLM is used.
771
+
772
+ Parameters
773
+ ----------
774
+ state : Dict[str, Any]
775
+ The current state dictionary holding all agent variables.
776
+ keys_to_include : List[str]
777
+ The list of keys in `state` to include in the output.
778
+ result_key : str
779
+ The key in `state` under which we'll store the final structured message.
780
+ role : str
781
+ The role that will be used in the final AIMessage (e.g., "DataCleaningAgent").
782
+ custom_title : str, optional
783
+ A title or heading for your report. Defaults to "Agent Output Summary".
784
+ """
785
+ print(" * REPORT AGENT OUTPUTS")
786
+
787
+ final_report = {"report_title": custom_title}
788
+
789
+ for key in keys_to_include:
790
+ final_report[key] = state.get(key, f"<{key}_not_found_in_state>")
791
+
792
+ # Wrap it in a list of messages (like the current "messages" pattern).
793
+ # You can serialize this dictionary as JSON or just cast it to string.
794
+ return {
795
+ result_key: [
796
+ AIMessage(
797
+ content=json.dumps(final_report, indent=2),
798
+ role=role
799
+ )
800
+ ]
801
+ }
802
+
@@ -1,6 +1,7 @@
1
1
  import io
2
2
  import pandas as pd
3
3
  import sqlalchemy as sql
4
+ from sqlalchemy import inspect
4
5
  from typing import Union, List, Dict
5
6
 
6
7
  def get_dataframe_summary(
@@ -139,8 +140,7 @@ def _summarize_dataframe(df: pd.DataFrame, dataset_name: str, n_sample=30, skip_
139
140
 
140
141
 
141
142
 
142
- def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engine.base.Engine],
143
- n_samples: int = 10) -> str:
143
+ def get_database_metadata(connection, n_samples=10) -> dict:
144
144
  """
145
145
  Collects metadata and sample data from a database, with safe identifier quoting and
146
146
  basic dialect-aware row limiting. Prevents issues with spaces/reserved words in identifiers.
@@ -154,77 +154,109 @@ def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engi
154
154
 
155
155
  Returns
156
156
  -------
157
- str
158
- A formatted string with database metadata, including some sample data from each column.
157
+ dict
158
+ A dictionary with database metadata, including some sample data from each column.
159
159
  """
160
-
161
- # If a connection is passed, use it; if an engine is passed, connect to it
162
160
  is_engine = isinstance(connection, sql.engine.base.Engine)
163
161
  conn = connection.connect() if is_engine else connection
164
162
 
165
- output = []
163
+ metadata = {
164
+ "dialect": None,
165
+ "driver": None,
166
+ "connection_url": None,
167
+ "schemas": [],
168
+ }
169
+
166
170
  try:
167
- # Grab the engine off the connection
168
171
  sql_engine = conn.engine
169
172
  dialect_name = sql_engine.dialect.name.lower()
170
173
 
171
- output.append(f"Database Dialect: {sql_engine.dialect.name}")
172
- output.append(f"Driver: {sql_engine.driver}")
173
- output.append(f"Connection URL: {sql_engine.url}")
174
-
175
- # Inspect the database
176
- inspector = sql.inspect(sql_engine)
177
- tables = inspector.get_table_names()
178
- output.append(f"Tables: {tables}")
179
- output.append(f"Schemas: {inspector.get_schema_names()}")
180
-
181
- # Helper to build a dialect-specific limit clause
182
- def build_query(col_name_quoted: str, table_name_quoted: str, n: int) -> str:
183
- """
184
- Returns a SQL query string to select N rows from the given column/table
185
- across different dialects (SQLite, MySQL, Postgres, MSSQL, Oracle, etc.)
186
- """
187
- if "sqlite" in dialect_name or "mysql" in dialect_name or "postgres" in dialect_name:
188
- # Common dialects supporting LIMIT
189
- return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
190
- elif "mssql" in dialect_name:
191
- # Microsoft SQL Server syntax
192
- return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted}"
193
- elif "oracle" in dialect_name:
194
- # Oracle syntax
195
- return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
196
- else:
197
- # Fallback
198
- return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
199
-
200
- # Prepare for quoting
201
- preparer = inspector.bind.dialect.identifier_preparer
202
-
203
- # For each table, get columns and sample data
204
- for table_name in tables:
205
- output.append(f"\nTable: {table_name}")
206
- # Properly quote the table name
207
- table_name_quoted = preparer.quote_identifier(table_name)
208
-
209
- for column in inspector.get_columns(table_name):
210
- col_name = column["name"]
211
- col_type = column["type"]
212
- output.append(f" Column: {col_name} Type: {col_type}")
174
+ metadata["dialect"] = sql_engine.dialect.name
175
+ metadata["driver"] = sql_engine.driver
176
+ metadata["connection_url"] = str(sql_engine.url)
213
177
 
214
- # Properly quote the column name
215
- col_name_quoted = preparer.quote_identifier(col_name)
216
-
217
- # Build a dialect-aware query with safe quoting
218
- query = build_query(col_name_quoted, table_name_quoted, n_samples)
219
-
220
- # Read a few sample values
221
- df = pd.read_sql(sql.text(query), conn)
222
- first_values = df[col_name].tolist()
223
- output.append(f" First {n_samples} Values: {first_values}")
178
+ inspector = inspect(sql_engine)
179
+ preparer = inspector.bind.dialect.identifier_preparer
224
180
 
181
+ # For each schema
182
+ for schema_name in inspector.get_schema_names():
183
+ schema_obj = {
184
+ "schema_name": schema_name,
185
+ "tables": []
186
+ }
187
+
188
+ tables = inspector.get_table_names(schema=schema_name)
189
+ for table_name in tables:
190
+ table_info = {
191
+ "table_name": table_name,
192
+ "columns": [],
193
+ "primary_key": [],
194
+ "foreign_keys": [],
195
+ "indexes": []
196
+ }
197
+ # Get columns
198
+ columns = inspector.get_columns(table_name, schema=schema_name)
199
+ for col in columns:
200
+ col_name = col["name"]
201
+ col_type = str(col["type"])
202
+ table_name_quoted = f"{preparer.quote_identifier(schema_name)}.{preparer.quote_identifier(table_name)}"
203
+ col_name_quoted = preparer.quote_identifier(col_name)
204
+
205
+ # Build query for sample data
206
+ query = build_query(col_name_quoted, table_name_quoted, n_samples, dialect_name)
207
+
208
+ # Retrieve sample data
209
+ try:
210
+ df = pd.read_sql(query, conn)
211
+ samples = df[col_name].head(n_samples).tolist()
212
+ except Exception as e:
213
+ samples = [f"Error retrieving data: {str(e)}"]
214
+
215
+ table_info["columns"].append({
216
+ "name": col_name,
217
+ "type": col_type,
218
+ "sample_values": samples
219
+ })
220
+
221
+ # Primary keys
222
+ pk_constraint = inspector.get_pk_constraint(table_name, schema=schema_name)
223
+ table_info["primary_key"] = pk_constraint.get("constrained_columns", [])
224
+
225
+ # Foreign keys
226
+ fks = inspector.get_foreign_keys(table_name, schema=schema_name)
227
+ table_info["foreign_keys"] = [
228
+ {
229
+ "local_cols": fk["constrained_columns"],
230
+ "referred_table": fk["referred_table"],
231
+ "referred_cols": fk["referred_columns"]
232
+ }
233
+ for fk in fks
234
+ ]
235
+
236
+ # Indexes
237
+ idxs = inspector.get_indexes(table_name, schema=schema_name)
238
+ table_info["indexes"] = idxs
239
+
240
+ schema_obj["tables"].append(table_info)
241
+
242
+ metadata["schemas"].append(schema_obj)
243
+
225
244
  finally:
226
- # Close connection if created inside the function
227
245
  if is_engine:
228
246
  conn.close()
229
247
 
230
- return "\n".join(output)
248
+ return metadata
249
+
250
+ def build_query(col_name_quoted: str, table_name_quoted: str, n: int, dialect_name: str) -> str:
251
+ # Example: expand your build_query to handle random sampling if possible
252
+ if "postgres" in dialect_name:
253
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
254
+ if "mysql" in dialect_name:
255
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RAND() LIMIT {n}"
256
+ if "sqlite" in dialect_name:
257
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
258
+ if "mssql" in dialect_name:
259
+ return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted} ORDER BY NEWID()"
260
+ # Oracle or fallback
261
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
262
+