ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +4 -5
- ai_data_science_team/agents/data_cleaning_agent.py +268 -116
- ai_data_science_team/agents/data_visualization_agent.py +470 -41
- ai_data_science_team/agents/data_wrangling_agent.py +471 -31
- ai_data_science_team/agents/feature_engineering_agent.py +426 -41
- ai_data_science_team/agents/sql_database_agent.py +458 -58
- ai_data_science_team/ml_agents/__init__.py +1 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +3 -1
- ai_data_science_team/templates/agent_templates.py +319 -43
- ai_data_science_team/tools/metadata.py +94 -62
- ai_data_science_team/tools/regex.py +86 -1
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
- ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
- ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,225 @@
|
|
1
1
|
from langchain_core.messages import AIMessage
|
2
2
|
from langgraph.graph import StateGraph, END
|
3
3
|
from langgraph.types import interrupt, Command
|
4
|
+
from langgraph.graph.state import CompiledStateGraph
|
5
|
+
|
6
|
+
from langchain_core.runnables import RunnableConfig
|
7
|
+
from langgraph.pregel.types import StreamMode
|
4
8
|
|
5
9
|
import pandas as pd
|
6
10
|
import sqlalchemy as sql
|
11
|
+
import json
|
7
12
|
|
8
|
-
from typing import Any, Callable, Dict, Type, Optional
|
13
|
+
from typing import Any, Callable, Dict, Type, Optional, Union, List
|
9
14
|
|
10
15
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
11
|
-
from ai_data_science_team.tools.regex import
|
16
|
+
from ai_data_science_team.tools.regex import (
|
17
|
+
relocate_imports_inside_function,
|
18
|
+
add_comments_to_top,
|
19
|
+
remove_consecutive_duplicates
|
20
|
+
)
|
21
|
+
|
22
|
+
from IPython.display import Image, display
|
23
|
+
import pandas as pd
|
24
|
+
|
25
|
+
class BaseAgent(CompiledStateGraph):
|
26
|
+
"""
|
27
|
+
A generic base class for agents that interact with compiled state graphs.
|
28
|
+
|
29
|
+
Provides shared functionality for handling parameters, responses, and state
|
30
|
+
graph operations.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self, **params):
|
34
|
+
"""
|
35
|
+
Initialize the agent with provided parameters.
|
36
|
+
|
37
|
+
Parameters:
|
38
|
+
**params: Arbitrary keyword arguments representing the agent's parameters.
|
39
|
+
"""
|
40
|
+
self._params = params
|
41
|
+
self._compiled_graph = self._make_compiled_graph()
|
42
|
+
self.response = None
|
43
|
+
|
44
|
+
def _make_compiled_graph(self):
|
45
|
+
"""
|
46
|
+
Subclasses should override this method to create a specific compiled graph.
|
47
|
+
"""
|
48
|
+
raise NotImplementedError("Subclasses must implement the `_make_compiled_graph` method.")
|
49
|
+
|
50
|
+
def update_params(self, **kwargs):
|
51
|
+
"""
|
52
|
+
Update one or more parameters and rebuild the compiled graph.
|
53
|
+
|
54
|
+
Parameters:
|
55
|
+
**kwargs: Parameters to update.
|
56
|
+
"""
|
57
|
+
self._params.update(kwargs)
|
58
|
+
self._compiled_graph = self._make_compiled_graph()
|
59
|
+
|
60
|
+
def __getattr__(self, name: str):
|
61
|
+
"""
|
62
|
+
Delegate attribute access to the compiled graph if the attribute is not found.
|
63
|
+
|
64
|
+
Parameters:
|
65
|
+
name (str): The attribute name.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
Any: The attribute from the compiled graph.
|
69
|
+
"""
|
70
|
+
return getattr(self._compiled_graph, name)
|
71
|
+
|
72
|
+
def invoke(
|
73
|
+
self,
|
74
|
+
input: Union[dict[str, Any], Any],
|
75
|
+
config: Optional[RunnableConfig] = None,
|
76
|
+
**kwargs
|
77
|
+
):
|
78
|
+
"""
|
79
|
+
Wrapper for self._compiled_graph.invoke()
|
80
|
+
|
81
|
+
Parameters:
|
82
|
+
input: The input data for the graph. It can be a dictionary or any other type.
|
83
|
+
config: Optional. The configuration for the graph run.
|
84
|
+
**kwarg: Arguments to pass to self._compiled_graph.invoke()
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
Any: The agent's response.
|
88
|
+
"""
|
89
|
+
self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
|
90
|
+
|
91
|
+
if self.response.get("messages"):
|
92
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
93
|
+
|
94
|
+
return self.response
|
95
|
+
|
96
|
+
def ainvoke(
|
97
|
+
self,
|
98
|
+
input: Union[dict[str, Any], Any],
|
99
|
+
config: Optional[RunnableConfig] = None,
|
100
|
+
**kwargs
|
101
|
+
):
|
102
|
+
"""
|
103
|
+
Wrapper for self._compiled_graph.ainvoke()
|
104
|
+
|
105
|
+
Parameters:
|
106
|
+
input: The input data for the graph. It can be a dictionary or any other type.
|
107
|
+
config: Optional. The configuration for the graph run.
|
108
|
+
**kwarg: Arguments to pass to self._compiled_graph.ainvoke()
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
Any: The agent's response.
|
112
|
+
"""
|
113
|
+
self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
|
114
|
+
|
115
|
+
if self.response.get("messages"):
|
116
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
117
|
+
|
118
|
+
return self.response
|
119
|
+
|
120
|
+
def stream(
|
121
|
+
self,
|
122
|
+
input: dict[str, Any] | Any,
|
123
|
+
config: RunnableConfig | None = None,
|
124
|
+
stream_mode: StreamMode | list[StreamMode] | None = None,
|
125
|
+
**kwargs
|
126
|
+
):
|
127
|
+
"""
|
128
|
+
Wrapper for self._compiled_graph.stream()
|
129
|
+
|
130
|
+
Parameters:
|
131
|
+
input: The input to the graph.
|
132
|
+
config: The configuration to use for the run.
|
133
|
+
stream_mode: The mode to stream output, defaults to self.stream_mode.
|
134
|
+
Options are 'values', 'updates', and 'debug'.
|
135
|
+
values: Emit the current values of the state for each step.
|
136
|
+
updates: Emit only the updates to the state for each step.
|
137
|
+
Output is a dict with the node name as key and the updated values as value.
|
138
|
+
debug: Emit debug events for each step.
|
139
|
+
**kwarg: Arguments to pass to self._compiled_graph.stream()
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
Any: The agent's response.
|
143
|
+
"""
|
144
|
+
self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
|
145
|
+
|
146
|
+
if self.response.get("messages"):
|
147
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
148
|
+
|
149
|
+
return self.response
|
150
|
+
|
151
|
+
def astream(
|
152
|
+
self,
|
153
|
+
input: dict[str, Any] | Any,
|
154
|
+
config: RunnableConfig | None = None,
|
155
|
+
stream_mode: StreamMode | list[StreamMode] | None = None,
|
156
|
+
**kwargs
|
157
|
+
):
|
158
|
+
"""
|
159
|
+
Wrapper for self._compiled_graph.astream()
|
160
|
+
|
161
|
+
Parameters:
|
162
|
+
input: The input to the graph.
|
163
|
+
config: The configuration to use for the run.
|
164
|
+
stream_mode: The mode to stream output, defaults to self.stream_mode.
|
165
|
+
Options are 'values', 'updates', and 'debug'.
|
166
|
+
values: Emit the current values of the state for each step.
|
167
|
+
updates: Emit only the updates to the state for each step.
|
168
|
+
Output is a dict with the node name as key and the updated values as value.
|
169
|
+
debug: Emit debug events for each step.
|
170
|
+
**kwarg: Arguments to pass to self._compiled_graph.astream()
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
Any: The agent's response.
|
174
|
+
"""
|
175
|
+
self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
|
176
|
+
|
177
|
+
if self.response.get("messages"):
|
178
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
179
|
+
|
180
|
+
return self.response
|
181
|
+
|
182
|
+
def get_state_keys(self):
|
183
|
+
"""
|
184
|
+
Returns a list of keys that the state graph response contains.
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
list: A list of keys in the response.
|
188
|
+
"""
|
189
|
+
return list(self.get_output_jsonschema()['properties'].keys())
|
190
|
+
|
191
|
+
def get_state_properties(self):
|
192
|
+
"""
|
193
|
+
Returns detailed properties of the state graph response.
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
dict: The properties of the response.
|
197
|
+
"""
|
198
|
+
return self.get_output_jsonschema()['properties']
|
199
|
+
|
200
|
+
def get_response(self):
|
201
|
+
"""
|
202
|
+
Returns the response generated by the agent.
|
203
|
+
|
204
|
+
Returns:
|
205
|
+
Any: The agent's response.
|
206
|
+
"""
|
207
|
+
if self.response.get("messages"):
|
208
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
209
|
+
|
210
|
+
return self.response
|
211
|
+
|
212
|
+
def show(self, xray: int = 0):
|
213
|
+
"""
|
214
|
+
Displays the agent's state graph as a Mermaid diagram.
|
215
|
+
|
216
|
+
Parameters:
|
217
|
+
xray (int): If set to 1, displays subgraph levels. Defaults to 0.
|
218
|
+
"""
|
219
|
+
display(Image(self.get_graph(xray=xray).draw_mermaid_png()))
|
220
|
+
|
221
|
+
|
222
|
+
|
12
223
|
|
13
224
|
def create_coding_agent_graph(
|
14
225
|
GraphState: Type,
|
@@ -79,35 +290,37 @@ def create_coding_agent_graph(
|
|
79
290
|
|
80
291
|
workflow = StateGraph(GraphState)
|
81
292
|
|
82
|
-
#
|
83
|
-
if not bypass_recommended_steps:
|
84
|
-
workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
|
293
|
+
# * NODES
|
85
294
|
|
86
295
|
# Always add create, execute, and fix nodes
|
87
296
|
workflow.add_node(create_code_node_name, node_functions[create_code_node_name])
|
88
297
|
workflow.add_node(execute_code_node_name, node_functions[execute_code_node_name])
|
89
298
|
workflow.add_node(fix_code_node_name, node_functions[fix_code_node_name])
|
90
299
|
|
300
|
+
# Conditionally add the recommended-steps node
|
301
|
+
if not bypass_recommended_steps:
|
302
|
+
workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
|
303
|
+
|
304
|
+
# Conditionally add the human review node
|
305
|
+
if human_in_the_loop:
|
306
|
+
workflow.add_node(human_review_node_name, node_functions[human_review_node_name])
|
307
|
+
|
91
308
|
# Conditionally add the explanation node
|
92
309
|
if not bypass_explain_code:
|
93
310
|
workflow.add_node(explain_code_node_name, node_functions[explain_code_node_name])
|
94
311
|
|
312
|
+
# * EDGES
|
313
|
+
|
95
314
|
# Set the entry point
|
96
315
|
entry_point = create_code_node_name if bypass_recommended_steps else recommended_steps_node_name
|
316
|
+
|
97
317
|
workflow.set_entry_point(entry_point)
|
98
318
|
|
99
|
-
# Add edges for recommended steps
|
100
319
|
if not bypass_recommended_steps:
|
101
|
-
|
102
|
-
workflow.add_edge(recommended_steps_node_name, human_review_node_name)
|
103
|
-
else:
|
104
|
-
workflow.add_edge(recommended_steps_node_name, create_code_node_name)
|
105
|
-
elif human_in_the_loop:
|
106
|
-
# Skip recommended steps but still include human review
|
107
|
-
workflow.add_edge(create_code_node_name, human_review_node_name)
|
320
|
+
workflow.add_edge(recommended_steps_node_name, create_code_node_name)
|
108
321
|
|
109
|
-
# Create -> Execute
|
110
322
|
workflow.add_edge(create_code_node_name, execute_code_node_name)
|
323
|
+
workflow.add_edge(fix_code_node_name, execute_code_node_name)
|
111
324
|
|
112
325
|
# Define a helper to check if we have an error & can still retry
|
113
326
|
def error_and_can_retry(state):
|
@@ -117,39 +330,43 @@ def create_coding_agent_graph(
|
|
117
330
|
and state.get(max_retries_key) is not None
|
118
331
|
and state[retry_count_key] < state[max_retries_key]
|
119
332
|
)
|
120
|
-
|
121
|
-
#
|
122
|
-
if
|
123
|
-
# If we are NOT bypassing explain, the next node is fix_code if error,
|
124
|
-
# else explain_code. Then we wire explain_code -> END afterward.
|
333
|
+
|
334
|
+
# If human in the loop, add a branch for human review
|
335
|
+
if human_in_the_loop:
|
125
336
|
workflow.add_conditional_edges(
|
126
337
|
execute_code_node_name,
|
127
|
-
lambda s: "fix_code" if error_and_can_retry(s) else "
|
338
|
+
lambda s: "fix_code" if error_and_can_retry(s) else "human_review",
|
128
339
|
{
|
340
|
+
"human_review": human_review_node_name,
|
129
341
|
"fix_code": fix_code_node_name,
|
130
|
-
"explain_code": explain_code_node_name,
|
131
342
|
},
|
132
343
|
)
|
133
|
-
# Fix code -> Execute again
|
134
|
-
workflow.add_edge(fix_code_node_name, execute_code_node_name)
|
135
|
-
# explain_code -> END
|
136
|
-
workflow.add_edge(explain_code_node_name, END)
|
137
344
|
else:
|
138
|
-
# If
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
345
|
+
# If no human review, the next node is fix_code if error, else explain_code.
|
346
|
+
if not bypass_explain_code:
|
347
|
+
workflow.add_conditional_edges(
|
348
|
+
execute_code_node_name,
|
349
|
+
lambda s: "fix_code" if error_and_can_retry(s) else "explain_code",
|
350
|
+
{
|
351
|
+
"fix_code": fix_code_node_name,
|
352
|
+
"explain_code": explain_code_node_name,
|
353
|
+
},
|
354
|
+
)
|
355
|
+
else:
|
356
|
+
workflow.add_conditional_edges(
|
357
|
+
execute_code_node_name,
|
358
|
+
lambda s: "fix_code" if error_and_can_retry(s) else "END",
|
359
|
+
{
|
360
|
+
"fix_code": fix_code_node_name,
|
361
|
+
"END": END,
|
362
|
+
},
|
363
|
+
)
|
364
|
+
|
365
|
+
if not bypass_explain_code:
|
366
|
+
workflow.add_edge(explain_code_node_name, END)
|
150
367
|
|
151
368
|
# Finally, compile
|
152
|
-
if human_in_the_loop
|
369
|
+
if human_in_the_loop:
|
153
370
|
app = workflow.compile(checkpointer=checkpointer)
|
154
371
|
else:
|
155
372
|
app = workflow.compile()
|
@@ -165,6 +382,8 @@ def node_func_human_review(
|
|
165
382
|
no_goto: str,
|
166
383
|
user_instructions_key: str = "user_instructions",
|
167
384
|
recommended_steps_key: str = "recommended_steps",
|
385
|
+
code_snippet_key: str = "code_snippet",
|
386
|
+
code_type: str = "python"
|
168
387
|
) -> Command[str]:
|
169
388
|
"""
|
170
389
|
A generic function to handle human review steps.
|
@@ -183,6 +402,10 @@ def node_func_human_review(
|
|
183
402
|
The key in the state to store user instructions.
|
184
403
|
recommended_steps_key : str, optional
|
185
404
|
The key in the state to store recommended steps.
|
405
|
+
code_snippet_key : str, optional
|
406
|
+
The key in the state to store the code snippet.
|
407
|
+
code_type : str, optional
|
408
|
+
The type of code snippet to display (e.g., "python").
|
186
409
|
|
187
410
|
Returns
|
188
411
|
-------
|
@@ -190,9 +413,11 @@ def node_func_human_review(
|
|
190
413
|
A Command object directing the next state and updates to the state.
|
191
414
|
"""
|
192
415
|
print(" * HUMAN REVIEW")
|
416
|
+
|
417
|
+
code_markdown=f"```{code_type}\n" + state.get(code_snippet_key)+"\n```"
|
193
418
|
|
194
419
|
# Display instructions and get user response
|
195
|
-
user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '')))
|
420
|
+
user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '') + "\n\n" + code_markdown))
|
196
421
|
|
197
422
|
# Decide next steps based on user input
|
198
423
|
if user_input.strip().lower() == "yes":
|
@@ -200,11 +425,11 @@ def node_func_human_review(
|
|
200
425
|
update = {}
|
201
426
|
else:
|
202
427
|
goto = no_goto
|
203
|
-
modifications = "Modifications: \n" + user_input
|
428
|
+
modifications = "User Has Requested Modifications To Previous Code: \n" + user_input
|
204
429
|
if state.get(user_instructions_key) is None:
|
205
|
-
update = {user_instructions_key: modifications}
|
430
|
+
update = {user_instructions_key: modifications + "\n\nPrevious Code:\n" + code_markdown}
|
206
431
|
else:
|
207
|
-
update = {user_instructions_key: state.get(user_instructions_key) + modifications}
|
432
|
+
update = {user_instructions_key: state.get(user_instructions_key) + modifications + "\n\nPrevious Code:\n" + code_markdown}
|
208
433
|
|
209
434
|
return Command(goto=goto, update=update)
|
210
435
|
|
@@ -394,6 +619,7 @@ def node_func_fix_agent_code(
|
|
394
619
|
retry_count_key: str = "retry_count",
|
395
620
|
log: bool = False,
|
396
621
|
file_path: str = "logs/agent_function.py",
|
622
|
+
function_name: str = "agent_function"
|
397
623
|
) -> dict:
|
398
624
|
"""
|
399
625
|
Generic function to fix a given piece of agent code using an LLM and a prompt template.
|
@@ -420,6 +646,8 @@ def node_func_fix_agent_code(
|
|
420
646
|
Whether to log the returned code to a file.
|
421
647
|
file_path : str, optional
|
422
648
|
The path to the file where the code will be logged.
|
649
|
+
function_name : str, optional
|
650
|
+
The name of the function in the code snippet that will be fixed.
|
423
651
|
|
424
652
|
Returns
|
425
653
|
-------
|
@@ -436,7 +664,8 @@ def node_func_fix_agent_code(
|
|
436
664
|
# Format the prompt with the code snippet and the error
|
437
665
|
prompt = prompt_template.format(
|
438
666
|
code_snippet=code_snippet,
|
439
|
-
error=error_message
|
667
|
+
error=error_message,
|
668
|
+
function_name=function_name,
|
440
669
|
)
|
441
670
|
|
442
671
|
# Execute the prompt with the LLM
|
@@ -524,3 +753,50 @@ def node_func_explain_agent_code(
|
|
524
753
|
# Return an error message if there was a problem with the code
|
525
754
|
message = AIMessage(content=error_message)
|
526
755
|
return {result_key: [message]}
|
756
|
+
|
757
|
+
|
758
|
+
|
759
|
+
def node_func_report_agent_outputs(
|
760
|
+
state: Dict[str, Any],
|
761
|
+
keys_to_include: List[str],
|
762
|
+
result_key: str,
|
763
|
+
role: str,
|
764
|
+
custom_title: str = "Agent Output Summary"
|
765
|
+
) -> Dict[str, Any]:
|
766
|
+
"""
|
767
|
+
Gathers relevant data directly from the state (filtered by `keys_to_include`)
|
768
|
+
and returns them as a structured message in `state[result_key]`.
|
769
|
+
|
770
|
+
No LLM is used.
|
771
|
+
|
772
|
+
Parameters
|
773
|
+
----------
|
774
|
+
state : Dict[str, Any]
|
775
|
+
The current state dictionary holding all agent variables.
|
776
|
+
keys_to_include : List[str]
|
777
|
+
The list of keys in `state` to include in the output.
|
778
|
+
result_key : str
|
779
|
+
The key in `state` under which we'll store the final structured message.
|
780
|
+
role : str
|
781
|
+
The role that will be used in the final AIMessage (e.g., "DataCleaningAgent").
|
782
|
+
custom_title : str, optional
|
783
|
+
A title or heading for your report. Defaults to "Agent Output Summary".
|
784
|
+
"""
|
785
|
+
print(" * REPORT AGENT OUTPUTS")
|
786
|
+
|
787
|
+
final_report = {"report_title": custom_title}
|
788
|
+
|
789
|
+
for key in keys_to_include:
|
790
|
+
final_report[key] = state.get(key, f"<{key}_not_found_in_state>")
|
791
|
+
|
792
|
+
# Wrap it in a list of messages (like the current "messages" pattern).
|
793
|
+
# You can serialize this dictionary as JSON or just cast it to string.
|
794
|
+
return {
|
795
|
+
result_key: [
|
796
|
+
AIMessage(
|
797
|
+
content=json.dumps(final_report, indent=2),
|
798
|
+
role=role
|
799
|
+
)
|
800
|
+
]
|
801
|
+
}
|
802
|
+
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import io
|
2
2
|
import pandas as pd
|
3
3
|
import sqlalchemy as sql
|
4
|
+
from sqlalchemy import inspect
|
4
5
|
from typing import Union, List, Dict
|
5
6
|
|
6
7
|
def get_dataframe_summary(
|
@@ -139,8 +140,7 @@ def _summarize_dataframe(df: pd.DataFrame, dataset_name: str, n_sample=30, skip_
|
|
139
140
|
|
140
141
|
|
141
142
|
|
142
|
-
def get_database_metadata(connection
|
143
|
-
n_samples: int = 10) -> str:
|
143
|
+
def get_database_metadata(connection, n_samples=10) -> dict:
|
144
144
|
"""
|
145
145
|
Collects metadata and sample data from a database, with safe identifier quoting and
|
146
146
|
basic dialect-aware row limiting. Prevents issues with spaces/reserved words in identifiers.
|
@@ -154,77 +154,109 @@ def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engi
|
|
154
154
|
|
155
155
|
Returns
|
156
156
|
-------
|
157
|
-
|
158
|
-
A
|
157
|
+
dict
|
158
|
+
A dictionary with database metadata, including some sample data from each column.
|
159
159
|
"""
|
160
|
-
|
161
|
-
# If a connection is passed, use it; if an engine is passed, connect to it
|
162
160
|
is_engine = isinstance(connection, sql.engine.base.Engine)
|
163
161
|
conn = connection.connect() if is_engine else connection
|
164
162
|
|
165
|
-
|
163
|
+
metadata = {
|
164
|
+
"dialect": None,
|
165
|
+
"driver": None,
|
166
|
+
"connection_url": None,
|
167
|
+
"schemas": [],
|
168
|
+
}
|
169
|
+
|
166
170
|
try:
|
167
|
-
# Grab the engine off the connection
|
168
171
|
sql_engine = conn.engine
|
169
172
|
dialect_name = sql_engine.dialect.name.lower()
|
170
173
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
# Inspect the database
|
176
|
-
inspector = sql.inspect(sql_engine)
|
177
|
-
tables = inspector.get_table_names()
|
178
|
-
output.append(f"Tables: {tables}")
|
179
|
-
output.append(f"Schemas: {inspector.get_schema_names()}")
|
180
|
-
|
181
|
-
# Helper to build a dialect-specific limit clause
|
182
|
-
def build_query(col_name_quoted: str, table_name_quoted: str, n: int) -> str:
|
183
|
-
"""
|
184
|
-
Returns a SQL query string to select N rows from the given column/table
|
185
|
-
across different dialects (SQLite, MySQL, Postgres, MSSQL, Oracle, etc.)
|
186
|
-
"""
|
187
|
-
if "sqlite" in dialect_name or "mysql" in dialect_name or "postgres" in dialect_name:
|
188
|
-
# Common dialects supporting LIMIT
|
189
|
-
return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
|
190
|
-
elif "mssql" in dialect_name:
|
191
|
-
# Microsoft SQL Server syntax
|
192
|
-
return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted}"
|
193
|
-
elif "oracle" in dialect_name:
|
194
|
-
# Oracle syntax
|
195
|
-
return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
|
196
|
-
else:
|
197
|
-
# Fallback
|
198
|
-
return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
|
199
|
-
|
200
|
-
# Prepare for quoting
|
201
|
-
preparer = inspector.bind.dialect.identifier_preparer
|
202
|
-
|
203
|
-
# For each table, get columns and sample data
|
204
|
-
for table_name in tables:
|
205
|
-
output.append(f"\nTable: {table_name}")
|
206
|
-
# Properly quote the table name
|
207
|
-
table_name_quoted = preparer.quote_identifier(table_name)
|
208
|
-
|
209
|
-
for column in inspector.get_columns(table_name):
|
210
|
-
col_name = column["name"]
|
211
|
-
col_type = column["type"]
|
212
|
-
output.append(f" Column: {col_name} Type: {col_type}")
|
174
|
+
metadata["dialect"] = sql_engine.dialect.name
|
175
|
+
metadata["driver"] = sql_engine.driver
|
176
|
+
metadata["connection_url"] = str(sql_engine.url)
|
213
177
|
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
# Build a dialect-aware query with safe quoting
|
218
|
-
query = build_query(col_name_quoted, table_name_quoted, n_samples)
|
219
|
-
|
220
|
-
# Read a few sample values
|
221
|
-
df = pd.read_sql(sql.text(query), conn)
|
222
|
-
first_values = df[col_name].tolist()
|
223
|
-
output.append(f" First {n_samples} Values: {first_values}")
|
178
|
+
inspector = inspect(sql_engine)
|
179
|
+
preparer = inspector.bind.dialect.identifier_preparer
|
224
180
|
|
181
|
+
# For each schema
|
182
|
+
for schema_name in inspector.get_schema_names():
|
183
|
+
schema_obj = {
|
184
|
+
"schema_name": schema_name,
|
185
|
+
"tables": []
|
186
|
+
}
|
187
|
+
|
188
|
+
tables = inspector.get_table_names(schema=schema_name)
|
189
|
+
for table_name in tables:
|
190
|
+
table_info = {
|
191
|
+
"table_name": table_name,
|
192
|
+
"columns": [],
|
193
|
+
"primary_key": [],
|
194
|
+
"foreign_keys": [],
|
195
|
+
"indexes": []
|
196
|
+
}
|
197
|
+
# Get columns
|
198
|
+
columns = inspector.get_columns(table_name, schema=schema_name)
|
199
|
+
for col in columns:
|
200
|
+
col_name = col["name"]
|
201
|
+
col_type = str(col["type"])
|
202
|
+
table_name_quoted = f"{preparer.quote_identifier(schema_name)}.{preparer.quote_identifier(table_name)}"
|
203
|
+
col_name_quoted = preparer.quote_identifier(col_name)
|
204
|
+
|
205
|
+
# Build query for sample data
|
206
|
+
query = build_query(col_name_quoted, table_name_quoted, n_samples, dialect_name)
|
207
|
+
|
208
|
+
# Retrieve sample data
|
209
|
+
try:
|
210
|
+
df = pd.read_sql(query, conn)
|
211
|
+
samples = df[col_name].head(n_samples).tolist()
|
212
|
+
except Exception as e:
|
213
|
+
samples = [f"Error retrieving data: {str(e)}"]
|
214
|
+
|
215
|
+
table_info["columns"].append({
|
216
|
+
"name": col_name,
|
217
|
+
"type": col_type,
|
218
|
+
"sample_values": samples
|
219
|
+
})
|
220
|
+
|
221
|
+
# Primary keys
|
222
|
+
pk_constraint = inspector.get_pk_constraint(table_name, schema=schema_name)
|
223
|
+
table_info["primary_key"] = pk_constraint.get("constrained_columns", [])
|
224
|
+
|
225
|
+
# Foreign keys
|
226
|
+
fks = inspector.get_foreign_keys(table_name, schema=schema_name)
|
227
|
+
table_info["foreign_keys"] = [
|
228
|
+
{
|
229
|
+
"local_cols": fk["constrained_columns"],
|
230
|
+
"referred_table": fk["referred_table"],
|
231
|
+
"referred_cols": fk["referred_columns"]
|
232
|
+
}
|
233
|
+
for fk in fks
|
234
|
+
]
|
235
|
+
|
236
|
+
# Indexes
|
237
|
+
idxs = inspector.get_indexes(table_name, schema=schema_name)
|
238
|
+
table_info["indexes"] = idxs
|
239
|
+
|
240
|
+
schema_obj["tables"].append(table_info)
|
241
|
+
|
242
|
+
metadata["schemas"].append(schema_obj)
|
243
|
+
|
225
244
|
finally:
|
226
|
-
# Close connection if created inside the function
|
227
245
|
if is_engine:
|
228
246
|
conn.close()
|
229
247
|
|
230
|
-
return
|
248
|
+
return metadata
|
249
|
+
|
250
|
+
def build_query(col_name_quoted: str, table_name_quoted: str, n: int, dialect_name: str) -> str:
|
251
|
+
# Example: expand your build_query to handle random sampling if possible
|
252
|
+
if "postgres" in dialect_name:
|
253
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
|
254
|
+
if "mysql" in dialect_name:
|
255
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RAND() LIMIT {n}"
|
256
|
+
if "sqlite" in dialect_name:
|
257
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
|
258
|
+
if "mssql" in dialect_name:
|
259
|
+
return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted} ORDER BY NEWID()"
|
260
|
+
# Oracle or fallback
|
261
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
|
262
|
+
|