dao-ai 0.0.34__py3-none-any.whl → 0.0.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/tools/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
+ from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
1
2
  from dao_ai.hooks.core import create_hooks
2
3
  from dao_ai.tools.agent import create_agent_endpoint_tool
3
4
  from dao_ai.tools.core import (
@@ -35,7 +36,9 @@ __all__ = [
35
36
  "current_time_tool",
36
37
  "format_time_tool",
37
38
  "is_business_hours_tool",
39
+ "LRUCacheService",
38
40
  "search_tool",
41
+ "SemanticCacheService",
39
42
  "time_difference_tool",
40
43
  "time_in_timezone_tool",
41
44
  "time_until_tool",
@@ -0,0 +1,236 @@
1
+ """
2
+ Genie tools for natural language queries to databases.
3
+
4
+ This package provides tools for interacting with Databricks Genie to translate
5
+ natural language questions into SQL queries.
6
+
7
+ Main exports:
8
+ - create_genie_tool: Factory function to create a Genie tool with optional caching
9
+
10
+ Cache implementations are available in the genie cache package:
11
+ - dao_ai.genie.cache.lru: LRU (Least Recently Used) cache
12
+ - dao_ai.genie.cache.semantic: Semantic similarity cache using pg_vector
13
+ """
14
+
15
+ import json
16
+ import os
17
+ from textwrap import dedent
18
+ from typing import Annotated, Any, Callable
19
+
20
+ import pandas as pd
21
+ from databricks_ai_bridge.genie import Genie, GenieResponse
22
+ from langchain.tools import tool
23
+ from langchain_core.messages import ToolMessage
24
+ from langchain_core.tools import InjectedToolCallId
25
+ from langgraph.prebuilt import InjectedState
26
+ from langgraph.types import Command
27
+ from loguru import logger
28
+ from pydantic import BaseModel
29
+
30
+ from dao_ai.config import (
31
+ AnyVariable,
32
+ CompositeVariableModel,
33
+ GenieLRUCacheParametersModel,
34
+ GenieRoomModel,
35
+ GenieSemanticCacheParametersModel,
36
+ value_of,
37
+ )
38
+ from dao_ai.genie import GenieService
39
+ from dao_ai.genie.cache import (
40
+ CacheResult,
41
+ GenieServiceBase,
42
+ LRUCacheService,
43
+ SemanticCacheService,
44
+ SQLCacheEntry,
45
+ )
46
+
47
+
48
+ class GenieToolInput(BaseModel):
49
+ """Input schema for Genie tool - only includes user-facing parameters."""
50
+
51
+ question: str
52
+
53
+
54
+ def _response_to_json(response: GenieResponse) -> str:
55
+ """Convert GenieResponse to JSON string, handling DataFrame results."""
56
+ # Convert result to string if it's a DataFrame
57
+ result: str | pd.DataFrame = response.result
58
+ if isinstance(result, pd.DataFrame):
59
+ result = result.to_markdown()
60
+
61
+ data: dict[str, Any] = {
62
+ "result": result,
63
+ "query": response.query,
64
+ "description": response.description,
65
+ "conversation_id": response.conversation_id,
66
+ }
67
+ return json.dumps(data)
68
+
69
+
70
+ def create_genie_tool(
71
+ genie_room: GenieRoomModel | dict[str, Any],
72
+ name: str | None = None,
73
+ description: str | None = None,
74
+ persist_conversation: bool = True,
75
+ truncate_results: bool = False,
76
+ lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
77
+ semantic_cache_parameters: GenieSemanticCacheParametersModel
78
+ | dict[str, Any]
79
+ | None = None,
80
+ ) -> Callable[..., Command]:
81
+ """
82
+ Create a tool for interacting with Databricks Genie for natural language queries to databases.
83
+
84
+ This factory function generates a tool that leverages Databricks Genie to translate natural
85
+ language questions into SQL queries and execute them against retail databases. This enables
86
+ answering questions about inventory, sales, and other structured retail data.
87
+
88
+ Args:
89
+ genie_room: GenieRoomModel or dict containing Genie configuration
90
+ name: Optional custom name for the tool. If None, uses default "genie_tool"
91
+ description: Optional custom description for the tool. If None, uses default description
92
+ persist_conversation: Whether to persist conversation IDs across tool calls for
93
+ multi-turn conversations within the same Genie space
94
+ truncate_results: Whether to truncate large query results to fit token limits
95
+ lru_cache_parameters: Optional LRU cache configuration for SQL query caching
96
+ semantic_cache_parameters: Optional semantic cache configuration using pg_vector
97
+ for similarity-based query matching
98
+
99
+ Returns:
100
+ A LangGraph tool that processes natural language queries through Genie
101
+ """
102
+ logger.debug("create_genie_tool")
103
+ logger.debug(f"genie_room type: {type(genie_room)}")
104
+ logger.debug(f"genie_room: {genie_room}")
105
+ logger.debug(f"persist_conversation: {persist_conversation}")
106
+ logger.debug(f"truncate_results: {truncate_results}")
107
+ logger.debug(f"name: {name}")
108
+ logger.debug(f"description: {description}")
109
+ logger.debug(f"genie_room: {genie_room}")
110
+ logger.debug(f"persist_conversation: {persist_conversation}")
111
+ logger.debug(f"truncate_results: {truncate_results}")
112
+ logger.debug(f"lru_cache_parameters: {lru_cache_parameters}")
113
+ logger.debug(f"semantic_cache_parameters: {semantic_cache_parameters}")
114
+
115
+ if isinstance(genie_room, dict):
116
+ genie_room = GenieRoomModel(**genie_room)
117
+
118
+ if isinstance(lru_cache_parameters, dict):
119
+ lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
120
+
121
+ if isinstance(semantic_cache_parameters, dict):
122
+ semantic_cache_parameters = GenieSemanticCacheParametersModel(
123
+ **semantic_cache_parameters
124
+ )
125
+
126
+ space_id: AnyVariable = genie_room.space_id or os.environ.get(
127
+ "DATABRICKS_GENIE_SPACE_ID"
128
+ )
129
+ if isinstance(space_id, dict):
130
+ space_id = CompositeVariableModel(**space_id)
131
+ space_id = value_of(space_id)
132
+
133
+ default_description: str = dedent("""
134
+ This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
135
+ questions about the data and the tool will try to answer them.
136
+ Please ask simple clear questions that can be answer by sql queries. If you need to do statistics or other forms of testing defer to using another tool.
137
+ Try to ask for aggregations on the data and ask very simple questions.
138
+ Prefer to call this tool multiple times rather than asking a complex question.
139
+ """)
140
+
141
+ tool_description: str = (
142
+ description if description is not None else default_description
143
+ )
144
+ tool_name: str = name if name is not None else "genie_tool"
145
+
146
+ function_docs = """
147
+
148
+ Args:
149
+ question (str): The question to ask to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question.
150
+
151
+ Returns:
152
+ GenieResponse: A response object containing the conversation ID and result from Genie."""
153
+ tool_description = tool_description + function_docs
154
+
155
+ genie: Genie = Genie(
156
+ space_id=space_id,
157
+ client=genie_room.workspace_client,
158
+ truncate_results=truncate_results,
159
+ )
160
+
161
+ genie_service: GenieServiceBase = GenieService(genie)
162
+
163
+ # Wrap with semantic cache first (checked second due to decorator pattern)
164
+ if semantic_cache_parameters is not None:
165
+ genie_service = SemanticCacheService(
166
+ impl=genie_service,
167
+ parameters=semantic_cache_parameters,
168
+ genie_space_id=space_id,
169
+ )
170
+
171
+ # Wrap with LRU cache last (checked first - fast O(1) exact match)
172
+ if lru_cache_parameters is not None:
173
+ genie_service = LRUCacheService(
174
+ impl=genie_service,
175
+ parameters=lru_cache_parameters,
176
+ )
177
+
178
+ @tool(
179
+ name_or_callable=tool_name,
180
+ description=tool_description,
181
+ )
182
+ def genie_tool(
183
+ question: Annotated[str, "The question to ask Genie about your data"],
184
+ state: Annotated[dict, InjectedState],
185
+ tool_call_id: Annotated[str, InjectedToolCallId],
186
+ ) -> Command:
187
+ """Process a natural language question through Databricks Genie."""
188
+ # Get existing conversation mapping and retrieve conversation ID for this space
189
+ conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
190
+ existing_conversation_id: str | None = conversation_ids.get(space_id)
191
+ logger.debug(
192
+ f"Existing conversation ID for space {space_id}: {existing_conversation_id}"
193
+ )
194
+
195
+ response: GenieResponse = genie_service.ask_question(
196
+ question, conversation_id=existing_conversation_id
197
+ )
198
+
199
+ current_conversation_id: str = response.conversation_id
200
+ logger.debug(
201
+ f"Current conversation ID for space {space_id}: {current_conversation_id}"
202
+ )
203
+
204
+ # Update the conversation mapping with the new conversation ID for this space
205
+
206
+ update: dict[str, Any] = {
207
+ "messages": [
208
+ ToolMessage(_response_to_json(response), tool_call_id=tool_call_id)
209
+ ],
210
+ }
211
+
212
+ if persist_conversation:
213
+ updated_conversation_ids: dict[str, str] = conversation_ids.copy()
214
+ updated_conversation_ids[space_id] = current_conversation_id
215
+ update["genie_conversation_ids"] = updated_conversation_ids
216
+
217
+ return Command(update=update)
218
+
219
+ return genie_tool
220
+
221
+
222
+ # Re-export cache types for convenience
223
+ __all__ = [
224
+ # Main tool
225
+ "create_genie_tool",
226
+ # Input types
227
+ "GenieToolInput",
228
+ # Service base classes
229
+ "GenieService",
230
+ "GenieServiceBase",
231
+ # Cache types (from cache subpackage)
232
+ "CacheResult",
233
+ "LRUCacheService",
234
+ "SemanticCacheService",
235
+ "SQLCacheEntry",
236
+ ]
dao_ai/tools/genie.py CHANGED
@@ -1,3 +1,14 @@
1
+ """
2
+ Genie tool for natural language queries to databases.
3
+
4
+ This module provides the tool factory for creating LangGraph tools that
5
+ interact with Databricks Genie.
6
+
7
+ For the core Genie service and cache implementations, see:
8
+ - dao_ai.genie: GenieService, GenieServiceBase
9
+ - dao_ai.genie.cache: LRUCacheService, SemanticCacheService
10
+ """
11
+
1
12
  import json
2
13
  import os
3
14
  from textwrap import dedent
@@ -11,17 +22,24 @@ from langchain_core.tools import InjectedToolCallId
11
22
  from langgraph.prebuilt import InjectedState
12
23
  from langgraph.types import Command
13
24
  from loguru import logger
14
- from pydantic import BaseModel, Field
25
+ from pydantic import BaseModel
15
26
 
16
- from dao_ai.config import AnyVariable, CompositeVariableModel, GenieRoomModel, value_of
27
+ from dao_ai.config import (
28
+ AnyVariable,
29
+ CompositeVariableModel,
30
+ GenieLRUCacheParametersModel,
31
+ GenieRoomModel,
32
+ GenieSemanticCacheParametersModel,
33
+ value_of,
34
+ )
35
+ from dao_ai.genie import GenieService, GenieServiceBase
36
+ from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
17
37
 
18
38
 
19
39
  class GenieToolInput(BaseModel):
20
- """Input schema for the Genie tool."""
40
+ """Input schema for Genie tool - only includes user-facing parameters."""
21
41
 
22
- question: str = Field(
23
- description="The question to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question."
24
- )
42
+ question: str
25
43
 
26
44
 
27
45
  def _response_to_json(response: GenieResponse) -> str:
@@ -46,6 +64,10 @@ def create_genie_tool(
46
64
  description: str | None = None,
47
65
  persist_conversation: bool = True,
48
66
  truncate_results: bool = False,
67
+ lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
68
+ semantic_cache_parameters: GenieSemanticCacheParametersModel
69
+ | dict[str, Any]
70
+ | None = None,
49
71
  ) -> Callable[..., Command]:
50
72
  """
51
73
  Create a tool for interacting with Databricks Genie for natural language queries to databases.
@@ -61,6 +83,9 @@ def create_genie_tool(
61
83
  persist_conversation: Whether to persist conversation IDs across tool calls for
62
84
  multi-turn conversations within the same Genie space
63
85
  truncate_results: Whether to truncate large query results to fit token limits
86
+ lru_cache_parameters: Optional LRU cache configuration for SQL query caching
87
+ semantic_cache_parameters: Optional semantic cache configuration using pg_vector
88
+ for similarity-based query matching
64
89
 
65
90
  Returns:
66
91
  A LangGraph tool that processes natural language queries through Genie
@@ -75,10 +100,20 @@ def create_genie_tool(
75
100
  logger.debug(f"genie_room: {genie_room}")
76
101
  logger.debug(f"persist_conversation: {persist_conversation}")
77
102
  logger.debug(f"truncate_results: {truncate_results}")
103
+ logger.debug(f"lru_cache_parameters: {lru_cache_parameters}")
104
+ logger.debug(f"semantic_cache_parameters: {semantic_cache_parameters}")
78
105
 
79
106
  if isinstance(genie_room, dict):
80
107
  genie_room = GenieRoomModel(**genie_room)
81
108
 
109
+ if isinstance(lru_cache_parameters, dict):
110
+ lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
111
+
112
+ if isinstance(semantic_cache_parameters, dict):
113
+ semantic_cache_parameters = GenieSemanticCacheParametersModel(
114
+ **semantic_cache_parameters
115
+ )
116
+
82
117
  space_id: AnyVariable = genie_room.space_id or os.environ.get(
83
118
  "DATABRICKS_GENIE_SPACE_ID"
84
119
  )
@@ -108,6 +143,29 @@ Returns:
108
143
  GenieResponse: A response object containing the conversation ID and result from Genie."""
109
144
  tool_description = tool_description + function_docs
110
145
 
146
+ genie: Genie = Genie(
147
+ space_id=space_id,
148
+ client=genie_room.workspace_client,
149
+ truncate_results=truncate_results,
150
+ )
151
+
152
+ genie_service: GenieServiceBase = GenieService(genie)
153
+
154
+ # Wrap with semantic cache first (checked second due to decorator pattern)
155
+ if semantic_cache_parameters is not None:
156
+ genie_service = SemanticCacheService(
157
+ impl=genie_service,
158
+ parameters=semantic_cache_parameters,
159
+ genie_space_id=space_id,
160
+ ).initialize() # Eagerly initialize to fail fast and create table
161
+
162
+ # Wrap with LRU cache last (checked first - fast O(1) exact match)
163
+ if lru_cache_parameters is not None:
164
+ genie_service = LRUCacheService(
165
+ impl=genie_service,
166
+ parameters=lru_cache_parameters,
167
+ )
168
+
111
169
  @tool(
112
170
  name_or_callable=tool_name,
113
171
  description=tool_description,
@@ -117,12 +175,6 @@ GenieResponse: A response object containing the conversation ID and result from
117
175
  state: Annotated[dict, InjectedState],
118
176
  tool_call_id: Annotated[str, InjectedToolCallId],
119
177
  ) -> Command:
120
- genie: Genie = Genie(
121
- space_id=space_id,
122
- client=genie_room.workspace_client,
123
- truncate_results=truncate_results,
124
- )
125
-
126
178
  """Process a natural language question through Databricks Genie."""
127
179
  # Get existing conversation mapping and retrieve conversation ID for this space
128
180
  conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
@@ -131,7 +183,7 @@ GenieResponse: A response object containing the conversation ID and result from
131
183
  f"Existing conversation ID for space {space_id}: {existing_conversation_id}"
132
184
  )
133
185
 
134
- response: GenieResponse = genie.ask_question(
186
+ response: GenieResponse = genie_service.ask_question(
135
187
  question, conversation_id=existing_conversation_id
136
188
  )
137
189
 
@@ -153,8 +205,6 @@ GenieResponse: A response object containing the conversation ID and result from
153
205
  updated_conversation_ids[space_id] = current_conversation_id
154
206
  update["genie_conversation_ids"] = updated_conversation_ids
155
207
 
156
- logger.debug(f"State update: {update}")
157
-
158
208
  return Command(update=update)
159
209
 
160
210
  return genie_tool
@@ -15,6 +15,7 @@ from dao_ai.config import (
15
15
  value_of,
16
16
  )
17
17
  from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
18
+ from dao_ai.utils import normalize_host
18
19
 
19
20
 
20
21
  def create_uc_tools(
@@ -299,9 +300,11 @@ def with_partial_args(
299
300
  if "client_secret" not in resolved_args:
300
301
  resolved_args["client_secret"] = value_of(sp.client_secret)
301
302
 
302
- # Normalize host/workspace_host - accept either key
303
+ # Normalize host/workspace_host - accept either key, ensure https:// scheme
303
304
  if "workspace_host" in resolved_args and "host" not in resolved_args:
304
- resolved_args["host"] = resolved_args.pop("workspace_host")
305
+ resolved_args["host"] = normalize_host(resolved_args.pop("workspace_host"))
306
+ elif "host" in resolved_args:
307
+ resolved_args["host"] = normalize_host(resolved_args["host"])
305
308
 
306
309
  # Default host from WorkspaceClient if not provided
307
310
  if "host" not in resolved_args:
@@ -20,6 +20,7 @@ from dao_ai.config import (
20
20
  RetrieverModel,
21
21
  VectorStoreModel,
22
22
  )
23
+ from dao_ai.utils import normalize_host
23
24
 
24
25
 
25
26
  def create_vector_search_tool(
@@ -108,8 +109,9 @@ def create_vector_search_tool(
108
109
  # The workspace_client parameter in DatabricksVectorSearch is only used to detect
109
110
  # model serving mode - it doesn't pass credentials to VectorSearchClient.
110
111
  client_args: dict[str, Any] = {}
111
- if os.environ.get("DATABRICKS_HOST"):
112
- client_args["workspace_url"] = os.environ.get("DATABRICKS_HOST")
112
+ databricks_host = normalize_host(os.environ.get("DATABRICKS_HOST"))
113
+ if databricks_host:
114
+ client_args["workspace_url"] = databricks_host
113
115
  if os.environ.get("DATABRICKS_TOKEN"):
114
116
  client_args["personal_access_token"] = os.environ.get("DATABRICKS_TOKEN")
115
117
  if os.environ.get("DATABRICKS_CLIENT_ID"):
dao_ai/utils.py CHANGED
@@ -38,6 +38,30 @@ def normalize_name(name: str) -> str:
38
38
  return normalized.strip("_")
39
39
 
40
40
 
41
+ def normalize_host(host: str | None) -> str | None:
42
+ """Ensure host URL has https:// scheme.
43
+
44
+ The DATABRICKS_HOST environment variable should always include the https://
45
+ scheme, but some environments (e.g., Databricks Apps infrastructure) may
46
+ provide the host without it. This function normalizes the host to ensure
47
+ it has the proper scheme.
48
+
49
+ Args:
50
+ host: The host URL, with or without scheme
51
+
52
+ Returns:
53
+ The host URL with https:// scheme, or None if host is None/empty
54
+ """
55
+ if not host:
56
+ return None
57
+ host = host.strip()
58
+ if not host:
59
+ return None
60
+ if not host.startswith("http://") and not host.startswith("https://"):
61
+ return f"https://{host}"
62
+ return host
63
+
64
+
41
65
  def get_default_databricks_host() -> str | None:
42
66
  """Get the default Databricks workspace host.
43
67
 
@@ -46,19 +70,19 @@ def get_default_databricks_host() -> str | None:
46
70
  2. WorkspaceClient ambient authentication (e.g., from ~/.databrickscfg)
47
71
 
48
72
  Returns:
49
- The Databricks workspace host URL, or None if not available.
73
+ The Databricks workspace host URL (with https:// scheme), or None if not available.
50
74
  """
51
75
  # Try environment variable first
52
76
  host: str | None = os.environ.get("DATABRICKS_HOST")
53
77
  if host:
54
- return host
78
+ return normalize_host(host)
55
79
 
56
80
  # Fall back to WorkspaceClient
57
81
  try:
58
82
  from databricks.sdk import WorkspaceClient
59
83
 
60
84
  w: WorkspaceClient = WorkspaceClient()
61
- return w.config.host
85
+ return normalize_host(w.config.host)
62
86
  except Exception:
63
87
  logger.debug("Could not get default Databricks host from WorkspaceClient")
64
88
  return None