dao-ai 0.0.34__py3-none-any.whl → 0.0.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +38 -4
- dao_ai/genie/__init__.py +59 -0
- dao_ai/genie/cache/__init__.py +44 -0
- dao_ai/genie/cache/base.py +122 -0
- dao_ai/genie/cache/lru.py +306 -0
- dao_ai/genie/cache/semantic.py +638 -0
- dao_ai/providers/databricks.py +14 -7
- dao_ai/tools/__init__.py +3 -0
- dao_ai/tools/genie/__init__.py +236 -0
- dao_ai/tools/genie.py +65 -15
- dao_ai/tools/unity_catalog.py +5 -2
- dao_ai/tools/vector_search.py +4 -2
- dao_ai/utils.py +27 -3
- dao_ai-0.0.36.dist-info/METADATA +951 -0
- {dao_ai-0.0.34.dist-info → dao_ai-0.0.36.dist-info}/RECORD +18 -12
- dao_ai-0.0.34.dist-info/METADATA +0 -1169
- {dao_ai-0.0.34.dist-info → dao_ai-0.0.36.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.34.dist-info → dao_ai-0.0.36.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.34.dist-info → dao_ai-0.0.36.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
|
|
1
2
|
from dao_ai.hooks.core import create_hooks
|
|
2
3
|
from dao_ai.tools.agent import create_agent_endpoint_tool
|
|
3
4
|
from dao_ai.tools.core import (
|
|
@@ -35,7 +36,9 @@ __all__ = [
|
|
|
35
36
|
"current_time_tool",
|
|
36
37
|
"format_time_tool",
|
|
37
38
|
"is_business_hours_tool",
|
|
39
|
+
"LRUCacheService",
|
|
38
40
|
"search_tool",
|
|
41
|
+
"SemanticCacheService",
|
|
39
42
|
"time_difference_tool",
|
|
40
43
|
"time_in_timezone_tool",
|
|
41
44
|
"time_until_tool",
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Genie tools for natural language queries to databases.
|
|
3
|
+
|
|
4
|
+
This package provides tools for interacting with Databricks Genie to translate
|
|
5
|
+
natural language questions into SQL queries.
|
|
6
|
+
|
|
7
|
+
Main exports:
|
|
8
|
+
- create_genie_tool: Factory function to create a Genie tool with optional caching
|
|
9
|
+
|
|
10
|
+
Cache implementations are available in the genie cache package:
|
|
11
|
+
- dao_ai.genie.cache.lru: LRU (Least Recently Used) cache
|
|
12
|
+
- dao_ai.genie.cache.semantic: Semantic similarity cache using pg_vector
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import os
|
|
17
|
+
from textwrap import dedent
|
|
18
|
+
from typing import Annotated, Any, Callable
|
|
19
|
+
|
|
20
|
+
import pandas as pd
|
|
21
|
+
from databricks_ai_bridge.genie import Genie, GenieResponse
|
|
22
|
+
from langchain.tools import tool
|
|
23
|
+
from langchain_core.messages import ToolMessage
|
|
24
|
+
from langchain_core.tools import InjectedToolCallId
|
|
25
|
+
from langgraph.prebuilt import InjectedState
|
|
26
|
+
from langgraph.types import Command
|
|
27
|
+
from loguru import logger
|
|
28
|
+
from pydantic import BaseModel
|
|
29
|
+
|
|
30
|
+
from dao_ai.config import (
|
|
31
|
+
AnyVariable,
|
|
32
|
+
CompositeVariableModel,
|
|
33
|
+
GenieLRUCacheParametersModel,
|
|
34
|
+
GenieRoomModel,
|
|
35
|
+
GenieSemanticCacheParametersModel,
|
|
36
|
+
value_of,
|
|
37
|
+
)
|
|
38
|
+
from dao_ai.genie import GenieService
|
|
39
|
+
from dao_ai.genie.cache import (
|
|
40
|
+
CacheResult,
|
|
41
|
+
GenieServiceBase,
|
|
42
|
+
LRUCacheService,
|
|
43
|
+
SemanticCacheService,
|
|
44
|
+
SQLCacheEntry,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class GenieToolInput(BaseModel):
|
|
49
|
+
"""Input schema for Genie tool - only includes user-facing parameters."""
|
|
50
|
+
|
|
51
|
+
question: str
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _response_to_json(response: GenieResponse) -> str:
|
|
55
|
+
"""Convert GenieResponse to JSON string, handling DataFrame results."""
|
|
56
|
+
# Convert result to string if it's a DataFrame
|
|
57
|
+
result: str | pd.DataFrame = response.result
|
|
58
|
+
if isinstance(result, pd.DataFrame):
|
|
59
|
+
result = result.to_markdown()
|
|
60
|
+
|
|
61
|
+
data: dict[str, Any] = {
|
|
62
|
+
"result": result,
|
|
63
|
+
"query": response.query,
|
|
64
|
+
"description": response.description,
|
|
65
|
+
"conversation_id": response.conversation_id,
|
|
66
|
+
}
|
|
67
|
+
return json.dumps(data)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def create_genie_tool(
|
|
71
|
+
genie_room: GenieRoomModel | dict[str, Any],
|
|
72
|
+
name: str | None = None,
|
|
73
|
+
description: str | None = None,
|
|
74
|
+
persist_conversation: bool = True,
|
|
75
|
+
truncate_results: bool = False,
|
|
76
|
+
lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
|
|
77
|
+
semantic_cache_parameters: GenieSemanticCacheParametersModel
|
|
78
|
+
| dict[str, Any]
|
|
79
|
+
| None = None,
|
|
80
|
+
) -> Callable[..., Command]:
|
|
81
|
+
"""
|
|
82
|
+
Create a tool for interacting with Databricks Genie for natural language queries to databases.
|
|
83
|
+
|
|
84
|
+
This factory function generates a tool that leverages Databricks Genie to translate natural
|
|
85
|
+
language questions into SQL queries and execute them against retail databases. This enables
|
|
86
|
+
answering questions about inventory, sales, and other structured retail data.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
genie_room: GenieRoomModel or dict containing Genie configuration
|
|
90
|
+
name: Optional custom name for the tool. If None, uses default "genie_tool"
|
|
91
|
+
description: Optional custom description for the tool. If None, uses default description
|
|
92
|
+
persist_conversation: Whether to persist conversation IDs across tool calls for
|
|
93
|
+
multi-turn conversations within the same Genie space
|
|
94
|
+
truncate_results: Whether to truncate large query results to fit token limits
|
|
95
|
+
lru_cache_parameters: Optional LRU cache configuration for SQL query caching
|
|
96
|
+
semantic_cache_parameters: Optional semantic cache configuration using pg_vector
|
|
97
|
+
for similarity-based query matching
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
A LangGraph tool that processes natural language queries through Genie
|
|
101
|
+
"""
|
|
102
|
+
logger.debug("create_genie_tool")
|
|
103
|
+
logger.debug(f"genie_room type: {type(genie_room)}")
|
|
104
|
+
logger.debug(f"genie_room: {genie_room}")
|
|
105
|
+
logger.debug(f"persist_conversation: {persist_conversation}")
|
|
106
|
+
logger.debug(f"truncate_results: {truncate_results}")
|
|
107
|
+
logger.debug(f"name: {name}")
|
|
108
|
+
logger.debug(f"description: {description}")
|
|
109
|
+
logger.debug(f"genie_room: {genie_room}")
|
|
110
|
+
logger.debug(f"persist_conversation: {persist_conversation}")
|
|
111
|
+
logger.debug(f"truncate_results: {truncate_results}")
|
|
112
|
+
logger.debug(f"lru_cache_parameters: {lru_cache_parameters}")
|
|
113
|
+
logger.debug(f"semantic_cache_parameters: {semantic_cache_parameters}")
|
|
114
|
+
|
|
115
|
+
if isinstance(genie_room, dict):
|
|
116
|
+
genie_room = GenieRoomModel(**genie_room)
|
|
117
|
+
|
|
118
|
+
if isinstance(lru_cache_parameters, dict):
|
|
119
|
+
lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
|
|
120
|
+
|
|
121
|
+
if isinstance(semantic_cache_parameters, dict):
|
|
122
|
+
semantic_cache_parameters = GenieSemanticCacheParametersModel(
|
|
123
|
+
**semantic_cache_parameters
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
space_id: AnyVariable = genie_room.space_id or os.environ.get(
|
|
127
|
+
"DATABRICKS_GENIE_SPACE_ID"
|
|
128
|
+
)
|
|
129
|
+
if isinstance(space_id, dict):
|
|
130
|
+
space_id = CompositeVariableModel(**space_id)
|
|
131
|
+
space_id = value_of(space_id)
|
|
132
|
+
|
|
133
|
+
default_description: str = dedent("""
|
|
134
|
+
This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
|
|
135
|
+
questions about the data and the tool will try to answer them.
|
|
136
|
+
Please ask simple clear questions that can be answer by sql queries. If you need to do statistics or other forms of testing defer to using another tool.
|
|
137
|
+
Try to ask for aggregations on the data and ask very simple questions.
|
|
138
|
+
Prefer to call this tool multiple times rather than asking a complex question.
|
|
139
|
+
""")
|
|
140
|
+
|
|
141
|
+
tool_description: str = (
|
|
142
|
+
description if description is not None else default_description
|
|
143
|
+
)
|
|
144
|
+
tool_name: str = name if name is not None else "genie_tool"
|
|
145
|
+
|
|
146
|
+
function_docs = """
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
question (str): The question to ask to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
GenieResponse: A response object containing the conversation ID and result from Genie."""
|
|
153
|
+
tool_description = tool_description + function_docs
|
|
154
|
+
|
|
155
|
+
genie: Genie = Genie(
|
|
156
|
+
space_id=space_id,
|
|
157
|
+
client=genie_room.workspace_client,
|
|
158
|
+
truncate_results=truncate_results,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
genie_service: GenieServiceBase = GenieService(genie)
|
|
162
|
+
|
|
163
|
+
# Wrap with semantic cache first (checked second due to decorator pattern)
|
|
164
|
+
if semantic_cache_parameters is not None:
|
|
165
|
+
genie_service = SemanticCacheService(
|
|
166
|
+
impl=genie_service,
|
|
167
|
+
parameters=semantic_cache_parameters,
|
|
168
|
+
genie_space_id=space_id,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Wrap with LRU cache last (checked first - fast O(1) exact match)
|
|
172
|
+
if lru_cache_parameters is not None:
|
|
173
|
+
genie_service = LRUCacheService(
|
|
174
|
+
impl=genie_service,
|
|
175
|
+
parameters=lru_cache_parameters,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
@tool(
|
|
179
|
+
name_or_callable=tool_name,
|
|
180
|
+
description=tool_description,
|
|
181
|
+
)
|
|
182
|
+
def genie_tool(
|
|
183
|
+
question: Annotated[str, "The question to ask Genie about your data"],
|
|
184
|
+
state: Annotated[dict, InjectedState],
|
|
185
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
186
|
+
) -> Command:
|
|
187
|
+
"""Process a natural language question through Databricks Genie."""
|
|
188
|
+
# Get existing conversation mapping and retrieve conversation ID for this space
|
|
189
|
+
conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
|
|
190
|
+
existing_conversation_id: str | None = conversation_ids.get(space_id)
|
|
191
|
+
logger.debug(
|
|
192
|
+
f"Existing conversation ID for space {space_id}: {existing_conversation_id}"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
response: GenieResponse = genie_service.ask_question(
|
|
196
|
+
question, conversation_id=existing_conversation_id
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
current_conversation_id: str = response.conversation_id
|
|
200
|
+
logger.debug(
|
|
201
|
+
f"Current conversation ID for space {space_id}: {current_conversation_id}"
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Update the conversation mapping with the new conversation ID for this space
|
|
205
|
+
|
|
206
|
+
update: dict[str, Any] = {
|
|
207
|
+
"messages": [
|
|
208
|
+
ToolMessage(_response_to_json(response), tool_call_id=tool_call_id)
|
|
209
|
+
],
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
if persist_conversation:
|
|
213
|
+
updated_conversation_ids: dict[str, str] = conversation_ids.copy()
|
|
214
|
+
updated_conversation_ids[space_id] = current_conversation_id
|
|
215
|
+
update["genie_conversation_ids"] = updated_conversation_ids
|
|
216
|
+
|
|
217
|
+
return Command(update=update)
|
|
218
|
+
|
|
219
|
+
return genie_tool
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# Re-export cache types for convenience
|
|
223
|
+
__all__ = [
|
|
224
|
+
# Main tool
|
|
225
|
+
"create_genie_tool",
|
|
226
|
+
# Input types
|
|
227
|
+
"GenieToolInput",
|
|
228
|
+
# Service base classes
|
|
229
|
+
"GenieService",
|
|
230
|
+
"GenieServiceBase",
|
|
231
|
+
# Cache types (from cache subpackage)
|
|
232
|
+
"CacheResult",
|
|
233
|
+
"LRUCacheService",
|
|
234
|
+
"SemanticCacheService",
|
|
235
|
+
"SQLCacheEntry",
|
|
236
|
+
]
|
dao_ai/tools/genie.py
CHANGED
|
@@ -1,3 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Genie tool for natural language queries to databases.
|
|
3
|
+
|
|
4
|
+
This module provides the tool factory for creating LangGraph tools that
|
|
5
|
+
interact with Databricks Genie.
|
|
6
|
+
|
|
7
|
+
For the core Genie service and cache implementations, see:
|
|
8
|
+
- dao_ai.genie: GenieService, GenieServiceBase
|
|
9
|
+
- dao_ai.genie.cache: LRUCacheService, SemanticCacheService
|
|
10
|
+
"""
|
|
11
|
+
|
|
1
12
|
import json
|
|
2
13
|
import os
|
|
3
14
|
from textwrap import dedent
|
|
@@ -11,17 +22,24 @@ from langchain_core.tools import InjectedToolCallId
|
|
|
11
22
|
from langgraph.prebuilt import InjectedState
|
|
12
23
|
from langgraph.types import Command
|
|
13
24
|
from loguru import logger
|
|
14
|
-
from pydantic import BaseModel
|
|
25
|
+
from pydantic import BaseModel
|
|
15
26
|
|
|
16
|
-
from dao_ai.config import
|
|
27
|
+
from dao_ai.config import (
|
|
28
|
+
AnyVariable,
|
|
29
|
+
CompositeVariableModel,
|
|
30
|
+
GenieLRUCacheParametersModel,
|
|
31
|
+
GenieRoomModel,
|
|
32
|
+
GenieSemanticCacheParametersModel,
|
|
33
|
+
value_of,
|
|
34
|
+
)
|
|
35
|
+
from dao_ai.genie import GenieService, GenieServiceBase
|
|
36
|
+
from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
|
|
17
37
|
|
|
18
38
|
|
|
19
39
|
class GenieToolInput(BaseModel):
|
|
20
|
-
"""Input schema for
|
|
40
|
+
"""Input schema for Genie tool - only includes user-facing parameters."""
|
|
21
41
|
|
|
22
|
-
question: str
|
|
23
|
-
description="The question to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question."
|
|
24
|
-
)
|
|
42
|
+
question: str
|
|
25
43
|
|
|
26
44
|
|
|
27
45
|
def _response_to_json(response: GenieResponse) -> str:
|
|
@@ -46,6 +64,10 @@ def create_genie_tool(
|
|
|
46
64
|
description: str | None = None,
|
|
47
65
|
persist_conversation: bool = True,
|
|
48
66
|
truncate_results: bool = False,
|
|
67
|
+
lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
|
|
68
|
+
semantic_cache_parameters: GenieSemanticCacheParametersModel
|
|
69
|
+
| dict[str, Any]
|
|
70
|
+
| None = None,
|
|
49
71
|
) -> Callable[..., Command]:
|
|
50
72
|
"""
|
|
51
73
|
Create a tool for interacting with Databricks Genie for natural language queries to databases.
|
|
@@ -61,6 +83,9 @@ def create_genie_tool(
|
|
|
61
83
|
persist_conversation: Whether to persist conversation IDs across tool calls for
|
|
62
84
|
multi-turn conversations within the same Genie space
|
|
63
85
|
truncate_results: Whether to truncate large query results to fit token limits
|
|
86
|
+
lru_cache_parameters: Optional LRU cache configuration for SQL query caching
|
|
87
|
+
semantic_cache_parameters: Optional semantic cache configuration using pg_vector
|
|
88
|
+
for similarity-based query matching
|
|
64
89
|
|
|
65
90
|
Returns:
|
|
66
91
|
A LangGraph tool that processes natural language queries through Genie
|
|
@@ -75,10 +100,20 @@ def create_genie_tool(
|
|
|
75
100
|
logger.debug(f"genie_room: {genie_room}")
|
|
76
101
|
logger.debug(f"persist_conversation: {persist_conversation}")
|
|
77
102
|
logger.debug(f"truncate_results: {truncate_results}")
|
|
103
|
+
logger.debug(f"lru_cache_parameters: {lru_cache_parameters}")
|
|
104
|
+
logger.debug(f"semantic_cache_parameters: {semantic_cache_parameters}")
|
|
78
105
|
|
|
79
106
|
if isinstance(genie_room, dict):
|
|
80
107
|
genie_room = GenieRoomModel(**genie_room)
|
|
81
108
|
|
|
109
|
+
if isinstance(lru_cache_parameters, dict):
|
|
110
|
+
lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
|
|
111
|
+
|
|
112
|
+
if isinstance(semantic_cache_parameters, dict):
|
|
113
|
+
semantic_cache_parameters = GenieSemanticCacheParametersModel(
|
|
114
|
+
**semantic_cache_parameters
|
|
115
|
+
)
|
|
116
|
+
|
|
82
117
|
space_id: AnyVariable = genie_room.space_id or os.environ.get(
|
|
83
118
|
"DATABRICKS_GENIE_SPACE_ID"
|
|
84
119
|
)
|
|
@@ -108,6 +143,29 @@ Returns:
|
|
|
108
143
|
GenieResponse: A response object containing the conversation ID and result from Genie."""
|
|
109
144
|
tool_description = tool_description + function_docs
|
|
110
145
|
|
|
146
|
+
genie: Genie = Genie(
|
|
147
|
+
space_id=space_id,
|
|
148
|
+
client=genie_room.workspace_client,
|
|
149
|
+
truncate_results=truncate_results,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
genie_service: GenieServiceBase = GenieService(genie)
|
|
153
|
+
|
|
154
|
+
# Wrap with semantic cache first (checked second due to decorator pattern)
|
|
155
|
+
if semantic_cache_parameters is not None:
|
|
156
|
+
genie_service = SemanticCacheService(
|
|
157
|
+
impl=genie_service,
|
|
158
|
+
parameters=semantic_cache_parameters,
|
|
159
|
+
genie_space_id=space_id,
|
|
160
|
+
).initialize() # Eagerly initialize to fail fast and create table
|
|
161
|
+
|
|
162
|
+
# Wrap with LRU cache last (checked first - fast O(1) exact match)
|
|
163
|
+
if lru_cache_parameters is not None:
|
|
164
|
+
genie_service = LRUCacheService(
|
|
165
|
+
impl=genie_service,
|
|
166
|
+
parameters=lru_cache_parameters,
|
|
167
|
+
)
|
|
168
|
+
|
|
111
169
|
@tool(
|
|
112
170
|
name_or_callable=tool_name,
|
|
113
171
|
description=tool_description,
|
|
@@ -117,12 +175,6 @@ GenieResponse: A response object containing the conversation ID and result from
|
|
|
117
175
|
state: Annotated[dict, InjectedState],
|
|
118
176
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
119
177
|
) -> Command:
|
|
120
|
-
genie: Genie = Genie(
|
|
121
|
-
space_id=space_id,
|
|
122
|
-
client=genie_room.workspace_client,
|
|
123
|
-
truncate_results=truncate_results,
|
|
124
|
-
)
|
|
125
|
-
|
|
126
178
|
"""Process a natural language question through Databricks Genie."""
|
|
127
179
|
# Get existing conversation mapping and retrieve conversation ID for this space
|
|
128
180
|
conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
|
|
@@ -131,7 +183,7 @@ GenieResponse: A response object containing the conversation ID and result from
|
|
|
131
183
|
f"Existing conversation ID for space {space_id}: {existing_conversation_id}"
|
|
132
184
|
)
|
|
133
185
|
|
|
134
|
-
response: GenieResponse =
|
|
186
|
+
response: GenieResponse = genie_service.ask_question(
|
|
135
187
|
question, conversation_id=existing_conversation_id
|
|
136
188
|
)
|
|
137
189
|
|
|
@@ -153,8 +205,6 @@ GenieResponse: A response object containing the conversation ID and result from
|
|
|
153
205
|
updated_conversation_ids[space_id] = current_conversation_id
|
|
154
206
|
update["genie_conversation_ids"] = updated_conversation_ids
|
|
155
207
|
|
|
156
|
-
logger.debug(f"State update: {update}")
|
|
157
|
-
|
|
158
208
|
return Command(update=update)
|
|
159
209
|
|
|
160
210
|
return genie_tool
|
dao_ai/tools/unity_catalog.py
CHANGED
|
@@ -15,6 +15,7 @@ from dao_ai.config import (
|
|
|
15
15
|
value_of,
|
|
16
16
|
)
|
|
17
17
|
from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
|
|
18
|
+
from dao_ai.utils import normalize_host
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
def create_uc_tools(
|
|
@@ -299,9 +300,11 @@ def with_partial_args(
|
|
|
299
300
|
if "client_secret" not in resolved_args:
|
|
300
301
|
resolved_args["client_secret"] = value_of(sp.client_secret)
|
|
301
302
|
|
|
302
|
-
# Normalize host/workspace_host - accept either key
|
|
303
|
+
# Normalize host/workspace_host - accept either key, ensure https:// scheme
|
|
303
304
|
if "workspace_host" in resolved_args and "host" not in resolved_args:
|
|
304
|
-
resolved_args["host"] = resolved_args.pop("workspace_host")
|
|
305
|
+
resolved_args["host"] = normalize_host(resolved_args.pop("workspace_host"))
|
|
306
|
+
elif "host" in resolved_args:
|
|
307
|
+
resolved_args["host"] = normalize_host(resolved_args["host"])
|
|
305
308
|
|
|
306
309
|
# Default host from WorkspaceClient if not provided
|
|
307
310
|
if "host" not in resolved_args:
|
dao_ai/tools/vector_search.py
CHANGED
|
@@ -20,6 +20,7 @@ from dao_ai.config import (
|
|
|
20
20
|
RetrieverModel,
|
|
21
21
|
VectorStoreModel,
|
|
22
22
|
)
|
|
23
|
+
from dao_ai.utils import normalize_host
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
def create_vector_search_tool(
|
|
@@ -108,8 +109,9 @@ def create_vector_search_tool(
|
|
|
108
109
|
# The workspace_client parameter in DatabricksVectorSearch is only used to detect
|
|
109
110
|
# model serving mode - it doesn't pass credentials to VectorSearchClient.
|
|
110
111
|
client_args: dict[str, Any] = {}
|
|
111
|
-
|
|
112
|
-
|
|
112
|
+
databricks_host = normalize_host(os.environ.get("DATABRICKS_HOST"))
|
|
113
|
+
if databricks_host:
|
|
114
|
+
client_args["workspace_url"] = databricks_host
|
|
113
115
|
if os.environ.get("DATABRICKS_TOKEN"):
|
|
114
116
|
client_args["personal_access_token"] = os.environ.get("DATABRICKS_TOKEN")
|
|
115
117
|
if os.environ.get("DATABRICKS_CLIENT_ID"):
|
dao_ai/utils.py
CHANGED
|
@@ -38,6 +38,30 @@ def normalize_name(name: str) -> str:
|
|
|
38
38
|
return normalized.strip("_")
|
|
39
39
|
|
|
40
40
|
|
|
41
|
+
def normalize_host(host: str | None) -> str | None:
|
|
42
|
+
"""Ensure host URL has https:// scheme.
|
|
43
|
+
|
|
44
|
+
The DATABRICKS_HOST environment variable should always include the https://
|
|
45
|
+
scheme, but some environments (e.g., Databricks Apps infrastructure) may
|
|
46
|
+
provide the host without it. This function normalizes the host to ensure
|
|
47
|
+
it has the proper scheme.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
host: The host URL, with or without scheme
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
The host URL with https:// scheme, or None if host is None/empty
|
|
54
|
+
"""
|
|
55
|
+
if not host:
|
|
56
|
+
return None
|
|
57
|
+
host = host.strip()
|
|
58
|
+
if not host:
|
|
59
|
+
return None
|
|
60
|
+
if not host.startswith("http://") and not host.startswith("https://"):
|
|
61
|
+
return f"https://{host}"
|
|
62
|
+
return host
|
|
63
|
+
|
|
64
|
+
|
|
41
65
|
def get_default_databricks_host() -> str | None:
|
|
42
66
|
"""Get the default Databricks workspace host.
|
|
43
67
|
|
|
@@ -46,19 +70,19 @@ def get_default_databricks_host() -> str | None:
|
|
|
46
70
|
2. WorkspaceClient ambient authentication (e.g., from ~/.databrickscfg)
|
|
47
71
|
|
|
48
72
|
Returns:
|
|
49
|
-
The Databricks workspace host URL, or None if not available.
|
|
73
|
+
The Databricks workspace host URL (with https:// scheme), or None if not available.
|
|
50
74
|
"""
|
|
51
75
|
# Try environment variable first
|
|
52
76
|
host: str | None = os.environ.get("DATABRICKS_HOST")
|
|
53
77
|
if host:
|
|
54
|
-
return host
|
|
78
|
+
return normalize_host(host)
|
|
55
79
|
|
|
56
80
|
# Fall back to WorkspaceClient
|
|
57
81
|
try:
|
|
58
82
|
from databricks.sdk import WorkspaceClient
|
|
59
83
|
|
|
60
84
|
w: WorkspaceClient = WorkspaceClient()
|
|
61
|
-
return w.config.host
|
|
85
|
+
return normalize_host(w.config.host)
|
|
62
86
|
except Exception:
|
|
63
87
|
logger.debug("Could not get default Databricks host from WorkspaceClient")
|
|
64
88
|
return None
|