openchatbi 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openchatbi/__init__.py +35 -0
- openchatbi/agent_graph.py +373 -0
- openchatbi/catalog/__init__.py +14 -0
- openchatbi/catalog/catalog_loader.py +208 -0
- openchatbi/catalog/catalog_store.py +202 -0
- openchatbi/catalog/entry.py +5 -0
- openchatbi/catalog/factory.py +81 -0
- openchatbi/catalog/helper.py +49 -0
- openchatbi/catalog/retrival_helper.py +74 -0
- openchatbi/catalog/schema_retrival.py +144 -0
- openchatbi/catalog/store/__init__.py +3 -0
- openchatbi/catalog/store/file_system.py +789 -0
- openchatbi/catalog/token_service.py +48 -0
- openchatbi/code/docker_executor.py +179 -0
- openchatbi/code/executor_base.py +21 -0
- openchatbi/code/local_executor.py +21 -0
- openchatbi/code/restricted_local_executor.py +47 -0
- openchatbi/config.yaml.template +74 -0
- openchatbi/config_loader.py +225 -0
- openchatbi/constants.py +17 -0
- openchatbi/graph_state.py +59 -0
- openchatbi/llm/llm.py +94 -0
- openchatbi/prompts/agent_prompt.md +48 -0
- openchatbi/prompts/extraction_prompt.md +175 -0
- openchatbi/prompts/schema_linking_prompt.md +56 -0
- openchatbi/prompts/sql_dialect/presto.md +57 -0
- openchatbi/prompts/system_prompt.py +92 -0
- openchatbi/prompts/text2sql_prompt.md +35 -0
- openchatbi/prompts/visualization_prompt.md +34 -0
- openchatbi/text2sql/__init__.py +1 -0
- openchatbi/text2sql/data.py +12 -0
- openchatbi/text2sql/extraction.py +122 -0
- openchatbi/text2sql/generate_sql.py +400 -0
- openchatbi/text2sql/schema_linking.py +239 -0
- openchatbi/text2sql/sql_graph.py +150 -0
- openchatbi/text2sql/text2sql_utils.py +57 -0
- openchatbi/text2sql/visualization.py +315 -0
- openchatbi/tool/ask_human.py +15 -0
- openchatbi/tool/mcp_tools.py +257 -0
- openchatbi/tool/memory.py +181 -0
- openchatbi/tool/run_python_code.py +70 -0
- openchatbi/tool/save_report.py +65 -0
- openchatbi/tool/search_knowledge.py +107 -0
- openchatbi/utils.py +183 -0
- openchatbi-0.0.1.dist-info/METADATA +674 -0
- openchatbi-0.0.1.dist-info/RECORD +48 -0
- openchatbi-0.0.1.dist-info/WHEEL +4 -0
- openchatbi-0.0.1.dist-info/licenses/LICENSE +21 -0
openchatbi/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""OpenChatBI core module initialization."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
6
|
+
|
|
7
|
+
from openchatbi.config_loader import ConfigLoader
|
|
8
|
+
|
|
9
|
+
# Global configuration instance
|
|
10
|
+
config = ConfigLoader()
|
|
11
|
+
# Skip config loading during documentation build
|
|
12
|
+
if not os.environ.get("SPHINX_BUILD"):
|
|
13
|
+
config.load()
|
|
14
|
+
else:
|
|
15
|
+
config.set({})
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_default_graph():
|
|
19
|
+
"""
|
|
20
|
+
Build the synchronous mode of the agent graph using default catalog in config.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
CompiledStateGraph: Compiled agent graph ready for execution.
|
|
24
|
+
"""
|
|
25
|
+
if os.environ.get("SPHINX_BUILD"):
|
|
26
|
+
return None
|
|
27
|
+
|
|
28
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
29
|
+
from openchatbi.agent_graph import build_agent_graph_sync
|
|
30
|
+
from openchatbi.tool.memory import get_sync_memory_store
|
|
31
|
+
|
|
32
|
+
checkpointer = MemorySaver()
|
|
33
|
+
return build_agent_graph_sync(
|
|
34
|
+
config.get().catalog_store, checkpointer=checkpointer, memory_store=get_sync_memory_store()
|
|
35
|
+
)
|
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
"""Main agent graph construction and execution logic."""
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
import logging
|
|
5
|
+
import traceback
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any, Optional
|
|
8
|
+
|
|
9
|
+
from langchain_core.language_models import BaseChatModel
|
|
10
|
+
from langchain_core.messages import AIMessage, SystemMessage
|
|
11
|
+
from langchain_core.tools import StructuredTool
|
|
12
|
+
from langchain_openai.chat_models.base import BaseChatOpenAI
|
|
13
|
+
from langgraph.constants import START
|
|
14
|
+
from langgraph.errors import GraphInterrupt
|
|
15
|
+
from langgraph.graph import END, StateGraph
|
|
16
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
17
|
+
from langgraph.prebuilt import ToolNode
|
|
18
|
+
from langgraph.store.base import BaseStore
|
|
19
|
+
from langgraph.types import Checkpointer, interrupt, Send
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
21
|
+
|
|
22
|
+
from openchatbi import config
|
|
23
|
+
from openchatbi.catalog import CatalogStore
|
|
24
|
+
from openchatbi.constants import datetime_format
|
|
25
|
+
from openchatbi.graph_state import AgentState, InputState, OutputState
|
|
26
|
+
from openchatbi.llm.llm import call_llm_chat_model_with_retry, default_llm
|
|
27
|
+
from openchatbi.prompts.system_prompt import AGENT_PROMPT_TEMPLATE
|
|
28
|
+
from openchatbi.text2sql.sql_graph import build_sql_graph
|
|
29
|
+
from openchatbi.tool.ask_human import AskHuman
|
|
30
|
+
from openchatbi.tool.mcp_tools import create_mcp_tools_sync, get_mcp_tools_async
|
|
31
|
+
from openchatbi.tool.memory import get_memory_tools
|
|
32
|
+
from openchatbi.tool.run_python_code import run_python_code
|
|
33
|
+
from openchatbi.tool.save_report import save_report
|
|
34
|
+
from openchatbi.tool.search_knowledge import search_knowledge, show_schema
|
|
35
|
+
from openchatbi.utils import log
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def ask_human(state: AgentState) -> dict[str, Any]:
|
|
41
|
+
"""Node function to ask human for additional information or clarification.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
state (AgentState): The current graph state containing messages and context.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
dict: Updated state with human feedback as a tool message and user input.
|
|
48
|
+
"""
|
|
49
|
+
tool_call = state["messages"][-1].tool_calls[0]
|
|
50
|
+
tool_call_id = tool_call["id"]
|
|
51
|
+
args = tool_call["args"]
|
|
52
|
+
user_feedback = interrupt({"text": args["question"], "buttons": args.get("options", None)})
|
|
53
|
+
tool_message = [{"tool_call_id": tool_call_id, "type": "tool", "content": user_feedback}]
|
|
54
|
+
return {"messages": tool_message, "user_input": user_feedback}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class CallSQLGraphInput(BaseModel):
|
|
58
|
+
reasoning: str = Field(
|
|
59
|
+
description="Explanation of why Text2SQL tool is needed",
|
|
60
|
+
)
|
|
61
|
+
context: str = Field(
|
|
62
|
+
description="""The full context pass to Text2SQL tool, make sure do not miss any potential information that related to user's question.
|
|
63
|
+
Following the format: History Conversation: (user and assistant history dialog)
|
|
64
|
+
Information: (the knowledge you retrival that is relevant, like metrics and dimensions)
|
|
65
|
+
User's latest question:""",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# Description for SQL tools
|
|
70
|
+
TEXT2SQL_TOOL_DESCRIPTION = """Text2SQL tool to generate and execute SQL query and build visualization DSL for UI
|
|
71
|
+
based on user's question and context.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
str: A formatted response containing SQL, data, and visualization status.
|
|
75
|
+
|
|
76
|
+
Important notes:
|
|
77
|
+
- If user want to change the visualization chart type or style, add the requirement in the question
|
|
78
|
+
- Make sure to provide question in English
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _format_sql_response(sql_graph_response: dict) -> str:
|
|
83
|
+
"""Format SQL graph response into a standardized string format.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
sql_graph_response: The response dictionary from the SQL graph
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
str: Formatted response string
|
|
90
|
+
"""
|
|
91
|
+
sql = sql_graph_response.get("sql", "")
|
|
92
|
+
data = sql_graph_response.get("data", "")
|
|
93
|
+
visualization_dsl = sql_graph_response.get("visualization_dsl", {})
|
|
94
|
+
|
|
95
|
+
response_parts = []
|
|
96
|
+
if sql:
|
|
97
|
+
response_parts.append(f"SQL Query:\n```sql\n{sql}\n```")
|
|
98
|
+
if data:
|
|
99
|
+
response_parts.append(f"\nQuery Results (CSV format):\n```csv\n{data}\n```")
|
|
100
|
+
|
|
101
|
+
# Include visualization status
|
|
102
|
+
if visualization_dsl and "error" not in visualization_dsl:
|
|
103
|
+
chart_type = visualization_dsl.get("chart_type", "unknown")
|
|
104
|
+
response_parts.append(
|
|
105
|
+
f"\nVisualization Created: {chart_type} chart has been automatically generated and will be displayed in the UI."
|
|
106
|
+
)
|
|
107
|
+
elif visualization_dsl and "error" in visualization_dsl:
|
|
108
|
+
response_parts.append(f"\nVisualization Error: {visualization_dsl['error']}")
|
|
109
|
+
|
|
110
|
+
return "\n\n".join(response_parts) if response_parts else "No results returned."
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_sql_tools(sql_graph: CompiledStateGraph, sync_mode: bool = False) -> Callable:
|
|
114
|
+
"""Create SQL generation tool from compiled SQL graph.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
sql_graph (CompiledStateGraph): The compiled SQL generation subgraph.
|
|
118
|
+
sync_mode (bool): Whether to create synchronous or asynchronous tools
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
function: Tool function for SQL generation.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def call_sql_graph_sync(reasoning: str, context: str) -> str:
|
|
125
|
+
"""Sync node function for Text2SQL tool"""
|
|
126
|
+
log(f"Call SQL graph (sync) with reasoning: {reasoning}, context: {context}")
|
|
127
|
+
try:
|
|
128
|
+
sql_graph_response = sql_graph.invoke({"messages": context})
|
|
129
|
+
return _format_sql_response(sql_graph_response)
|
|
130
|
+
except GraphInterrupt as e:
|
|
131
|
+
log(f"Sql graph interrupted:\n{repr(e)}")
|
|
132
|
+
raise e
|
|
133
|
+
except Exception as e:
|
|
134
|
+
log(f"Run sql graph error:\n{repr(e)}")
|
|
135
|
+
traceback.print_exc()
|
|
136
|
+
return "Error occurred when calling Text2SQL tool."
|
|
137
|
+
|
|
138
|
+
async def call_sql_graph_async(reasoning: str, context: str) -> str:
|
|
139
|
+
"""Async node function for Text2SQL tool"""
|
|
140
|
+
log(f"Call SQL graph (async) with reasoning: {reasoning}, context: {context}")
|
|
141
|
+
try:
|
|
142
|
+
sql_graph_response = await sql_graph.ainvoke({"messages": context})
|
|
143
|
+
return _format_sql_response(sql_graph_response)
|
|
144
|
+
except GraphInterrupt as e:
|
|
145
|
+
log(f"Sql graph interrupted:\n{repr(e)}")
|
|
146
|
+
raise e
|
|
147
|
+
except Exception as e:
|
|
148
|
+
log(f"Run sql graph error:\n{repr(e)}")
|
|
149
|
+
traceback.print_exc()
|
|
150
|
+
return "Error occurred when calling Text2SQL tool."
|
|
151
|
+
|
|
152
|
+
if sync_mode:
|
|
153
|
+
return StructuredTool.from_function(
|
|
154
|
+
func=call_sql_graph_sync,
|
|
155
|
+
name="text2sql",
|
|
156
|
+
description=TEXT2SQL_TOOL_DESCRIPTION,
|
|
157
|
+
args_schema=CallSQLGraphInput,
|
|
158
|
+
return_direct=False,
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
return StructuredTool.from_function(
|
|
162
|
+
coroutine=call_sql_graph_async,
|
|
163
|
+
name="text2sql",
|
|
164
|
+
description=TEXT2SQL_TOOL_DESCRIPTION,
|
|
165
|
+
args_schema=CallSQLGraphInput,
|
|
166
|
+
return_direct=False,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def agent_router(llm: BaseChatModel, tools: list) -> Callable:
|
|
171
|
+
"""Create router function to determine next node based on LLM tool calls.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
llm (BaseChatModel): The LLM for decision-making.
|
|
175
|
+
tools: List of tools.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
function: Router function that processes state and determines next node.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
# OpenAI models support strict tool calling
|
|
182
|
+
if isinstance(llm, BaseChatOpenAI):
|
|
183
|
+
llm_with_tools = llm.bind_tools(tools, strict=True)
|
|
184
|
+
else:
|
|
185
|
+
llm_with_tools = llm.bind_tools(tools)
|
|
186
|
+
|
|
187
|
+
def _call_model(state: AgentState):
|
|
188
|
+
messages = state["messages"]
|
|
189
|
+
system_prompt = AGENT_PROMPT_TEMPLATE.replace(
|
|
190
|
+
"[time_field_placeholder]", datetime.datetime.now().strftime(datetime_format)
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
response = call_llm_chat_model_with_retry(
|
|
194
|
+
llm_with_tools, ([SystemMessage(system_prompt)] + messages), bound_tools=tools, parallel_tool_call=True
|
|
195
|
+
)
|
|
196
|
+
agent_next_node = ""
|
|
197
|
+
if isinstance(response, AIMessage):
|
|
198
|
+
tool_calls = response.tool_calls
|
|
199
|
+
print("Tool Call:", ", ".join(tool["name"] for tool in tool_calls))
|
|
200
|
+
if tool_calls:
|
|
201
|
+
# Group tool calls by type for parallel routing
|
|
202
|
+
ask_human_calls = [call for call in tool_calls if call["name"] == "AskHuman"]
|
|
203
|
+
normal_tool_calls = [call for call in tool_calls if call["name"] != "AskHuman"]
|
|
204
|
+
|
|
205
|
+
# Create Send objects for parallel routing
|
|
206
|
+
sends = []
|
|
207
|
+
if ask_human_calls:
|
|
208
|
+
# Create message with only AskHuman calls
|
|
209
|
+
ask_human_msg = AIMessage(content=response.content, tool_calls=ask_human_calls)
|
|
210
|
+
sends.append(Send("ask_human", {"messages": [ask_human_msg]}))
|
|
211
|
+
|
|
212
|
+
if normal_tool_calls:
|
|
213
|
+
# Create message with only normal tool calls
|
|
214
|
+
tool_msg = AIMessage(content=response.content, tool_calls=normal_tool_calls)
|
|
215
|
+
sends.append(Send("use_tool", {"messages": [tool_msg]}))
|
|
216
|
+
|
|
217
|
+
return {"messages": [response], "sends": sends}
|
|
218
|
+
else:
|
|
219
|
+
return {"messages": [response], "final_answer": response.content, "agent_next_node": END}
|
|
220
|
+
elif response is None:
|
|
221
|
+
return {"messages": [AIMessage("Sorry, the LLM service is currently unavailable.")], "agent_next_node": END}
|
|
222
|
+
else:
|
|
223
|
+
return {"messages": [response], "agent_next_node": END}
|
|
224
|
+
|
|
225
|
+
return _call_model
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def _build_graph_core(
|
|
229
|
+
catalog: CatalogStore,
|
|
230
|
+
sync_mode: bool,
|
|
231
|
+
checkpointer: Checkpointer,
|
|
232
|
+
memory_store: BaseStore,
|
|
233
|
+
memory_tools: Optional[tuple[Callable, Callable]],
|
|
234
|
+
mcp_tools: list,
|
|
235
|
+
) -> CompiledStateGraph:
|
|
236
|
+
"""Core graph building logic shared by both sync and async versions.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
catalog: Catalog store containing schema information
|
|
240
|
+
sync_mode: Whether to use synchronous mode for tools and operations
|
|
241
|
+
checkpointer: The Checkpointer for state persistence
|
|
242
|
+
memory_store: The BaseStore to use for long-term memory
|
|
243
|
+
memory_tools: Tuple of (manage_memory_tool, search_memory_tool)
|
|
244
|
+
mcp_tools: Pre-initialized MCP tools
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
CompiledStateGraph: Compiled agent graph ready for execution
|
|
248
|
+
"""
|
|
249
|
+
sql_graph = build_sql_graph(catalog, checkpointer, memory_store)
|
|
250
|
+
call_sql_graph_tool = get_sql_tools(sql_graph=sql_graph, sync_mode=sync_mode)
|
|
251
|
+
|
|
252
|
+
# Use provided memory tools or create them
|
|
253
|
+
if memory_tools:
|
|
254
|
+
manage_memory_tool, search_memory_tool = memory_tools
|
|
255
|
+
else:
|
|
256
|
+
manage_memory_tool, search_memory_tool = get_memory_tools(default_llm, sync_mode=sync_mode, store=memory_store)
|
|
257
|
+
|
|
258
|
+
log(str(mcp_tools))
|
|
259
|
+
normal_tools = [
|
|
260
|
+
search_knowledge,
|
|
261
|
+
show_schema,
|
|
262
|
+
call_sql_graph_tool,
|
|
263
|
+
run_python_code,
|
|
264
|
+
manage_memory_tool,
|
|
265
|
+
search_memory_tool,
|
|
266
|
+
save_report,
|
|
267
|
+
] + mcp_tools
|
|
268
|
+
tool_node = ToolNode(normal_tools)
|
|
269
|
+
|
|
270
|
+
# Define the agent graph
|
|
271
|
+
graph = StateGraph(AgentState, input_schema=InputState, output_schema=OutputState)
|
|
272
|
+
|
|
273
|
+
# Add nodes to the graph
|
|
274
|
+
graph.add_node("router", agent_router(default_llm, normal_tools + [AskHuman]))
|
|
275
|
+
graph.add_node("ask_human", ask_human)
|
|
276
|
+
graph.add_node("use_tool", tool_node)
|
|
277
|
+
|
|
278
|
+
# Add edges between nodes
|
|
279
|
+
graph.add_edge(START, "router")
|
|
280
|
+
graph.add_edge("ask_human", "router")
|
|
281
|
+
graph.add_edge("use_tool", "router")
|
|
282
|
+
|
|
283
|
+
# Add conditional routing from router node
|
|
284
|
+
def route_tools(state: AgentState):
|
|
285
|
+
# Only use sends if the last message came from the router (has tool_calls)
|
|
286
|
+
last_message = state["messages"][-1] if state["messages"] else None
|
|
287
|
+
if (
|
|
288
|
+
last_message
|
|
289
|
+
and isinstance(last_message, AIMessage)
|
|
290
|
+
and last_message.tool_calls
|
|
291
|
+
and "sends" in state
|
|
292
|
+
and state["sends"]
|
|
293
|
+
):
|
|
294
|
+
return state["sends"] # Return Send objects for parallel execution
|
|
295
|
+
elif "agent_next_node" in state:
|
|
296
|
+
return state["agent_next_node"] # Return single node name
|
|
297
|
+
else:
|
|
298
|
+
return END
|
|
299
|
+
|
|
300
|
+
graph.add_conditional_edges(
|
|
301
|
+
"router",
|
|
302
|
+
route_tools,
|
|
303
|
+
# mapping of paths to node names (for single routing)
|
|
304
|
+
{
|
|
305
|
+
"ask_human": "ask_human",
|
|
306
|
+
"use_tool": "use_tool",
|
|
307
|
+
END: END,
|
|
308
|
+
},
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
graph = graph.compile(name="agent_graph", checkpointer=checkpointer, store=memory_store)
|
|
312
|
+
return graph
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def build_agent_graph_sync(
|
|
316
|
+
catalog: CatalogStore,
|
|
317
|
+
checkpointer: Checkpointer = None,
|
|
318
|
+
memory_store: BaseStore = None,
|
|
319
|
+
) -> CompiledStateGraph:
|
|
320
|
+
"""Build the main agent graph with all nodes and edges (sync version).
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
catalog: Catalog store containing schema information.
|
|
324
|
+
checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory.
|
|
325
|
+
memory_store: The BaseStore to use for long-term memory. If None, will auto assign according to sync_mode.
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
CompiledStateGraph: Compiled agent graph ready for execution.
|
|
329
|
+
"""
|
|
330
|
+
# Get MCP tools for sync context
|
|
331
|
+
mcp_tools = create_mcp_tools_sync(config.get().mcp_servers)
|
|
332
|
+
|
|
333
|
+
return _build_graph_core(
|
|
334
|
+
catalog=catalog,
|
|
335
|
+
sync_mode=True,
|
|
336
|
+
checkpointer=checkpointer,
|
|
337
|
+
memory_store=memory_store,
|
|
338
|
+
memory_tools=None, # Always None for sync version - creates its own
|
|
339
|
+
mcp_tools=mcp_tools,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
async def build_agent_graph_async(
|
|
344
|
+
catalog: CatalogStore,
|
|
345
|
+
checkpointer: Checkpointer = None,
|
|
346
|
+
memory_store: BaseStore = None,
|
|
347
|
+
memory_tools: tuple[Callable, Callable] = None,
|
|
348
|
+
) -> CompiledStateGraph:
|
|
349
|
+
"""Build the main agent graph with all nodes and edges (async version).
|
|
350
|
+
|
|
351
|
+
This function is identical to build_agent_graph_sync but properly handles
|
|
352
|
+
async MCP tool initialization when called from async contexts.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
catalog: Catalog store containing schema information.
|
|
356
|
+
checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory.
|
|
357
|
+
memory_store: The BaseStore to use for long-term memory. If None, will auto assign according to sync_mode.
|
|
358
|
+
memory_tools: Tuple of (manage_memory_tool, search_memory_tool). If None, creates async tools.
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
CompiledStateGraph: Compiled agent graph ready for execution.
|
|
362
|
+
"""
|
|
363
|
+
# Get MCP tools for async context
|
|
364
|
+
mcp_tools = await get_mcp_tools_async(config.get().mcp_servers)
|
|
365
|
+
|
|
366
|
+
return _build_graph_core(
|
|
367
|
+
catalog=catalog,
|
|
368
|
+
sync_mode=False,
|
|
369
|
+
checkpointer=checkpointer,
|
|
370
|
+
memory_store=memory_store,
|
|
371
|
+
memory_tools=memory_tools,
|
|
372
|
+
mcp_tools=mcp_tools,
|
|
373
|
+
)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Data catalog management module for OpenChatBI."""
|
|
2
|
+
|
|
3
|
+
from openchatbi.catalog.catalog_loader import (
|
|
4
|
+
DataCatalogLoader,
|
|
5
|
+
load_catalog_from_data_warehouse,
|
|
6
|
+
)
|
|
7
|
+
from openchatbi.catalog.catalog_store import CatalogStore
|
|
8
|
+
from openchatbi.catalog.factory import create_catalog_store
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"CatalogStore",
|
|
12
|
+
"DataCatalogLoader",
|
|
13
|
+
"load_catalog_from_data_warehouse",
|
|
14
|
+
]
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import MetaData, inspect
|
|
5
|
+
from sqlalchemy.engine import Engine
|
|
6
|
+
|
|
7
|
+
from .catalog_store import CatalogStore
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DataCatalogLoader:
|
|
13
|
+
"""
|
|
14
|
+
The loader to load data catalog from data warehouse metadata and save to catalog store.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, engine: Engine, include_tables: list[str] | None = None):
|
|
18
|
+
"""
|
|
19
|
+
Initialize catalog loader.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
engine (Engine): SQLAlchemy engine instance
|
|
23
|
+
include_tables (Optional[List[str]]): List of table names to include, None for all
|
|
24
|
+
"""
|
|
25
|
+
self.engine = engine
|
|
26
|
+
self.include_tables = include_tables
|
|
27
|
+
self.metadata = MetaData()
|
|
28
|
+
self.inspector = inspect(engine)
|
|
29
|
+
|
|
30
|
+
def get_tables_and_columns(self) -> dict[str, list[dict[str, Any]]]:
|
|
31
|
+
"""
|
|
32
|
+
Extract table and column metadata including comments using SQLAlchemy inspector.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Dict[str, List[Dict[str, Any]]]: Dictionary mapping table names to list of column information
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
tables_columns = {}
|
|
39
|
+
|
|
40
|
+
# Get all table names
|
|
41
|
+
table_names = self.inspector.get_table_names()
|
|
42
|
+
|
|
43
|
+
# Filter to specific tables if configured
|
|
44
|
+
if self.include_tables:
|
|
45
|
+
table_names = [name for name in table_names if name in self.include_tables]
|
|
46
|
+
|
|
47
|
+
logger.info(f"Found {len(table_names)} tables to process")
|
|
48
|
+
|
|
49
|
+
for table_name in table_names:
|
|
50
|
+
try:
|
|
51
|
+
# Get column information for the table
|
|
52
|
+
columns = self.inspector.get_columns(table_name)
|
|
53
|
+
column_list = []
|
|
54
|
+
for column in columns:
|
|
55
|
+
is_common_column = column not in ("id", "name", "type", "status")
|
|
56
|
+
column_info = {
|
|
57
|
+
"column_name": column["name"],
|
|
58
|
+
"display_name": "",
|
|
59
|
+
"alias": "",
|
|
60
|
+
"type": str(column["type"]),
|
|
61
|
+
"category": "",
|
|
62
|
+
"tag": "",
|
|
63
|
+
"description": column.get("comment", "") or "",
|
|
64
|
+
"dimension_table": "",
|
|
65
|
+
"default": str(column.get("default", "")) if column.get("default") is not None else "",
|
|
66
|
+
"is_common": is_common_column,
|
|
67
|
+
}
|
|
68
|
+
column_list.append(column_info)
|
|
69
|
+
|
|
70
|
+
tables_columns[table_name] = column_list
|
|
71
|
+
logger.debug(f"Processed table {table_name} with {len(column_list)} columns")
|
|
72
|
+
|
|
73
|
+
except Exception as e:
|
|
74
|
+
logger.error(f"Failed to process table {table_name}: {e}")
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
logger.info(f"Successfully processed {len(tables_columns)} tables")
|
|
78
|
+
return tables_columns
|
|
79
|
+
|
|
80
|
+
except Exception as e:
|
|
81
|
+
logger.error(f"Failed to get tables and columns from data warehouse: {e}")
|
|
82
|
+
return {}
|
|
83
|
+
|
|
84
|
+
def get_table_indexes(self, table_name: str) -> list[dict[str, Any]]:
|
|
85
|
+
"""
|
|
86
|
+
Get index information for a specific table.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
table_name (str): Name of the table
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
List[Dict[str, Any]]: List of index information
|
|
93
|
+
"""
|
|
94
|
+
try:
|
|
95
|
+
indexes = self.inspector.get_indexes(table_name)
|
|
96
|
+
return indexes
|
|
97
|
+
except Exception as e:
|
|
98
|
+
logger.warning(f"Failed to get indexes for table {table_name}: {e}")
|
|
99
|
+
return []
|
|
100
|
+
|
|
101
|
+
def get_foreign_keys(self, table_name: str) -> list[dict[str, Any]]:
|
|
102
|
+
"""
|
|
103
|
+
Get foreign key information for a specific table.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
table_name (str): Name of the table
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
List[Dict[str, Any]]: List of foreign key information
|
|
110
|
+
"""
|
|
111
|
+
try:
|
|
112
|
+
foreign_keys = self.inspector.get_foreign_keys(table_name)
|
|
113
|
+
return foreign_keys
|
|
114
|
+
except Exception as e:
|
|
115
|
+
logger.warning(f"Failed to get foreign keys for table {table_name}: {e}")
|
|
116
|
+
return []
|
|
117
|
+
|
|
118
|
+
def save_to_catalog_store(
|
|
119
|
+
self, catalog_store: CatalogStore, database_name: str | None = None, update: bool = False
|
|
120
|
+
) -> bool:
|
|
121
|
+
"""
|
|
122
|
+
Extract warehouse metadata and save to catalog store.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
catalog_store (CatalogStore): Target catalog store to load data to
|
|
126
|
+
database_name (Optional[str]): Database name in catalog, defaults to 'default'
|
|
127
|
+
update (bool): Update existing catalog store to sync with data warehouse
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
bool: True if load was successful, False otherwise
|
|
131
|
+
"""
|
|
132
|
+
try:
|
|
133
|
+
if database_name is None:
|
|
134
|
+
database_name = "default"
|
|
135
|
+
|
|
136
|
+
# Get tables and columns from data warehouse
|
|
137
|
+
tables_columns = self.get_tables_and_columns()
|
|
138
|
+
|
|
139
|
+
if not tables_columns:
|
|
140
|
+
logger.warning("No tables found in data warehouse")
|
|
141
|
+
return True
|
|
142
|
+
|
|
143
|
+
# Import each table
|
|
144
|
+
success_count = 0
|
|
145
|
+
total_count = len(tables_columns)
|
|
146
|
+
|
|
147
|
+
for table_name, columns in tables_columns.items():
|
|
148
|
+
try:
|
|
149
|
+
# Get table comment if available
|
|
150
|
+
table_comment = ""
|
|
151
|
+
try:
|
|
152
|
+
table_info = self.inspector.get_table_comment(table_name)
|
|
153
|
+
table_comment = table_info.get("text", "") if table_info else ""
|
|
154
|
+
except Exception:
|
|
155
|
+
# Some databases don't support table comments
|
|
156
|
+
pass
|
|
157
|
+
|
|
158
|
+
table_info = {"description": table_comment, "selection_rule": "", "sql_rule": ""}
|
|
159
|
+
if catalog_store.save_table_information(table_name, table_info, columns, database_name):
|
|
160
|
+
success_count += 1
|
|
161
|
+
logger.info(f"Successfully loaded table: {database_name}.{table_name}")
|
|
162
|
+
else:
|
|
163
|
+
logger.error(f"Failed to load table: {database_name}.{table_name}")
|
|
164
|
+
|
|
165
|
+
# init null SQL examples
|
|
166
|
+
catalog_store.save_table_sql_examples(
|
|
167
|
+
table_name, [{"question": "null", "answer": "null"}], database_name
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.error(f"Error loading table {table_name}: {e}")
|
|
172
|
+
|
|
173
|
+
# init empty table selection examples
|
|
174
|
+
catalog_store.save_table_selection_examples([("", [])])
|
|
175
|
+
|
|
176
|
+
logger.info(f"Load completed: {success_count}/{total_count} tables loaded successfully")
|
|
177
|
+
return success_count == total_count
|
|
178
|
+
|
|
179
|
+
except Exception as e:
|
|
180
|
+
logger.error(f"Failed to load data warehouse to catalog store: {e}")
|
|
181
|
+
return False
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def load_catalog_from_data_warehouse(catalog_store: CatalogStore) -> bool:
|
|
185
|
+
"""
|
|
186
|
+
Load catalog data from data warehouse using SQLAlchemy based on data warehouse config (URI)
|
|
187
|
+
|
|
188
|
+
Main entry point for catalog loading.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
catalog_store (CatalogStore): Target catalog store
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
bool: True if load was successful, False otherwise
|
|
195
|
+
"""
|
|
196
|
+
try:
|
|
197
|
+
data_warehouse_config = catalog_store.get_data_warehouse_config()
|
|
198
|
+
database_uri = data_warehouse_config.get("uri")
|
|
199
|
+
include_tables = data_warehouse_config.get("include_tables")
|
|
200
|
+
database_name = data_warehouse_config.get("database_name", "default")
|
|
201
|
+
engine = catalog_store.get_sql_engine()
|
|
202
|
+
|
|
203
|
+
loader = DataCatalogLoader(engine, include_tables)
|
|
204
|
+
return loader.save_to_catalog_store(catalog_store, database_name)
|
|
205
|
+
|
|
206
|
+
except Exception as e:
|
|
207
|
+
logger.error(f"Failed to import catalog from data warehouse URI {database_uri}: {e}")
|
|
208
|
+
return False
|