ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9009__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +4 -5
- ai_data_science_team/agents/data_cleaning_agent.py +268 -116
- ai_data_science_team/agents/data_visualization_agent.py +470 -41
- ai_data_science_team/agents/data_wrangling_agent.py +471 -31
- ai_data_science_team/agents/feature_engineering_agent.py +426 -41
- ai_data_science_team/agents/sql_database_agent.py +458 -58
- ai_data_science_team/ml_agents/__init__.py +1 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +1032 -0
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +398 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +3 -1
- ai_data_science_team/templates/agent_templates.py +319 -43
- ai_data_science_team/tools/metadata.py +94 -62
- ai_data_science_team/tools/regex.py +86 -1
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9009.dist-info/METADATA +245 -0
- ai_data_science_team-0.0.0.9009.dist-info/RECORD +28 -0
- ai_data_science_team-0.0.0.9007.dist-info/METADATA +0 -183
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9009.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,225 @@
|
|
1
1
|
from langchain_core.messages import AIMessage
|
2
2
|
from langgraph.graph import StateGraph, END
|
3
3
|
from langgraph.types import interrupt, Command
|
4
|
+
from langgraph.graph.state import CompiledStateGraph
|
5
|
+
|
6
|
+
from langchain_core.runnables import RunnableConfig
|
7
|
+
from langgraph.pregel.types import StreamMode
|
4
8
|
|
5
9
|
import pandas as pd
|
6
10
|
import sqlalchemy as sql
|
11
|
+
import json
|
7
12
|
|
8
|
-
from typing import Any, Callable, Dict, Type, Optional
|
13
|
+
from typing import Any, Callable, Dict, Type, Optional, Union, List
|
9
14
|
|
10
15
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
11
|
-
from ai_data_science_team.tools.regex import
|
16
|
+
from ai_data_science_team.tools.regex import (
|
17
|
+
relocate_imports_inside_function,
|
18
|
+
add_comments_to_top,
|
19
|
+
remove_consecutive_duplicates
|
20
|
+
)
|
21
|
+
|
22
|
+
from IPython.display import Image, display
|
23
|
+
import pandas as pd
|
24
|
+
|
25
|
+
class BaseAgent(CompiledStateGraph):
|
26
|
+
"""
|
27
|
+
A generic base class for agents that interact with compiled state graphs.
|
28
|
+
|
29
|
+
Provides shared functionality for handling parameters, responses, and state
|
30
|
+
graph operations.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self, **params):
|
34
|
+
"""
|
35
|
+
Initialize the agent with provided parameters.
|
36
|
+
|
37
|
+
Parameters:
|
38
|
+
**params: Arbitrary keyword arguments representing the agent's parameters.
|
39
|
+
"""
|
40
|
+
self._params = params
|
41
|
+
self._compiled_graph = self._make_compiled_graph()
|
42
|
+
self.response = None
|
43
|
+
|
44
|
+
def _make_compiled_graph(self):
|
45
|
+
"""
|
46
|
+
Subclasses should override this method to create a specific compiled graph.
|
47
|
+
"""
|
48
|
+
raise NotImplementedError("Subclasses must implement the `_make_compiled_graph` method.")
|
49
|
+
|
50
|
+
def update_params(self, **kwargs):
|
51
|
+
"""
|
52
|
+
Update one or more parameters and rebuild the compiled graph.
|
53
|
+
|
54
|
+
Parameters:
|
55
|
+
**kwargs: Parameters to update.
|
56
|
+
"""
|
57
|
+
self._params.update(kwargs)
|
58
|
+
self._compiled_graph = self._make_compiled_graph()
|
59
|
+
|
60
|
+
def __getattr__(self, name: str):
|
61
|
+
"""
|
62
|
+
Delegate attribute access to the compiled graph if the attribute is not found.
|
63
|
+
|
64
|
+
Parameters:
|
65
|
+
name (str): The attribute name.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
Any: The attribute from the compiled graph.
|
69
|
+
"""
|
70
|
+
return getattr(self._compiled_graph, name)
|
71
|
+
|
72
|
+
def invoke(
|
73
|
+
self,
|
74
|
+
input: Union[dict[str, Any], Any],
|
75
|
+
config: Optional[RunnableConfig] = None,
|
76
|
+
**kwargs
|
77
|
+
):
|
78
|
+
"""
|
79
|
+
Wrapper for self._compiled_graph.invoke()
|
80
|
+
|
81
|
+
Parameters:
|
82
|
+
input: The input data for the graph. It can be a dictionary or any other type.
|
83
|
+
config: Optional. The configuration for the graph run.
|
84
|
+
**kwarg: Arguments to pass to self._compiled_graph.invoke()
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
Any: The agent's response.
|
88
|
+
"""
|
89
|
+
self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
|
90
|
+
|
91
|
+
if self.response.get("messages"):
|
92
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
93
|
+
|
94
|
+
return self.response
|
95
|
+
|
96
|
+
def ainvoke(
|
97
|
+
self,
|
98
|
+
input: Union[dict[str, Any], Any],
|
99
|
+
config: Optional[RunnableConfig] = None,
|
100
|
+
**kwargs
|
101
|
+
):
|
102
|
+
"""
|
103
|
+
Wrapper for self._compiled_graph.ainvoke()
|
104
|
+
|
105
|
+
Parameters:
|
106
|
+
input: The input data for the graph. It can be a dictionary or any other type.
|
107
|
+
config: Optional. The configuration for the graph run.
|
108
|
+
**kwarg: Arguments to pass to self._compiled_graph.ainvoke()
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
Any: The agent's response.
|
112
|
+
"""
|
113
|
+
self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
|
114
|
+
|
115
|
+
if self.response.get("messages"):
|
116
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
117
|
+
|
118
|
+
return self.response
|
119
|
+
|
120
|
+
def stream(
|
121
|
+
self,
|
122
|
+
input: dict[str, Any] | Any,
|
123
|
+
config: RunnableConfig | None = None,
|
124
|
+
stream_mode: StreamMode | list[StreamMode] | None = None,
|
125
|
+
**kwargs
|
126
|
+
):
|
127
|
+
"""
|
128
|
+
Wrapper for self._compiled_graph.stream()
|
129
|
+
|
130
|
+
Parameters:
|
131
|
+
input: The input to the graph.
|
132
|
+
config: The configuration to use for the run.
|
133
|
+
stream_mode: The mode to stream output, defaults to self.stream_mode.
|
134
|
+
Options are 'values', 'updates', and 'debug'.
|
135
|
+
values: Emit the current values of the state for each step.
|
136
|
+
updates: Emit only the updates to the state for each step.
|
137
|
+
Output is a dict with the node name as key and the updated values as value.
|
138
|
+
debug: Emit debug events for each step.
|
139
|
+
**kwarg: Arguments to pass to self._compiled_graph.stream()
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
Any: The agent's response.
|
143
|
+
"""
|
144
|
+
self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
|
145
|
+
|
146
|
+
if self.response.get("messages"):
|
147
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
148
|
+
|
149
|
+
return self.response
|
150
|
+
|
151
|
+
def astream(
|
152
|
+
self,
|
153
|
+
input: dict[str, Any] | Any,
|
154
|
+
config: RunnableConfig | None = None,
|
155
|
+
stream_mode: StreamMode | list[StreamMode] | None = None,
|
156
|
+
**kwargs
|
157
|
+
):
|
158
|
+
"""
|
159
|
+
Wrapper for self._compiled_graph.astream()
|
160
|
+
|
161
|
+
Parameters:
|
162
|
+
input: The input to the graph.
|
163
|
+
config: The configuration to use for the run.
|
164
|
+
stream_mode: The mode to stream output, defaults to self.stream_mode.
|
165
|
+
Options are 'values', 'updates', and 'debug'.
|
166
|
+
values: Emit the current values of the state for each step.
|
167
|
+
updates: Emit only the updates to the state for each step.
|
168
|
+
Output is a dict with the node name as key and the updated values as value.
|
169
|
+
debug: Emit debug events for each step.
|
170
|
+
**kwarg: Arguments to pass to self._compiled_graph.astream()
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
Any: The agent's response.
|
174
|
+
"""
|
175
|
+
self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
|
176
|
+
|
177
|
+
if self.response.get("messages"):
|
178
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
179
|
+
|
180
|
+
return self.response
|
181
|
+
|
182
|
+
def get_state_keys(self):
|
183
|
+
"""
|
184
|
+
Returns a list of keys that the state graph response contains.
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
list: A list of keys in the response.
|
188
|
+
"""
|
189
|
+
return list(self.get_output_jsonschema()['properties'].keys())
|
190
|
+
|
191
|
+
def get_state_properties(self):
|
192
|
+
"""
|
193
|
+
Returns detailed properties of the state graph response.
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
dict: The properties of the response.
|
197
|
+
"""
|
198
|
+
return self.get_output_jsonschema()['properties']
|
199
|
+
|
200
|
+
def get_response(self):
|
201
|
+
"""
|
202
|
+
Returns the response generated by the agent.
|
203
|
+
|
204
|
+
Returns:
|
205
|
+
Any: The agent's response.
|
206
|
+
"""
|
207
|
+
if self.response.get("messages"):
|
208
|
+
self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
|
209
|
+
|
210
|
+
return self.response
|
211
|
+
|
212
|
+
def show(self, xray: int = 0):
|
213
|
+
"""
|
214
|
+
Displays the agent's state graph as a Mermaid diagram.
|
215
|
+
|
216
|
+
Parameters:
|
217
|
+
xray (int): If set to 1, displays subgraph levels. Defaults to 0.
|
218
|
+
"""
|
219
|
+
display(Image(self.get_graph(xray=xray).draw_mermaid_png()))
|
220
|
+
|
221
|
+
|
222
|
+
|
12
223
|
|
13
224
|
def create_coding_agent_graph(
|
14
225
|
GraphState: Type,
|
@@ -79,35 +290,37 @@ def create_coding_agent_graph(
|
|
79
290
|
|
80
291
|
workflow = StateGraph(GraphState)
|
81
292
|
|
82
|
-
#
|
83
|
-
if not bypass_recommended_steps:
|
84
|
-
workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
|
293
|
+
# * NODES
|
85
294
|
|
86
295
|
# Always add create, execute, and fix nodes
|
87
296
|
workflow.add_node(create_code_node_name, node_functions[create_code_node_name])
|
88
297
|
workflow.add_node(execute_code_node_name, node_functions[execute_code_node_name])
|
89
298
|
workflow.add_node(fix_code_node_name, node_functions[fix_code_node_name])
|
90
299
|
|
300
|
+
# Conditionally add the recommended-steps node
|
301
|
+
if not bypass_recommended_steps:
|
302
|
+
workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
|
303
|
+
|
304
|
+
# Conditionally add the human review node
|
305
|
+
if human_in_the_loop:
|
306
|
+
workflow.add_node(human_review_node_name, node_functions[human_review_node_name])
|
307
|
+
|
91
308
|
# Conditionally add the explanation node
|
92
309
|
if not bypass_explain_code:
|
93
310
|
workflow.add_node(explain_code_node_name, node_functions[explain_code_node_name])
|
94
311
|
|
312
|
+
# * EDGES
|
313
|
+
|
95
314
|
# Set the entry point
|
96
315
|
entry_point = create_code_node_name if bypass_recommended_steps else recommended_steps_node_name
|
316
|
+
|
97
317
|
workflow.set_entry_point(entry_point)
|
98
318
|
|
99
|
-
# Add edges for recommended steps
|
100
319
|
if not bypass_recommended_steps:
|
101
|
-
|
102
|
-
workflow.add_edge(recommended_steps_node_name, human_review_node_name)
|
103
|
-
else:
|
104
|
-
workflow.add_edge(recommended_steps_node_name, create_code_node_name)
|
105
|
-
elif human_in_the_loop:
|
106
|
-
# Skip recommended steps but still include human review
|
107
|
-
workflow.add_edge(create_code_node_name, human_review_node_name)
|
320
|
+
workflow.add_edge(recommended_steps_node_name, create_code_node_name)
|
108
321
|
|
109
|
-
# Create -> Execute
|
110
322
|
workflow.add_edge(create_code_node_name, execute_code_node_name)
|
323
|
+
workflow.add_edge(fix_code_node_name, execute_code_node_name)
|
111
324
|
|
112
325
|
# Define a helper to check if we have an error & can still retry
|
113
326
|
def error_and_can_retry(state):
|
@@ -117,39 +330,43 @@ def create_coding_agent_graph(
|
|
117
330
|
and state.get(max_retries_key) is not None
|
118
331
|
and state[retry_count_key] < state[max_retries_key]
|
119
332
|
)
|
120
|
-
|
121
|
-
#
|
122
|
-
if
|
123
|
-
# If we are NOT bypassing explain, the next node is fix_code if error,
|
124
|
-
# else explain_code. Then we wire explain_code -> END afterward.
|
333
|
+
|
334
|
+
# If human in the loop, add a branch for human review
|
335
|
+
if human_in_the_loop:
|
125
336
|
workflow.add_conditional_edges(
|
126
337
|
execute_code_node_name,
|
127
|
-
lambda s: "fix_code" if error_and_can_retry(s) else "
|
338
|
+
lambda s: "fix_code" if error_and_can_retry(s) else "human_review",
|
128
339
|
{
|
340
|
+
"human_review": human_review_node_name,
|
129
341
|
"fix_code": fix_code_node_name,
|
130
|
-
"explain_code": explain_code_node_name,
|
131
342
|
},
|
132
343
|
)
|
133
|
-
# Fix code -> Execute again
|
134
|
-
workflow.add_edge(fix_code_node_name, execute_code_node_name)
|
135
|
-
# explain_code -> END
|
136
|
-
workflow.add_edge(explain_code_node_name, END)
|
137
344
|
else:
|
138
|
-
# If
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
345
|
+
# If no human review, the next node is fix_code if error, else explain_code.
|
346
|
+
if not bypass_explain_code:
|
347
|
+
workflow.add_conditional_edges(
|
348
|
+
execute_code_node_name,
|
349
|
+
lambda s: "fix_code" if error_and_can_retry(s) else "explain_code",
|
350
|
+
{
|
351
|
+
"fix_code": fix_code_node_name,
|
352
|
+
"explain_code": explain_code_node_name,
|
353
|
+
},
|
354
|
+
)
|
355
|
+
else:
|
356
|
+
workflow.add_conditional_edges(
|
357
|
+
execute_code_node_name,
|
358
|
+
lambda s: "fix_code" if error_and_can_retry(s) else "END",
|
359
|
+
{
|
360
|
+
"fix_code": fix_code_node_name,
|
361
|
+
"END": END,
|
362
|
+
},
|
363
|
+
)
|
364
|
+
|
365
|
+
if not bypass_explain_code:
|
366
|
+
workflow.add_edge(explain_code_node_name, END)
|
150
367
|
|
151
368
|
# Finally, compile
|
152
|
-
if human_in_the_loop
|
369
|
+
if human_in_the_loop:
|
153
370
|
app = workflow.compile(checkpointer=checkpointer)
|
154
371
|
else:
|
155
372
|
app = workflow.compile()
|
@@ -165,6 +382,8 @@ def node_func_human_review(
|
|
165
382
|
no_goto: str,
|
166
383
|
user_instructions_key: str = "user_instructions",
|
167
384
|
recommended_steps_key: str = "recommended_steps",
|
385
|
+
code_snippet_key: str = "code_snippet",
|
386
|
+
code_type: str = "python"
|
168
387
|
) -> Command[str]:
|
169
388
|
"""
|
170
389
|
A generic function to handle human review steps.
|
@@ -183,6 +402,10 @@ def node_func_human_review(
|
|
183
402
|
The key in the state to store user instructions.
|
184
403
|
recommended_steps_key : str, optional
|
185
404
|
The key in the state to store recommended steps.
|
405
|
+
code_snippet_key : str, optional
|
406
|
+
The key in the state to store the code snippet.
|
407
|
+
code_type : str, optional
|
408
|
+
The type of code snippet to display (e.g., "python").
|
186
409
|
|
187
410
|
Returns
|
188
411
|
-------
|
@@ -190,9 +413,11 @@ def node_func_human_review(
|
|
190
413
|
A Command object directing the next state and updates to the state.
|
191
414
|
"""
|
192
415
|
print(" * HUMAN REVIEW")
|
416
|
+
|
417
|
+
code_markdown=f"```{code_type}\n" + state.get(code_snippet_key)+"\n```"
|
193
418
|
|
194
419
|
# Display instructions and get user response
|
195
|
-
user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '')))
|
420
|
+
user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '') + "\n\n" + code_markdown))
|
196
421
|
|
197
422
|
# Decide next steps based on user input
|
198
423
|
if user_input.strip().lower() == "yes":
|
@@ -200,11 +425,11 @@ def node_func_human_review(
|
|
200
425
|
update = {}
|
201
426
|
else:
|
202
427
|
goto = no_goto
|
203
|
-
modifications = "Modifications: \n" + user_input
|
428
|
+
modifications = "User Has Requested Modifications To Previous Code: \n" + user_input
|
204
429
|
if state.get(user_instructions_key) is None:
|
205
|
-
update = {user_instructions_key: modifications}
|
430
|
+
update = {user_instructions_key: modifications + "\n\nPrevious Code:\n" + code_markdown}
|
206
431
|
else:
|
207
|
-
update = {user_instructions_key: state.get(user_instructions_key) + modifications}
|
432
|
+
update = {user_instructions_key: state.get(user_instructions_key) + modifications + "\n\nPrevious Code:\n" + code_markdown}
|
208
433
|
|
209
434
|
return Command(goto=goto, update=update)
|
210
435
|
|
@@ -394,6 +619,7 @@ def node_func_fix_agent_code(
|
|
394
619
|
retry_count_key: str = "retry_count",
|
395
620
|
log: bool = False,
|
396
621
|
file_path: str = "logs/agent_function.py",
|
622
|
+
function_name: str = "agent_function"
|
397
623
|
) -> dict:
|
398
624
|
"""
|
399
625
|
Generic function to fix a given piece of agent code using an LLM and a prompt template.
|
@@ -420,6 +646,8 @@ def node_func_fix_agent_code(
|
|
420
646
|
Whether to log the returned code to a file.
|
421
647
|
file_path : str, optional
|
422
648
|
The path to the file where the code will be logged.
|
649
|
+
function_name : str, optional
|
650
|
+
The name of the function in the code snippet that will be fixed.
|
423
651
|
|
424
652
|
Returns
|
425
653
|
-------
|
@@ -436,7 +664,8 @@ def node_func_fix_agent_code(
|
|
436
664
|
# Format the prompt with the code snippet and the error
|
437
665
|
prompt = prompt_template.format(
|
438
666
|
code_snippet=code_snippet,
|
439
|
-
error=error_message
|
667
|
+
error=error_message,
|
668
|
+
function_name=function_name,
|
440
669
|
)
|
441
670
|
|
442
671
|
# Execute the prompt with the LLM
|
@@ -524,3 +753,50 @@ def node_func_explain_agent_code(
|
|
524
753
|
# Return an error message if there was a problem with the code
|
525
754
|
message = AIMessage(content=error_message)
|
526
755
|
return {result_key: [message]}
|
756
|
+
|
757
|
+
|
758
|
+
|
759
|
+
def node_func_report_agent_outputs(
|
760
|
+
state: Dict[str, Any],
|
761
|
+
keys_to_include: List[str],
|
762
|
+
result_key: str,
|
763
|
+
role: str,
|
764
|
+
custom_title: str = "Agent Output Summary"
|
765
|
+
) -> Dict[str, Any]:
|
766
|
+
"""
|
767
|
+
Gathers relevant data directly from the state (filtered by `keys_to_include`)
|
768
|
+
and returns them as a structured message in `state[result_key]`.
|
769
|
+
|
770
|
+
No LLM is used.
|
771
|
+
|
772
|
+
Parameters
|
773
|
+
----------
|
774
|
+
state : Dict[str, Any]
|
775
|
+
The current state dictionary holding all agent variables.
|
776
|
+
keys_to_include : List[str]
|
777
|
+
The list of keys in `state` to include in the output.
|
778
|
+
result_key : str
|
779
|
+
The key in `state` under which we'll store the final structured message.
|
780
|
+
role : str
|
781
|
+
The role that will be used in the final AIMessage (e.g., "DataCleaningAgent").
|
782
|
+
custom_title : str, optional
|
783
|
+
A title or heading for your report. Defaults to "Agent Output Summary".
|
784
|
+
"""
|
785
|
+
print(" * REPORT AGENT OUTPUTS")
|
786
|
+
|
787
|
+
final_report = {"report_title": custom_title}
|
788
|
+
|
789
|
+
for key in keys_to_include:
|
790
|
+
final_report[key] = state.get(key, f"<{key}_not_found_in_state>")
|
791
|
+
|
792
|
+
# Wrap it in a list of messages (like the current "messages" pattern).
|
793
|
+
# You can serialize this dictionary as JSON or just cast it to string.
|
794
|
+
return {
|
795
|
+
result_key: [
|
796
|
+
AIMessage(
|
797
|
+
content=json.dumps(final_report, indent=2),
|
798
|
+
role=role
|
799
|
+
)
|
800
|
+
]
|
801
|
+
}
|
802
|
+
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import io
|
2
2
|
import pandas as pd
|
3
3
|
import sqlalchemy as sql
|
4
|
+
from sqlalchemy import inspect
|
4
5
|
from typing import Union, List, Dict
|
5
6
|
|
6
7
|
def get_dataframe_summary(
|
@@ -139,8 +140,7 @@ def _summarize_dataframe(df: pd.DataFrame, dataset_name: str, n_sample=30, skip_
|
|
139
140
|
|
140
141
|
|
141
142
|
|
142
|
-
def get_database_metadata(connection
|
143
|
-
n_samples: int = 10) -> str:
|
143
|
+
def get_database_metadata(connection, n_samples=10) -> dict:
|
144
144
|
"""
|
145
145
|
Collects metadata and sample data from a database, with safe identifier quoting and
|
146
146
|
basic dialect-aware row limiting. Prevents issues with spaces/reserved words in identifiers.
|
@@ -154,77 +154,109 @@ def get_database_metadata(connection: Union[sql.engine.base.Connection, sql.engi
|
|
154
154
|
|
155
155
|
Returns
|
156
156
|
-------
|
157
|
-
|
158
|
-
A
|
157
|
+
dict
|
158
|
+
A dictionary with database metadata, including some sample data from each column.
|
159
159
|
"""
|
160
|
-
|
161
|
-
# If a connection is passed, use it; if an engine is passed, connect to it
|
162
160
|
is_engine = isinstance(connection, sql.engine.base.Engine)
|
163
161
|
conn = connection.connect() if is_engine else connection
|
164
162
|
|
165
|
-
|
163
|
+
metadata = {
|
164
|
+
"dialect": None,
|
165
|
+
"driver": None,
|
166
|
+
"connection_url": None,
|
167
|
+
"schemas": [],
|
168
|
+
}
|
169
|
+
|
166
170
|
try:
|
167
|
-
# Grab the engine off the connection
|
168
171
|
sql_engine = conn.engine
|
169
172
|
dialect_name = sql_engine.dialect.name.lower()
|
170
173
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
# Inspect the database
|
176
|
-
inspector = sql.inspect(sql_engine)
|
177
|
-
tables = inspector.get_table_names()
|
178
|
-
output.append(f"Tables: {tables}")
|
179
|
-
output.append(f"Schemas: {inspector.get_schema_names()}")
|
180
|
-
|
181
|
-
# Helper to build a dialect-specific limit clause
|
182
|
-
def build_query(col_name_quoted: str, table_name_quoted: str, n: int) -> str:
|
183
|
-
"""
|
184
|
-
Returns a SQL query string to select N rows from the given column/table
|
185
|
-
across different dialects (SQLite, MySQL, Postgres, MSSQL, Oracle, etc.)
|
186
|
-
"""
|
187
|
-
if "sqlite" in dialect_name or "mysql" in dialect_name or "postgres" in dialect_name:
|
188
|
-
# Common dialects supporting LIMIT
|
189
|
-
return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
|
190
|
-
elif "mssql" in dialect_name:
|
191
|
-
# Microsoft SQL Server syntax
|
192
|
-
return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted}"
|
193
|
-
elif "oracle" in dialect_name:
|
194
|
-
# Oracle syntax
|
195
|
-
return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
|
196
|
-
else:
|
197
|
-
# Fallback
|
198
|
-
return f"SELECT {col_name_quoted} FROM {table_name_quoted} LIMIT {n}"
|
199
|
-
|
200
|
-
# Prepare for quoting
|
201
|
-
preparer = inspector.bind.dialect.identifier_preparer
|
202
|
-
|
203
|
-
# For each table, get columns and sample data
|
204
|
-
for table_name in tables:
|
205
|
-
output.append(f"\nTable: {table_name}")
|
206
|
-
# Properly quote the table name
|
207
|
-
table_name_quoted = preparer.quote_identifier(table_name)
|
208
|
-
|
209
|
-
for column in inspector.get_columns(table_name):
|
210
|
-
col_name = column["name"]
|
211
|
-
col_type = column["type"]
|
212
|
-
output.append(f" Column: {col_name} Type: {col_type}")
|
174
|
+
metadata["dialect"] = sql_engine.dialect.name
|
175
|
+
metadata["driver"] = sql_engine.driver
|
176
|
+
metadata["connection_url"] = str(sql_engine.url)
|
213
177
|
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
# Build a dialect-aware query with safe quoting
|
218
|
-
query = build_query(col_name_quoted, table_name_quoted, n_samples)
|
219
|
-
|
220
|
-
# Read a few sample values
|
221
|
-
df = pd.read_sql(sql.text(query), conn)
|
222
|
-
first_values = df[col_name].tolist()
|
223
|
-
output.append(f" First {n_samples} Values: {first_values}")
|
178
|
+
inspector = inspect(sql_engine)
|
179
|
+
preparer = inspector.bind.dialect.identifier_preparer
|
224
180
|
|
181
|
+
# For each schema
|
182
|
+
for schema_name in inspector.get_schema_names():
|
183
|
+
schema_obj = {
|
184
|
+
"schema_name": schema_name,
|
185
|
+
"tables": []
|
186
|
+
}
|
187
|
+
|
188
|
+
tables = inspector.get_table_names(schema=schema_name)
|
189
|
+
for table_name in tables:
|
190
|
+
table_info = {
|
191
|
+
"table_name": table_name,
|
192
|
+
"columns": [],
|
193
|
+
"primary_key": [],
|
194
|
+
"foreign_keys": [],
|
195
|
+
"indexes": []
|
196
|
+
}
|
197
|
+
# Get columns
|
198
|
+
columns = inspector.get_columns(table_name, schema=schema_name)
|
199
|
+
for col in columns:
|
200
|
+
col_name = col["name"]
|
201
|
+
col_type = str(col["type"])
|
202
|
+
table_name_quoted = f"{preparer.quote_identifier(schema_name)}.{preparer.quote_identifier(table_name)}"
|
203
|
+
col_name_quoted = preparer.quote_identifier(col_name)
|
204
|
+
|
205
|
+
# Build query for sample data
|
206
|
+
query = build_query(col_name_quoted, table_name_quoted, n_samples, dialect_name)
|
207
|
+
|
208
|
+
# Retrieve sample data
|
209
|
+
try:
|
210
|
+
df = pd.read_sql(query, conn)
|
211
|
+
samples = df[col_name].head(n_samples).tolist()
|
212
|
+
except Exception as e:
|
213
|
+
samples = [f"Error retrieving data: {str(e)}"]
|
214
|
+
|
215
|
+
table_info["columns"].append({
|
216
|
+
"name": col_name,
|
217
|
+
"type": col_type,
|
218
|
+
"sample_values": samples
|
219
|
+
})
|
220
|
+
|
221
|
+
# Primary keys
|
222
|
+
pk_constraint = inspector.get_pk_constraint(table_name, schema=schema_name)
|
223
|
+
table_info["primary_key"] = pk_constraint.get("constrained_columns", [])
|
224
|
+
|
225
|
+
# Foreign keys
|
226
|
+
fks = inspector.get_foreign_keys(table_name, schema=schema_name)
|
227
|
+
table_info["foreign_keys"] = [
|
228
|
+
{
|
229
|
+
"local_cols": fk["constrained_columns"],
|
230
|
+
"referred_table": fk["referred_table"],
|
231
|
+
"referred_cols": fk["referred_columns"]
|
232
|
+
}
|
233
|
+
for fk in fks
|
234
|
+
]
|
235
|
+
|
236
|
+
# Indexes
|
237
|
+
idxs = inspector.get_indexes(table_name, schema=schema_name)
|
238
|
+
table_info["indexes"] = idxs
|
239
|
+
|
240
|
+
schema_obj["tables"].append(table_info)
|
241
|
+
|
242
|
+
metadata["schemas"].append(schema_obj)
|
243
|
+
|
225
244
|
finally:
|
226
|
-
# Close connection if created inside the function
|
227
245
|
if is_engine:
|
228
246
|
conn.close()
|
229
247
|
|
230
|
-
return
|
248
|
+
return metadata
|
249
|
+
|
250
|
+
def build_query(col_name_quoted: str, table_name_quoted: str, n: int, dialect_name: str) -> str:
|
251
|
+
# Example: expand your build_query to handle random sampling if possible
|
252
|
+
if "postgres" in dialect_name:
|
253
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
|
254
|
+
if "mysql" in dialect_name:
|
255
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RAND() LIMIT {n}"
|
256
|
+
if "sqlite" in dialect_name:
|
257
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} ORDER BY RANDOM() LIMIT {n}"
|
258
|
+
if "mssql" in dialect_name:
|
259
|
+
return f"SELECT TOP {n} {col_name_quoted} FROM {table_name_quoted} ORDER BY NEWID()"
|
260
|
+
# Oracle or fallback
|
261
|
+
return f"SELECT {col_name_quoted} FROM {table_name_quoted} WHERE ROWNUM <= {n}"
|
262
|
+
|