openchatbi 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. openchatbi/__init__.py +35 -0
  2. openchatbi/agent_graph.py +373 -0
  3. openchatbi/catalog/__init__.py +14 -0
  4. openchatbi/catalog/catalog_loader.py +208 -0
  5. openchatbi/catalog/catalog_store.py +202 -0
  6. openchatbi/catalog/entry.py +5 -0
  7. openchatbi/catalog/factory.py +81 -0
  8. openchatbi/catalog/helper.py +49 -0
  9. openchatbi/catalog/retrival_helper.py +74 -0
  10. openchatbi/catalog/schema_retrival.py +144 -0
  11. openchatbi/catalog/store/__init__.py +3 -0
  12. openchatbi/catalog/store/file_system.py +789 -0
  13. openchatbi/catalog/token_service.py +48 -0
  14. openchatbi/code/docker_executor.py +179 -0
  15. openchatbi/code/executor_base.py +21 -0
  16. openchatbi/code/local_executor.py +21 -0
  17. openchatbi/code/restricted_local_executor.py +47 -0
  18. openchatbi/config.yaml.template +74 -0
  19. openchatbi/config_loader.py +225 -0
  20. openchatbi/constants.py +17 -0
  21. openchatbi/graph_state.py +59 -0
  22. openchatbi/llm/llm.py +94 -0
  23. openchatbi/prompts/agent_prompt.md +48 -0
  24. openchatbi/prompts/extraction_prompt.md +175 -0
  25. openchatbi/prompts/schema_linking_prompt.md +56 -0
  26. openchatbi/prompts/sql_dialect/presto.md +57 -0
  27. openchatbi/prompts/system_prompt.py +92 -0
  28. openchatbi/prompts/text2sql_prompt.md +35 -0
  29. openchatbi/prompts/visualization_prompt.md +34 -0
  30. openchatbi/text2sql/__init__.py +1 -0
  31. openchatbi/text2sql/data.py +12 -0
  32. openchatbi/text2sql/extraction.py +122 -0
  33. openchatbi/text2sql/generate_sql.py +400 -0
  34. openchatbi/text2sql/schema_linking.py +239 -0
  35. openchatbi/text2sql/sql_graph.py +150 -0
  36. openchatbi/text2sql/text2sql_utils.py +57 -0
  37. openchatbi/text2sql/visualization.py +315 -0
  38. openchatbi/tool/ask_human.py +15 -0
  39. openchatbi/tool/mcp_tools.py +257 -0
  40. openchatbi/tool/memory.py +181 -0
  41. openchatbi/tool/run_python_code.py +70 -0
  42. openchatbi/tool/save_report.py +65 -0
  43. openchatbi/tool/search_knowledge.py +107 -0
  44. openchatbi/utils.py +183 -0
  45. openchatbi-0.0.1.dist-info/METADATA +674 -0
  46. openchatbi-0.0.1.dist-info/RECORD +48 -0
  47. openchatbi-0.0.1.dist-info/WHEEL +4 -0
  48. openchatbi-0.0.1.dist-info/licenses/LICENSE +21 -0
openchatbi/__init__.py ADDED
@@ -0,0 +1,35 @@
1
+ """OpenChatBI core module initialization."""
2
+
3
+ import os
4
+
5
+ from langgraph.graph.state import CompiledStateGraph
6
+
7
+ from openchatbi.config_loader import ConfigLoader
8
+
9
+ # Global configuration instance
10
+ config = ConfigLoader()
11
+ # Skip config loading during documentation build
12
+ if not os.environ.get("SPHINX_BUILD"):
13
+ config.load()
14
+ else:
15
+ config.set({})
16
+
17
+
18
+ def get_default_graph():
19
+ """
20
+ Build the synchronous mode of the agent graph using default catalog in config.
21
+
22
+ Returns:
23
+ CompiledStateGraph: Compiled agent graph ready for execution.
24
+ """
25
+ if os.environ.get("SPHINX_BUILD"):
26
+ return None
27
+
28
+ from langgraph.checkpoint.memory import MemorySaver
29
+ from openchatbi.agent_graph import build_agent_graph_sync
30
+ from openchatbi.tool.memory import get_sync_memory_store
31
+
32
+ checkpointer = MemorySaver()
33
+ return build_agent_graph_sync(
34
+ config.get().catalog_store, checkpointer=checkpointer, memory_store=get_sync_memory_store()
35
+ )
@@ -0,0 +1,373 @@
1
+ """Main agent graph construction and execution logic."""
2
+
3
+ import datetime
4
+ import logging
5
+ import traceback
6
+ from collections.abc import Callable
7
+ from typing import Any, Optional
8
+
9
+ from langchain_core.language_models import BaseChatModel
10
+ from langchain_core.messages import AIMessage, SystemMessage
11
+ from langchain_core.tools import StructuredTool
12
+ from langchain_openai.chat_models.base import BaseChatOpenAI
13
+ from langgraph.constants import START
14
+ from langgraph.errors import GraphInterrupt
15
+ from langgraph.graph import END, StateGraph
16
+ from langgraph.graph.state import CompiledStateGraph
17
+ from langgraph.prebuilt import ToolNode
18
+ from langgraph.store.base import BaseStore
19
+ from langgraph.types import Checkpointer, interrupt, Send
20
+ from pydantic import BaseModel, Field
21
+
22
+ from openchatbi import config
23
+ from openchatbi.catalog import CatalogStore
24
+ from openchatbi.constants import datetime_format
25
+ from openchatbi.graph_state import AgentState, InputState, OutputState
26
+ from openchatbi.llm.llm import call_llm_chat_model_with_retry, default_llm
27
+ from openchatbi.prompts.system_prompt import AGENT_PROMPT_TEMPLATE
28
+ from openchatbi.text2sql.sql_graph import build_sql_graph
29
+ from openchatbi.tool.ask_human import AskHuman
30
+ from openchatbi.tool.mcp_tools import create_mcp_tools_sync, get_mcp_tools_async
31
+ from openchatbi.tool.memory import get_memory_tools
32
+ from openchatbi.tool.run_python_code import run_python_code
33
+ from openchatbi.tool.save_report import save_report
34
+ from openchatbi.tool.search_knowledge import search_knowledge, show_schema
35
+ from openchatbi.utils import log
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ def ask_human(state: AgentState) -> dict[str, Any]:
41
+ """Node function to ask human for additional information or clarification.
42
+
43
+ Args:
44
+ state (AgentState): The current graph state containing messages and context.
45
+
46
+ Returns:
47
+ dict: Updated state with human feedback as a tool message and user input.
48
+ """
49
+ tool_call = state["messages"][-1].tool_calls[0]
50
+ tool_call_id = tool_call["id"]
51
+ args = tool_call["args"]
52
+ user_feedback = interrupt({"text": args["question"], "buttons": args.get("options", None)})
53
+ tool_message = [{"tool_call_id": tool_call_id, "type": "tool", "content": user_feedback}]
54
+ return {"messages": tool_message, "user_input": user_feedback}
55
+
56
+
57
+ class CallSQLGraphInput(BaseModel):
58
+ reasoning: str = Field(
59
+ description="Explanation of why Text2SQL tool is needed",
60
+ )
61
+ context: str = Field(
62
+ description="""The full context pass to Text2SQL tool, make sure do not miss any potential information that related to user's question.
63
+ Following the format: History Conversation: (user and assistant history dialog)
64
+ Information: (the knowledge you retrival that is relevant, like metrics and dimensions)
65
+ User's latest question:""",
66
+ )
67
+
68
+
69
+ # Description for SQL tools
70
+ TEXT2SQL_TOOL_DESCRIPTION = """Text2SQL tool to generate and execute SQL query and build visualization DSL for UI
71
+ based on user's question and context.
72
+
73
+ Returns:
74
+ str: A formatted response containing SQL, data, and visualization status.
75
+
76
+ Important notes:
77
+ - If user want to change the visualization chart type or style, add the requirement in the question
78
+ - Make sure to provide question in English
79
+ """
80
+
81
+
82
+ def _format_sql_response(sql_graph_response: dict) -> str:
83
+ """Format SQL graph response into a standardized string format.
84
+
85
+ Args:
86
+ sql_graph_response: The response dictionary from the SQL graph
87
+
88
+ Returns:
89
+ str: Formatted response string
90
+ """
91
+ sql = sql_graph_response.get("sql", "")
92
+ data = sql_graph_response.get("data", "")
93
+ visualization_dsl = sql_graph_response.get("visualization_dsl", {})
94
+
95
+ response_parts = []
96
+ if sql:
97
+ response_parts.append(f"SQL Query:\n```sql\n{sql}\n```")
98
+ if data:
99
+ response_parts.append(f"\nQuery Results (CSV format):\n```csv\n{data}\n```")
100
+
101
+ # Include visualization status
102
+ if visualization_dsl and "error" not in visualization_dsl:
103
+ chart_type = visualization_dsl.get("chart_type", "unknown")
104
+ response_parts.append(
105
+ f"\nVisualization Created: {chart_type} chart has been automatically generated and will be displayed in the UI."
106
+ )
107
+ elif visualization_dsl and "error" in visualization_dsl:
108
+ response_parts.append(f"\nVisualization Error: {visualization_dsl['error']}")
109
+
110
+ return "\n\n".join(response_parts) if response_parts else "No results returned."
111
+
112
+
113
+ def get_sql_tools(sql_graph: CompiledStateGraph, sync_mode: bool = False) -> Callable:
114
+ """Create SQL generation tool from compiled SQL graph.
115
+
116
+ Args:
117
+ sql_graph (CompiledStateGraph): The compiled SQL generation subgraph.
118
+ sync_mode (bool): Whether to create synchronous or asynchronous tools
119
+
120
+ Returns:
121
+ function: Tool function for SQL generation.
122
+ """
123
+
124
+ def call_sql_graph_sync(reasoning: str, context: str) -> str:
125
+ """Sync node function for Text2SQL tool"""
126
+ log(f"Call SQL graph (sync) with reasoning: {reasoning}, context: {context}")
127
+ try:
128
+ sql_graph_response = sql_graph.invoke({"messages": context})
129
+ return _format_sql_response(sql_graph_response)
130
+ except GraphInterrupt as e:
131
+ log(f"Sql graph interrupted:\n{repr(e)}")
132
+ raise e
133
+ except Exception as e:
134
+ log(f"Run sql graph error:\n{repr(e)}")
135
+ traceback.print_exc()
136
+ return "Error occurred when calling Text2SQL tool."
137
+
138
+ async def call_sql_graph_async(reasoning: str, context: str) -> str:
139
+ """Async node function for Text2SQL tool"""
140
+ log(f"Call SQL graph (async) with reasoning: {reasoning}, context: {context}")
141
+ try:
142
+ sql_graph_response = await sql_graph.ainvoke({"messages": context})
143
+ return _format_sql_response(sql_graph_response)
144
+ except GraphInterrupt as e:
145
+ log(f"Sql graph interrupted:\n{repr(e)}")
146
+ raise e
147
+ except Exception as e:
148
+ log(f"Run sql graph error:\n{repr(e)}")
149
+ traceback.print_exc()
150
+ return "Error occurred when calling Text2SQL tool."
151
+
152
+ if sync_mode:
153
+ return StructuredTool.from_function(
154
+ func=call_sql_graph_sync,
155
+ name="text2sql",
156
+ description=TEXT2SQL_TOOL_DESCRIPTION,
157
+ args_schema=CallSQLGraphInput,
158
+ return_direct=False,
159
+ )
160
+ else:
161
+ return StructuredTool.from_function(
162
+ coroutine=call_sql_graph_async,
163
+ name="text2sql",
164
+ description=TEXT2SQL_TOOL_DESCRIPTION,
165
+ args_schema=CallSQLGraphInput,
166
+ return_direct=False,
167
+ )
168
+
169
+
170
+ def agent_router(llm: BaseChatModel, tools: list) -> Callable:
171
+ """Create router function to determine next node based on LLM tool calls.
172
+
173
+ Args:
174
+ llm (BaseChatModel): The LLM for decision-making.
175
+ tools: List of tools.
176
+
177
+ Returns:
178
+ function: Router function that processes state and determines next node.
179
+ """
180
+
181
+ # OpenAI models support strict tool calling
182
+ if isinstance(llm, BaseChatOpenAI):
183
+ llm_with_tools = llm.bind_tools(tools, strict=True)
184
+ else:
185
+ llm_with_tools = llm.bind_tools(tools)
186
+
187
+ def _call_model(state: AgentState):
188
+ messages = state["messages"]
189
+ system_prompt = AGENT_PROMPT_TEMPLATE.replace(
190
+ "[time_field_placeholder]", datetime.datetime.now().strftime(datetime_format)
191
+ )
192
+
193
+ response = call_llm_chat_model_with_retry(
194
+ llm_with_tools, ([SystemMessage(system_prompt)] + messages), bound_tools=tools, parallel_tool_call=True
195
+ )
196
+ agent_next_node = ""
197
+ if isinstance(response, AIMessage):
198
+ tool_calls = response.tool_calls
199
+ print("Tool Call:", ", ".join(tool["name"] for tool in tool_calls))
200
+ if tool_calls:
201
+ # Group tool calls by type for parallel routing
202
+ ask_human_calls = [call for call in tool_calls if call["name"] == "AskHuman"]
203
+ normal_tool_calls = [call for call in tool_calls if call["name"] != "AskHuman"]
204
+
205
+ # Create Send objects for parallel routing
206
+ sends = []
207
+ if ask_human_calls:
208
+ # Create message with only AskHuman calls
209
+ ask_human_msg = AIMessage(content=response.content, tool_calls=ask_human_calls)
210
+ sends.append(Send("ask_human", {"messages": [ask_human_msg]}))
211
+
212
+ if normal_tool_calls:
213
+ # Create message with only normal tool calls
214
+ tool_msg = AIMessage(content=response.content, tool_calls=normal_tool_calls)
215
+ sends.append(Send("use_tool", {"messages": [tool_msg]}))
216
+
217
+ return {"messages": [response], "sends": sends}
218
+ else:
219
+ return {"messages": [response], "final_answer": response.content, "agent_next_node": END}
220
+ elif response is None:
221
+ return {"messages": [AIMessage("Sorry, the LLM service is currently unavailable.")], "agent_next_node": END}
222
+ else:
223
+ return {"messages": [response], "agent_next_node": END}
224
+
225
+ return _call_model
226
+
227
+
228
+ def _build_graph_core(
229
+ catalog: CatalogStore,
230
+ sync_mode: bool,
231
+ checkpointer: Checkpointer,
232
+ memory_store: BaseStore,
233
+ memory_tools: Optional[tuple[Callable, Callable]],
234
+ mcp_tools: list,
235
+ ) -> CompiledStateGraph:
236
+ """Core graph building logic shared by both sync and async versions.
237
+
238
+ Args:
239
+ catalog: Catalog store containing schema information
240
+ sync_mode: Whether to use synchronous mode for tools and operations
241
+ checkpointer: The Checkpointer for state persistence
242
+ memory_store: The BaseStore to use for long-term memory
243
+ memory_tools: Tuple of (manage_memory_tool, search_memory_tool)
244
+ mcp_tools: Pre-initialized MCP tools
245
+
246
+ Returns:
247
+ CompiledStateGraph: Compiled agent graph ready for execution
248
+ """
249
+ sql_graph = build_sql_graph(catalog, checkpointer, memory_store)
250
+ call_sql_graph_tool = get_sql_tools(sql_graph=sql_graph, sync_mode=sync_mode)
251
+
252
+ # Use provided memory tools or create them
253
+ if memory_tools:
254
+ manage_memory_tool, search_memory_tool = memory_tools
255
+ else:
256
+ manage_memory_tool, search_memory_tool = get_memory_tools(default_llm, sync_mode=sync_mode, store=memory_store)
257
+
258
+ log(str(mcp_tools))
259
+ normal_tools = [
260
+ search_knowledge,
261
+ show_schema,
262
+ call_sql_graph_tool,
263
+ run_python_code,
264
+ manage_memory_tool,
265
+ search_memory_tool,
266
+ save_report,
267
+ ] + mcp_tools
268
+ tool_node = ToolNode(normal_tools)
269
+
270
+ # Define the agent graph
271
+ graph = StateGraph(AgentState, input_schema=InputState, output_schema=OutputState)
272
+
273
+ # Add nodes to the graph
274
+ graph.add_node("router", agent_router(default_llm, normal_tools + [AskHuman]))
275
+ graph.add_node("ask_human", ask_human)
276
+ graph.add_node("use_tool", tool_node)
277
+
278
+ # Add edges between nodes
279
+ graph.add_edge(START, "router")
280
+ graph.add_edge("ask_human", "router")
281
+ graph.add_edge("use_tool", "router")
282
+
283
+ # Add conditional routing from router node
284
+ def route_tools(state: AgentState):
285
+ # Only use sends if the last message came from the router (has tool_calls)
286
+ last_message = state["messages"][-1] if state["messages"] else None
287
+ if (
288
+ last_message
289
+ and isinstance(last_message, AIMessage)
290
+ and last_message.tool_calls
291
+ and "sends" in state
292
+ and state["sends"]
293
+ ):
294
+ return state["sends"] # Return Send objects for parallel execution
295
+ elif "agent_next_node" in state:
296
+ return state["agent_next_node"] # Return single node name
297
+ else:
298
+ return END
299
+
300
+ graph.add_conditional_edges(
301
+ "router",
302
+ route_tools,
303
+ # mapping of paths to node names (for single routing)
304
+ {
305
+ "ask_human": "ask_human",
306
+ "use_tool": "use_tool",
307
+ END: END,
308
+ },
309
+ )
310
+
311
+ graph = graph.compile(name="agent_graph", checkpointer=checkpointer, store=memory_store)
312
+ return graph
313
+
314
+
315
+ def build_agent_graph_sync(
316
+ catalog: CatalogStore,
317
+ checkpointer: Checkpointer = None,
318
+ memory_store: BaseStore = None,
319
+ ) -> CompiledStateGraph:
320
+ """Build the main agent graph with all nodes and edges (sync version).
321
+
322
+ Args:
323
+ catalog: Catalog store containing schema information.
324
+ checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory.
325
+ memory_store: The BaseStore to use for long-term memory. If None, will auto assign according to sync_mode.
326
+
327
+ Returns:
328
+ CompiledStateGraph: Compiled agent graph ready for execution.
329
+ """
330
+ # Get MCP tools for sync context
331
+ mcp_tools = create_mcp_tools_sync(config.get().mcp_servers)
332
+
333
+ return _build_graph_core(
334
+ catalog=catalog,
335
+ sync_mode=True,
336
+ checkpointer=checkpointer,
337
+ memory_store=memory_store,
338
+ memory_tools=None, # Always None for sync version - creates its own
339
+ mcp_tools=mcp_tools,
340
+ )
341
+
342
+
343
+ async def build_agent_graph_async(
344
+ catalog: CatalogStore,
345
+ checkpointer: Checkpointer = None,
346
+ memory_store: BaseStore = None,
347
+ memory_tools: tuple[Callable, Callable] = None,
348
+ ) -> CompiledStateGraph:
349
+ """Build the main agent graph with all nodes and edges (async version).
350
+
351
+ This function is identical to build_agent_graph_sync but properly handles
352
+ async MCP tool initialization when called from async contexts.
353
+
354
+ Args:
355
+ catalog: Catalog store containing schema information.
356
+ checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory.
357
+ memory_store: The BaseStore to use for long-term memory. If None, will auto assign according to sync_mode.
358
+ memory_tools: Tuple of (manage_memory_tool, search_memory_tool). If None, creates async tools.
359
+
360
+ Returns:
361
+ CompiledStateGraph: Compiled agent graph ready for execution.
362
+ """
363
+ # Get MCP tools for async context
364
+ mcp_tools = await get_mcp_tools_async(config.get().mcp_servers)
365
+
366
+ return _build_graph_core(
367
+ catalog=catalog,
368
+ sync_mode=False,
369
+ checkpointer=checkpointer,
370
+ memory_store=memory_store,
371
+ memory_tools=memory_tools,
372
+ mcp_tools=mcp_tools,
373
+ )
@@ -0,0 +1,14 @@
1
+ """Data catalog management module for OpenChatBI."""
2
+
3
+ from openchatbi.catalog.catalog_loader import (
4
+ DataCatalogLoader,
5
+ load_catalog_from_data_warehouse,
6
+ )
7
+ from openchatbi.catalog.catalog_store import CatalogStore
8
+ from openchatbi.catalog.factory import create_catalog_store
9
+
10
+ __all__ = [
11
+ "CatalogStore",
12
+ "DataCatalogLoader",
13
+ "load_catalog_from_data_warehouse",
14
+ ]
@@ -0,0 +1,208 @@
1
+ import logging
2
+ from typing import Any
3
+
4
+ from sqlalchemy import MetaData, inspect
5
+ from sqlalchemy.engine import Engine
6
+
7
+ from .catalog_store import CatalogStore
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class DataCatalogLoader:
13
+ """
14
+ The loader to load data catalog from data warehouse metadata and save to catalog store.
15
+ """
16
+
17
+ def __init__(self, engine: Engine, include_tables: list[str] | None = None):
18
+ """
19
+ Initialize catalog loader.
20
+
21
+ Args:
22
+ engine (Engine): SQLAlchemy engine instance
23
+ include_tables (Optional[List[str]]): List of table names to include, None for all
24
+ """
25
+ self.engine = engine
26
+ self.include_tables = include_tables
27
+ self.metadata = MetaData()
28
+ self.inspector = inspect(engine)
29
+
30
+ def get_tables_and_columns(self) -> dict[str, list[dict[str, Any]]]:
31
+ """
32
+ Extract table and column metadata including comments using SQLAlchemy inspector.
33
+
34
+ Returns:
35
+ Dict[str, List[Dict[str, Any]]]: Dictionary mapping table names to list of column information
36
+ """
37
+ try:
38
+ tables_columns = {}
39
+
40
+ # Get all table names
41
+ table_names = self.inspector.get_table_names()
42
+
43
+ # Filter to specific tables if configured
44
+ if self.include_tables:
45
+ table_names = [name for name in table_names if name in self.include_tables]
46
+
47
+ logger.info(f"Found {len(table_names)} tables to process")
48
+
49
+ for table_name in table_names:
50
+ try:
51
+ # Get column information for the table
52
+ columns = self.inspector.get_columns(table_name)
53
+ column_list = []
54
+ for column in columns:
55
+ is_common_column = column not in ("id", "name", "type", "status")
56
+ column_info = {
57
+ "column_name": column["name"],
58
+ "display_name": "",
59
+ "alias": "",
60
+ "type": str(column["type"]),
61
+ "category": "",
62
+ "tag": "",
63
+ "description": column.get("comment", "") or "",
64
+ "dimension_table": "",
65
+ "default": str(column.get("default", "")) if column.get("default") is not None else "",
66
+ "is_common": is_common_column,
67
+ }
68
+ column_list.append(column_info)
69
+
70
+ tables_columns[table_name] = column_list
71
+ logger.debug(f"Processed table {table_name} with {len(column_list)} columns")
72
+
73
+ except Exception as e:
74
+ logger.error(f"Failed to process table {table_name}: {e}")
75
+ continue
76
+
77
+ logger.info(f"Successfully processed {len(tables_columns)} tables")
78
+ return tables_columns
79
+
80
+ except Exception as e:
81
+ logger.error(f"Failed to get tables and columns from data warehouse: {e}")
82
+ return {}
83
+
84
+ def get_table_indexes(self, table_name: str) -> list[dict[str, Any]]:
85
+ """
86
+ Get index information for a specific table.
87
+
88
+ Args:
89
+ table_name (str): Name of the table
90
+
91
+ Returns:
92
+ List[Dict[str, Any]]: List of index information
93
+ """
94
+ try:
95
+ indexes = self.inspector.get_indexes(table_name)
96
+ return indexes
97
+ except Exception as e:
98
+ logger.warning(f"Failed to get indexes for table {table_name}: {e}")
99
+ return []
100
+
101
+ def get_foreign_keys(self, table_name: str) -> list[dict[str, Any]]:
102
+ """
103
+ Get foreign key information for a specific table.
104
+
105
+ Args:
106
+ table_name (str): Name of the table
107
+
108
+ Returns:
109
+ List[Dict[str, Any]]: List of foreign key information
110
+ """
111
+ try:
112
+ foreign_keys = self.inspector.get_foreign_keys(table_name)
113
+ return foreign_keys
114
+ except Exception as e:
115
+ logger.warning(f"Failed to get foreign keys for table {table_name}: {e}")
116
+ return []
117
+
118
+ def save_to_catalog_store(
119
+ self, catalog_store: CatalogStore, database_name: str | None = None, update: bool = False
120
+ ) -> bool:
121
+ """
122
+ Extract warehouse metadata and save to catalog store.
123
+
124
+ Args:
125
+ catalog_store (CatalogStore): Target catalog store to load data to
126
+ database_name (Optional[str]): Database name in catalog, defaults to 'default'
127
+ update (bool): Update existing catalog store to sync with data warehouse
128
+
129
+ Returns:
130
+ bool: True if load was successful, False otherwise
131
+ """
132
+ try:
133
+ if database_name is None:
134
+ database_name = "default"
135
+
136
+ # Get tables and columns from data warehouse
137
+ tables_columns = self.get_tables_and_columns()
138
+
139
+ if not tables_columns:
140
+ logger.warning("No tables found in data warehouse")
141
+ return True
142
+
143
+ # Import each table
144
+ success_count = 0
145
+ total_count = len(tables_columns)
146
+
147
+ for table_name, columns in tables_columns.items():
148
+ try:
149
+ # Get table comment if available
150
+ table_comment = ""
151
+ try:
152
+ table_info = self.inspector.get_table_comment(table_name)
153
+ table_comment = table_info.get("text", "") if table_info else ""
154
+ except Exception:
155
+ # Some databases don't support table comments
156
+ pass
157
+
158
+ table_info = {"description": table_comment, "selection_rule": "", "sql_rule": ""}
159
+ if catalog_store.save_table_information(table_name, table_info, columns, database_name):
160
+ success_count += 1
161
+ logger.info(f"Successfully loaded table: {database_name}.{table_name}")
162
+ else:
163
+ logger.error(f"Failed to load table: {database_name}.{table_name}")
164
+
165
+ # init null SQL examples
166
+ catalog_store.save_table_sql_examples(
167
+ table_name, [{"question": "null", "answer": "null"}], database_name
168
+ )
169
+
170
+ except Exception as e:
171
+ logger.error(f"Error loading table {table_name}: {e}")
172
+
173
+ # init empty table selection examples
174
+ catalog_store.save_table_selection_examples([("", [])])
175
+
176
+ logger.info(f"Load completed: {success_count}/{total_count} tables loaded successfully")
177
+ return success_count == total_count
178
+
179
+ except Exception as e:
180
+ logger.error(f"Failed to load data warehouse to catalog store: {e}")
181
+ return False
182
+
183
+
184
+ def load_catalog_from_data_warehouse(catalog_store: CatalogStore) -> bool:
185
+ """
186
+ Load catalog data from data warehouse using SQLAlchemy based on data warehouse config (URI)
187
+
188
+ Main entry point for catalog loading.
189
+
190
+ Args:
191
+ catalog_store (CatalogStore): Target catalog store
192
+
193
+ Returns:
194
+ bool: True if load was successful, False otherwise
195
+ """
196
+ try:
197
+ data_warehouse_config = catalog_store.get_data_warehouse_config()
198
+ database_uri = data_warehouse_config.get("uri")
199
+ include_tables = data_warehouse_config.get("include_tables")
200
+ database_name = data_warehouse_config.get("database_name", "default")
201
+ engine = catalog_store.get_sql_engine()
202
+
203
+ loader = DataCatalogLoader(engine, include_tables)
204
+ return loader.save_to_catalog_store(catalog_store, database_name)
205
+
206
+ except Exception as e:
207
+ logger.error(f"Failed to import catalog from data warehouse URI {database_uri}: {e}")
208
+ return False