ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +4 -5
  3. ai_data_science_team/agents/data_cleaning_agent.py +268 -116
  4. ai_data_science_team/agents/data_visualization_agent.py +470 -41
  5. ai_data_science_team/agents/data_wrangling_agent.py +471 -31
  6. ai_data_science_team/agents/feature_engineering_agent.py +426 -41
  7. ai_data_science_team/agents/sql_database_agent.py +458 -58
  8. ai_data_science_team/ml_agents/__init__.py +1 -0
  9. ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
  10. ai_data_science_team/multiagents/__init__.py +1 -0
  11. ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
  12. ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
  13. ai_data_science_team/templates/__init__.py +3 -1
  14. ai_data_science_team/templates/agent_templates.py +319 -43
  15. ai_data_science_team/tools/metadata.py +94 -62
  16. ai_data_science_team/tools/regex.py +86 -1
  17. ai_data_science_team/utils/__init__.py +0 -0
  18. ai_data_science_team/utils/plotly.py +24 -0
  19. ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
  20. ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
  21. ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
  22. ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
  23. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
  24. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
  25. {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,225 @@
1
1
  from langchain_core.messages import AIMessage
2
2
  from langgraph.graph import StateGraph, END
3
3
  from langgraph.types import interrupt, Command
4
+ from langgraph.graph.state import CompiledStateGraph
5
+
6
+ from langchain_core.runnables import RunnableConfig
7
+ from langgraph.pregel.types import StreamMode
4
8
 
5
9
  import pandas as pd
6
10
  import sqlalchemy as sql
11
+ import json
7
12
 
8
- from typing import Any, Callable, Dict, Type, Optional
13
+ from typing import Any, Callable, Dict, Type, Optional, Union, List
9
14
 
10
15
  from ai_data_science_team.tools.parsers import PythonOutputParser
11
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
16
+ from ai_data_science_team.tools.regex import (
17
+ relocate_imports_inside_function,
18
+ add_comments_to_top,
19
+ remove_consecutive_duplicates
20
+ )
21
+
22
+ from IPython.display import Image, display
23
+ import pandas as pd
24
+
25
+ class BaseAgent(CompiledStateGraph):
26
+ """
27
+ A generic base class for agents that interact with compiled state graphs.
28
+
29
+ Provides shared functionality for handling parameters, responses, and state
30
+ graph operations.
31
+ """
32
+
33
+ def __init__(self, **params):
34
+ """
35
+ Initialize the agent with provided parameters.
36
+
37
+ Parameters:
38
+ **params: Arbitrary keyword arguments representing the agent's parameters.
39
+ """
40
+ self._params = params
41
+ self._compiled_graph = self._make_compiled_graph()
42
+ self.response = None
43
+
44
+ def _make_compiled_graph(self):
45
+ """
46
+ Subclasses should override this method to create a specific compiled graph.
47
+ """
48
+ raise NotImplementedError("Subclasses must implement the `_make_compiled_graph` method.")
49
+
50
+ def update_params(self, **kwargs):
51
+ """
52
+ Update one or more parameters and rebuild the compiled graph.
53
+
54
+ Parameters:
55
+ **kwargs: Parameters to update.
56
+ """
57
+ self._params.update(kwargs)
58
+ self._compiled_graph = self._make_compiled_graph()
59
+
60
+ def __getattr__(self, name: str):
61
+ """
62
+ Delegate attribute access to the compiled graph if the attribute is not found.
63
+
64
+ Parameters:
65
+ name (str): The attribute name.
66
+
67
+ Returns:
68
+ Any: The attribute from the compiled graph.
69
+ """
70
+ return getattr(self._compiled_graph, name)
71
+
72
+ def invoke(
73
+ self,
74
+ input: Union[dict[str, Any], Any],
75
+ config: Optional[RunnableConfig] = None,
76
+ **kwargs
77
+ ):
78
+ """
79
+ Wrapper for self._compiled_graph.invoke()
80
+
81
+ Parameters:
82
+ input: The input data for the graph. It can be a dictionary or any other type.
83
+ config: Optional. The configuration for the graph run.
84
+ **kwarg: Arguments to pass to self._compiled_graph.invoke()
85
+
86
+ Returns:
87
+ Any: The agent's response.
88
+ """
89
+ self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
90
+
91
+ if self.response.get("messages"):
92
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
93
+
94
+ return self.response
95
+
96
+ def ainvoke(
97
+ self,
98
+ input: Union[dict[str, Any], Any],
99
+ config: Optional[RunnableConfig] = None,
100
+ **kwargs
101
+ ):
102
+ """
103
+ Wrapper for self._compiled_graph.ainvoke()
104
+
105
+ Parameters:
106
+ input: The input data for the graph. It can be a dictionary or any other type.
107
+ config: Optional. The configuration for the graph run.
108
+ **kwarg: Arguments to pass to self._compiled_graph.ainvoke()
109
+
110
+ Returns:
111
+ Any: The agent's response.
112
+ """
113
+ self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
114
+
115
+ if self.response.get("messages"):
116
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
117
+
118
+ return self.response
119
+
120
+ def stream(
121
+ self,
122
+ input: dict[str, Any] | Any,
123
+ config: RunnableConfig | None = None,
124
+ stream_mode: StreamMode | list[StreamMode] | None = None,
125
+ **kwargs
126
+ ):
127
+ """
128
+ Wrapper for self._compiled_graph.stream()
129
+
130
+ Parameters:
131
+ input: The input to the graph.
132
+ config: The configuration to use for the run.
133
+ stream_mode: The mode to stream output, defaults to self.stream_mode.
134
+ Options are 'values', 'updates', and 'debug'.
135
+ values: Emit the current values of the state for each step.
136
+ updates: Emit only the updates to the state for each step.
137
+ Output is a dict with the node name as key and the updated values as value.
138
+ debug: Emit debug events for each step.
139
+ **kwarg: Arguments to pass to self._compiled_graph.stream()
140
+
141
+ Returns:
142
+ Any: The agent's response.
143
+ """
144
+ self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
145
+
146
+ if self.response.get("messages"):
147
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
148
+
149
+ return self.response
150
+
151
+ def astream(
152
+ self,
153
+ input: dict[str, Any] | Any,
154
+ config: RunnableConfig | None = None,
155
+ stream_mode: StreamMode | list[StreamMode] | None = None,
156
+ **kwargs
157
+ ):
158
+ """
159
+ Wrapper for self._compiled_graph.astream()
160
+
161
+ Parameters:
162
+ input: The input to the graph.
163
+ config: The configuration to use for the run.
164
+ stream_mode: The mode to stream output, defaults to self.stream_mode.
165
+ Options are 'values', 'updates', and 'debug'.
166
+ values: Emit the current values of the state for each step.
167
+ updates: Emit only the updates to the state for each step.
168
+ Output is a dict with the node name as key and the updated values as value.
169
+ debug: Emit debug events for each step.
170
+ **kwarg: Arguments to pass to self._compiled_graph.astream()
171
+
172
+ Returns:
173
+ Any: The agent's response.
174
+ """
175
+ self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
176
+
177
+ if self.response.get("messages"):
178
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
179
+
180
+ return self.response
181
+
182
+ def get_state_keys(self):
183
+ """
184
+ Returns a list of keys that the state graph response contains.
185
+
186
+ Returns:
187
+ list: A list of keys in the response.
188
+ """
189
+ return list(self.get_output_jsonschema()['properties'].keys())
190
+
191
+ def get_state_properties(self):
192
+ """
193
+ Returns detailed properties of the state graph response.
194
+
195
+ Returns:
196
+ dict: The properties of the response.
197
+ """
198
+ return self.get_output_jsonschema()['properties']
199
+
200
+ def get_response(self):
201
+ """
202
+ Returns the response generated by the agent.
203
+
204
+ Returns:
205
+ Any: The agent's response.
206
+ """
207
+ if self.response.get("messages"):
208
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
209
+
210
+ return self.response
211
+
212
+ def show(self, xray: int = 0):
213
+ """
214
+ Displays the agent's state graph as a Mermaid diagram.
215
+
216
+ Parameters:
217
+ xray (int): If set to 1, displays subgraph levels. Defaults to 0.
218
+ """
219
+ display(Image(self.get_graph(xray=xray).draw_mermaid_png()))
220
+
221
+
222
+
12
223
 
13
224
  def create_coding_agent_graph(
14
225
  GraphState: Type,
@@ -79,35 +290,37 @@ def create_coding_agent_graph(
79
290
 
80
291
  workflow = StateGraph(GraphState)
81
292
 
82
- # Conditionally add the recommended-steps node
83
- if not bypass_recommended_steps:
84
- workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
293
+ # * NODES
85
294
 
86
295
  # Always add create, execute, and fix nodes
87
296
  workflow.add_node(create_code_node_name, node_functions[create_code_node_name])
88
297
  workflow.add_node(execute_code_node_name, node_functions[execute_code_node_name])
89
298
  workflow.add_node(fix_code_node_name, node_functions[fix_code_node_name])
90
299
 
300
+ # Conditionally add the recommended-steps node
301
+ if not bypass_recommended_steps:
302
+ workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
303
+
304
+ # Conditionally add the human review node
305
+ if human_in_the_loop:
306
+ workflow.add_node(human_review_node_name, node_functions[human_review_node_name])
307
+
91
308
  # Conditionally add the explanation node
92
309
  if not bypass_explain_code:
93
310
  workflow.add_node(explain_code_node_name, node_functions[explain_code_node_name])
94
311
 
312
+ # * EDGES
313
+
95
314
  # Set the entry point
96
315
  entry_point = create_code_node_name if bypass_recommended_steps else recommended_steps_node_name
316
+
97
317
  workflow.set_entry_point(entry_point)
98
318
 
99
- # Add edges for recommended steps
100
319
  if not bypass_recommended_steps:
101
- if human_in_the_loop:
102
- workflow.add_edge(recommended_steps_node_name, human_review_node_name)
103
- else:
104
- workflow.add_edge(recommended_steps_node_name, create_code_node_name)
105
- elif human_in_the_loop:
106
- # Skip recommended steps but still include human review
107
- workflow.add_edge(create_code_node_name, human_review_node_name)
320
+ workflow.add_edge(recommended_steps_node_name, create_code_node_name)
108
321
 
109
- # Create -> Execute
110
322
  workflow.add_edge(create_code_node_name, execute_code_node_name)
323
+ workflow.add_edge(fix_code_node_name, execute_code_node_name)
111
324
 
112
325
  # Define a helper to check if we have an error & can still retry
113
326
  def error_and_can_retry(state):
@@ -117,39 +330,43 @@ def create_coding_agent_graph(
117
330
  and state.get(max_retries_key) is not None
118
331
  and state[retry_count_key] < state[max_retries_key]
119
332
  )
120
-
121
- # ---- Split into two branches for bypass_explain_code ----
122
- if not bypass_explain_code:
123
- # If we are NOT bypassing explain, the next node is fix_code if error,
124
- # else explain_code. Then we wire explain_code -> END afterward.
333
+
334
+ # If human in the loop, add a branch for human review
335
+ if human_in_the_loop:
125
336
  workflow.add_conditional_edges(
126
337
  execute_code_node_name,
127
- lambda s: "fix_code" if error_and_can_retry(s) else "explain_code",
338
+ lambda s: "fix_code" if error_and_can_retry(s) else "human_review",
128
339
  {
340
+ "human_review": human_review_node_name,
129
341
  "fix_code": fix_code_node_name,
130
- "explain_code": explain_code_node_name,
131
342
  },
132
343
  )
133
- # Fix code -> Execute again
134
- workflow.add_edge(fix_code_node_name, execute_code_node_name)
135
- # explain_code -> END
136
- workflow.add_edge(explain_code_node_name, END)
137
344
  else:
138
- # If we ARE bypassing explain_code, the next node is fix_code if error,
139
- # else straight to END.
140
- workflow.add_conditional_edges(
141
- execute_code_node_name,
142
- lambda s: "fix_code" if error_and_can_retry(s) else "END",
143
- {
144
- "fix_code": fix_code_node_name,
145
- "END": END,
146
- },
147
- )
148
- # Fix code -> Execute again
149
- workflow.add_edge(fix_code_node_name, execute_code_node_name)
345
+ # If no human review, the next node is fix_code if error, else explain_code.
346
+ if not bypass_explain_code:
347
+ workflow.add_conditional_edges(
348
+ execute_code_node_name,
349
+ lambda s: "fix_code" if error_and_can_retry(s) else "explain_code",
350
+ {
351
+ "fix_code": fix_code_node_name,
352
+ "explain_code": explain_code_node_name,
353
+ },
354
+ )
355
+ else:
356
+ workflow.add_conditional_edges(
357
+ execute_code_node_name,
358
+ lambda s: "fix_code" if error_and_can_retry(s) else "END",
359
+ {
360
+ "fix_code": fix_code_node_name,
361
+ "END": END,
362
+ },
363
+ )
364
+
365
+ if not bypass_explain_code:
366
+ workflow.add_edge(explain_code_node_name, END)
150
367
 
151
368
  # Finally, compile
152
- if human_in_the_loop and checkpointer is not None:
369
+ if human_in_the_loop:
153
370
  app = workflow.compile(checkpointer=checkpointer)
154
371
  else:
155
372
  app = workflow.compile()
@@ -165,6 +382,8 @@ def node_func_human_review(
165
382
  no_goto: str,
166
383
  user_instructions_key: str = "user_instructions",
167
384
  recommended_steps_key: str = "recommended_steps",
385
+ code_snippet_key: str = "code_snippet",
386
+ code_type: str = "python"
168
387
  ) -> Command[str]:
169
388
  """
170
389
  A generic function to handle human review steps.
@@ -183,6 +402,10 @@ def node_func_human_review(
183
402
  The key in the state to store user instructions.
184
403
  recommended_steps_key : str, optional
185
404
  The key in the state to store recommended steps.
405
+ code_snippet_key : str, optional
406
+ The key in the state to store the code snippet.
407
+ code_type : str, optional
408
+ The type of code snippet to display (e.g., "python").
186
409
 
187
410
  Returns
188
411
  -------
@@ -190,9 +413,11 @@ def node_func_human_review(
190
413
  A Command object directing the next state and updates to the state.
191
414
  """
192
415
  print(" * HUMAN REVIEW")
416
+
417
+ code_markdown=f"```{code_type}\n" + state.get(code_snippet_key)+"\n```"
193
418
 
194
419
  # Display instructions and get user response
195
- user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '')))
420
+ user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '') + "\n\n" + code_markdown))
196
421
 
197
422
  # Decide next steps based on user input
198
423
  if user_input.strip().lower() == "yes":
@@ -200,11 +425,11 @@ def node_func_human_review(
200
425
  update = {}
201
426
  else:
202
427
  goto = no_goto
203
- modifications = "Modifications: \n" + user_input
428
+ modifications = "User Has Requested Modifications To Previous Code: \n" + user_input
204
429
  if state.get(user_instructions_key) is None:
205
- update = {user_instructions_key: modifications}
430
+ update = {user_instructions_key: modifications + "\n\nPrevious Code:\n" + code_markdown}
206
431
  else:
207
- update = {user_instructions_key: state.get(user_instructions_key) + modifications}
432
+ update = {user_instructions_key: state.get(user_instructions_key) + modifications + "\n\nPrevious Code:\n" + code_markdown}
208
433
 
209
434
  return Command(goto=goto, update=update)
210
435
 
@@ -394,6 +619,7 @@ def node_func_fix_agent_code(
394
619
  retry_count_key: str = "retry_count",
395
620
  log: bool = False,
396
621
  file_path: str = "logs/agent_function.py",
622
+ function_name: str = "agent_function"
397
623
  ) -> dict:
398
624
  """
399
625
  Generic function to fix a given piece of agent code using an LLM and a prompt template.
@@ -420,6 +646,8 @@ def node_func_fix_agent_code(
420
646
  Whether to log the returned code to a file.
421
647
  file_path : str, optional
422
648
  The path to the file where the code will be logged.
649
+ function_name : str, optional
650
+ The name of the function in the code snippet that will be fixed.
423
651
 
424
652
  Returns
425
653
  -------
@@ -436,7 +664,8 @@ def node_func_fix_agent_code(
436
664
  # Format the prompt with the code snippet and the error
437
665
  prompt = prompt_template.format(
438
666
  code_snippet=code_snippet,
439
- error=error_message
667
+ error=error_message,
668
+ function_name=function_name,
440
669
  )
441
670
 
442
671
  # Execute the prompt with the LLM
@@ -524,3 +753,50 @@ def node_func_explain_agent_code(
524
753
  # Return an error message if there was a problem with the code
525
754
  message = AIMessage(content=error_message)
526
755
  return {result_key: [message]}
756
+
757
+
758
+
759
+ def node_func_report_agent_outputs(
760
+ state: Dict[str, Any],
761
+ keys_to_include: List[str],
762
+ result_key: str,
763
+ role: str,
764
+ custom_title: str = "Agent Output Summary"
765
+ ) -> Dict[str, Any]:
766
+ """
767
+ Gathers relevant data directly from the state (filtered by `keys_to_include`)
768
+ and returns them as a structured message in `state[result_key]`.
769
+
770
+ No LLM is used.
771
+
772
+ Parameters
773
+ ----------
774
+ state : Dict[str, Any]
775
+ The current state dictionary holding all agent variables.
776
+ keys_to_include : List[str]
777
+ The list of keys in `state` to include in the output.
778
+ result_key : str
779
+ The key in `state` under which we'll store the final structured message.
780
+ role : str
781
+ The role that will be used in the final AIMessage (e.g., "DataCleaningAgent").
782
+ custom_title : str, optional
783
+ A title or heading for your report. Defaults to "Agent Output Summary".
784
+ """
785
+ print(" * REPORT AGENT OUTPUTS")
786
+
787
+ final_report = {"report_title": custom_title}
788
+
789
+ for key in keys_to_include:
790
+ final_report[key] = state.get(key, f"<{key}_not_found_in_state>")
791
+
792
+ # Wrap it in a list of messages (like the current "messages" pattern).
793
+ # You can serialize this dictionary as JSON or just cast it to string.
794
+ return {
795
+ result_key: [
796
+ AIMessage(
797
+ content=json.dumps(final_report, indent=2),
798
+ role=role
799
+ )
800
+ ]
801
+ }
802
+
@@ -1,6 +1,7 @@
1
1
  import io
2
2
  import pandas as pd
3
3
  import sqlalchemy as sql
4
+ from sqlalchemy import inspect
4
5
  from typing import Union, List, Dict
5
6
 
6
7
  def get_dataframe_summary(
@@ -139,8 +140,7 @@ def _summarize_dataframe(df: pd.DataFrame, dataset_name: str, n_sample=30, skip_
139
140
 
140
141
 
141
142
 
142
- def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engine.base.Engine],
143
- n_samples: int = 10) -> str:
143
+ def get_database_metadata(connection, n_samples=10) -> dict:
144
144
  """
145
145
  Collects metadata and sample data from a database, with safe identifier quoting and
146
146
  basic dialect-aware row limiting. Prevents issues with spaces/reserved words in identifiers.
@@ -154,77 +154,109 @@ def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engi
154
154
 
155
155
  Returns
156
156
  -------
157
- str
158
- A formatted string with database metadata, including some sample data from each column.
157
+ dict
158
+ A dictionary with database metadata, including some sample data from each column.
159
159
  """
160
-
161
- # If a connection is passed, use it; if an engine is passed, connect to it
162
160
  is_engine = isinstance(connection, sql.engine.base.Engine)
163
161
  conn = connection.connect() if is_engine else connection
164
162
 
165
- output = []
163
+ metadata = {
164
+ "dialect": None,
165
+ "driver": None,
166
+ "connection_url": None,
167
+ "schemas": [],
168
+ }
169
+
166
170
  try:
167
- # Grab the engine off the connection
168
171
  sql_engine = conn.engine
169
172
  dialect_name = sql_engine.dialect.name.lower()
170
173
 
171
- output.append(f"Database Dialect: {sql_engine.dialect.name}")
172
- output.append(f"Driver: {sql_engine.driver}")
173
- output.append(f"Connection URL: {sql_engine.url}")
174
-
175
- # Inspect the database
176
- inspector = sql.inspect(sql_engine)
177
- tables = inspector.get_table_names()
178
- output.append(f"Tables: {tables}")
179
- output.append(f"Schemas: {inspector.get_schema_names()}")
180
-
181
- # Helper to build a dialect-specific limit clause
182
- def build_query(col_name_quoted: str, table_name_quoted: str, n: int) -> str:
183
- """
184
- Returns a SQL query string to select N rows from the given column/table
185
- across different dialects (SQLite, MySQL, Postgres, MSSQL, Oracle, etc.)
186
- """
187
- if "sqlite" in dialect_name or "mysql" in dialect_name or "postgres" in dialect_name:
188
- # Common dialects supporting LIMIT
189
- return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
190
- elif "mssql" in dialect_name:
191
- # Microsoft SQL Server syntax
192
- return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted}"
193
- elif "oracle" in dialect_name:
194
- # Oracle syntax
195
- return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
196
- else:
197
- # Fallback
198
- return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
199
-
200
- # Prepare for quoting
201
- preparer = inspector.bind.dialect.identifier_preparer
202
-
203
- # For each table, get columns and sample data
204
- for table_name in tables:
205
- output.append(f"\nTable: {table_name}")
206
- # Properly quote the table name
207
- table_name_quoted = preparer.quote_identifier(table_name)
208
-
209
- for column in inspector.get_columns(table_name):
210
- col_name = column["name"]
211
- col_type = column["type"]
212
- output.append(f" Column: {col_name} Type: {col_type}")
174
+ metadata["dialect"] = sql_engine.dialect.name
175
+ metadata["driver"] = sql_engine.driver
176
+ metadata["connection_url"] = str(sql_engine.url)
213
177
 
214
- # Properly quote the column name
215
- col_name_quoted = preparer.quote_identifier(col_name)
216
-
217
- # Build a dialect-aware query with safe quoting
218
- query = build_query(col_name_quoted, table_name_quoted, n_samples)
219
-
220
- # Read a few sample values
221
- df = pd.read_sql(sql.text(query), conn)
222
- first_values = df[col_name].tolist()
223
- output.append(f" First {n_samples} Values: {first_values}")
178
+ inspector = inspect(sql_engine)
179
+ preparer = inspector.bind.dialect.identifier_preparer
224
180
 
181
+ # For each schema
182
+ for schema_name in inspector.get_schema_names():
183
+ schema_obj = {
184
+ "schema_name": schema_name,
185
+ "tables": []
186
+ }
187
+
188
+ tables = inspector.get_table_names(schema=schema_name)
189
+ for table_name in tables:
190
+ table_info = {
191
+ "table_name": table_name,
192
+ "columns": [],
193
+ "primary_key": [],
194
+ "foreign_keys": [],
195
+ "indexes": []
196
+ }
197
+ # Get columns
198
+ columns = inspector.get_columns(table_name, schema=schema_name)
199
+ for col in columns:
200
+ col_name = col["name"]
201
+ col_type = str(col["type"])
202
+ table_name_quoted = f"{preparer.quote_identifier(schema_name)}.{preparer.quote_identifier(table_name)}"
203
+ col_name_quoted = preparer.quote_identifier(col_name)
204
+
205
+ # Build query for sample data
206
+ query = build_query(col_name_quoted, table_name_quoted, n_samples, dialect_name)
207
+
208
+ # Retrieve sample data
209
+ try:
210
+ df = pd.read_sql(query, conn)
211
+ samples = df[col_name].head(n_samples).tolist()
212
+ except Exception as e:
213
+ samples = [f"Error retrieving data: {str(e)}"]
214
+
215
+ table_info["columns"].append({
216
+ "name": col_name,
217
+ "type": col_type,
218
+ "sample_values": samples
219
+ })
220
+
221
+ # Primary keys
222
+ pk_constraint = inspector.get_pk_constraint(table_name, schema=schema_name)
223
+ table_info["primary_key"] = pk_constraint.get("constrained_columns", [])
224
+
225
+ # Foreign keys
226
+ fks = inspector.get_foreign_keys(table_name, schema=schema_name)
227
+ table_info["foreign_keys"] = [
228
+ {
229
+ "local_cols": fk["constrained_columns"],
230
+ "referred_table": fk["referred_table"],
231
+ "referred_cols": fk["referred_columns"]
232
+ }
233
+ for fk in fks
234
+ ]
235
+
236
+ # Indexes
237
+ idxs = inspector.get_indexes(table_name, schema=schema_name)
238
+ table_info["indexes"] = idxs
239
+
240
+ schema_obj["tables"].append(table_info)
241
+
242
+ metadata["schemas"].append(schema_obj)
243
+
225
244
  finally:
226
- # Close connection if created inside the function
227
245
  if is_engine:
228
246
  conn.close()
229
247
 
230
- return "\n".join(output)
248
+ return metadata
249
+
250
+ def build_query(col_name_quoted: str, table_name_quoted: str, n: int, dialect_name: str) -> str:
251
+ # Example: expand your build_query to handle random sampling if possible
252
+ if "postgres" in dialect_name:
253
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
254
+ if "mysql" in dialect_name:
255
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RAND() LIMIT {n}"
256
+ if "sqlite" in dialect_name:
257
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
258
+ if "mssql" in dialect_name:
259
+ return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted} ORDER BY NEWID()"
260
+ # Oracle or fallback
261
+ return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
262
+