dao-ai 0.0.35__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +797 -242
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +329 -0
- dao_ai/genie/cache/semantic.py +919 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +11 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +108 -35
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.0.dist-info/METADATA +1878 -0
- dao_ai-0.1.0.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.35.dist-info/METADATA +0 -1169
- dao_ai-0.0.35.dist-info/RECORD +0 -41
- {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/genie.py
CHANGED
|
@@ -1,3 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Genie tool for natural language queries to databases.
|
|
3
|
+
|
|
4
|
+
This module provides the tool factory for creating LangGraph tools that
|
|
5
|
+
interact with Databricks Genie.
|
|
6
|
+
|
|
7
|
+
For the core Genie service and cache implementations, see:
|
|
8
|
+
- dao_ai.genie: GenieService, GenieServiceBase
|
|
9
|
+
- dao_ai.genie.cache: LRUCacheService, SemanticCacheService
|
|
10
|
+
"""
|
|
11
|
+
|
|
1
12
|
import json
|
|
2
13
|
import os
|
|
3
14
|
from textwrap import dedent
|
|
@@ -5,23 +16,29 @@ from typing import Annotated, Any, Callable
|
|
|
5
16
|
|
|
6
17
|
import pandas as pd
|
|
7
18
|
from databricks_ai_bridge.genie import Genie, GenieResponse
|
|
8
|
-
from langchain.tools import tool
|
|
19
|
+
from langchain.tools import ToolRuntime, tool
|
|
9
20
|
from langchain_core.messages import ToolMessage
|
|
10
|
-
from langchain_core.tools import InjectedToolCallId
|
|
11
|
-
from langgraph.prebuilt import InjectedState
|
|
12
21
|
from langgraph.types import Command
|
|
13
22
|
from loguru import logger
|
|
14
|
-
from pydantic import BaseModel
|
|
23
|
+
from pydantic import BaseModel
|
|
15
24
|
|
|
16
|
-
from dao_ai.config import
|
|
25
|
+
from dao_ai.config import (
|
|
26
|
+
AnyVariable,
|
|
27
|
+
CompositeVariableModel,
|
|
28
|
+
GenieLRUCacheParametersModel,
|
|
29
|
+
GenieRoomModel,
|
|
30
|
+
GenieSemanticCacheParametersModel,
|
|
31
|
+
value_of,
|
|
32
|
+
)
|
|
33
|
+
from dao_ai.genie import GenieService, GenieServiceBase
|
|
34
|
+
from dao_ai.genie.cache import CacheResult, LRUCacheService, SemanticCacheService
|
|
35
|
+
from dao_ai.state import AgentState, Context, SessionState
|
|
17
36
|
|
|
18
37
|
|
|
19
38
|
class GenieToolInput(BaseModel):
|
|
20
|
-
"""Input schema for
|
|
39
|
+
"""Input schema for Genie tool - only includes user-facing parameters."""
|
|
21
40
|
|
|
22
|
-
question: str
|
|
23
|
-
description="The question to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question."
|
|
24
|
-
)
|
|
41
|
+
question: str
|
|
25
42
|
|
|
26
43
|
|
|
27
44
|
def _response_to_json(response: GenieResponse) -> str:
|
|
@@ -46,6 +63,10 @@ def create_genie_tool(
|
|
|
46
63
|
description: str | None = None,
|
|
47
64
|
persist_conversation: bool = True,
|
|
48
65
|
truncate_results: bool = False,
|
|
66
|
+
lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
|
|
67
|
+
semantic_cache_parameters: GenieSemanticCacheParametersModel
|
|
68
|
+
| dict[str, Any]
|
|
69
|
+
| None = None,
|
|
49
70
|
) -> Callable[..., Command]:
|
|
50
71
|
"""
|
|
51
72
|
Create a tool for interacting with Databricks Genie for natural language queries to databases.
|
|
@@ -61,6 +82,9 @@ def create_genie_tool(
|
|
|
61
82
|
persist_conversation: Whether to persist conversation IDs across tool calls for
|
|
62
83
|
multi-turn conversations within the same Genie space
|
|
63
84
|
truncate_results: Whether to truncate large query results to fit token limits
|
|
85
|
+
lru_cache_parameters: Optional LRU cache configuration for SQL query caching
|
|
86
|
+
semantic_cache_parameters: Optional semantic cache configuration using pg_vector
|
|
87
|
+
for similarity-based query matching
|
|
64
88
|
|
|
65
89
|
Returns:
|
|
66
90
|
A LangGraph tool that processes natural language queries through Genie
|
|
@@ -72,13 +96,20 @@ def create_genie_tool(
|
|
|
72
96
|
logger.debug(f"truncate_results: {truncate_results}")
|
|
73
97
|
logger.debug(f"name: {name}")
|
|
74
98
|
logger.debug(f"description: {description}")
|
|
75
|
-
logger.debug(f"
|
|
76
|
-
logger.debug(f"
|
|
77
|
-
logger.debug(f"truncate_results: {truncate_results}")
|
|
99
|
+
logger.debug(f"lru_cache_parameters: {lru_cache_parameters}")
|
|
100
|
+
logger.debug(f"semantic_cache_parameters: {semantic_cache_parameters}")
|
|
78
101
|
|
|
79
102
|
if isinstance(genie_room, dict):
|
|
80
103
|
genie_room = GenieRoomModel(**genie_room)
|
|
81
104
|
|
|
105
|
+
if isinstance(lru_cache_parameters, dict):
|
|
106
|
+
lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
|
|
107
|
+
|
|
108
|
+
if isinstance(semantic_cache_parameters, dict):
|
|
109
|
+
semantic_cache_parameters = GenieSemanticCacheParametersModel(
|
|
110
|
+
**semantic_cache_parameters
|
|
111
|
+
)
|
|
112
|
+
|
|
82
113
|
space_id: AnyVariable = genie_room.space_id or os.environ.get(
|
|
83
114
|
"DATABRICKS_GENIE_SPACE_ID"
|
|
84
115
|
)
|
|
@@ -108,52 +139,94 @@ Returns:
|
|
|
108
139
|
GenieResponse: A response object containing the conversation ID and result from Genie."""
|
|
109
140
|
tool_description = tool_description + function_docs
|
|
110
141
|
|
|
142
|
+
genie: Genie = Genie(
|
|
143
|
+
space_id=space_id,
|
|
144
|
+
client=genie_room.workspace_client,
|
|
145
|
+
truncate_results=truncate_results,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
genie_service: GenieServiceBase = GenieService(genie)
|
|
149
|
+
|
|
150
|
+
# Wrap with semantic cache first (checked second due to decorator pattern)
|
|
151
|
+
if semantic_cache_parameters is not None:
|
|
152
|
+
genie_service = SemanticCacheService(
|
|
153
|
+
impl=genie_service,
|
|
154
|
+
parameters=semantic_cache_parameters,
|
|
155
|
+
workspace_client=genie_room.workspace_client, # Pass workspace client for conversation history
|
|
156
|
+
).initialize() # Eagerly initialize to fail fast and create table
|
|
157
|
+
|
|
158
|
+
# Wrap with LRU cache last (checked first - fast O(1) exact match)
|
|
159
|
+
if lru_cache_parameters is not None:
|
|
160
|
+
genie_service = LRUCacheService(
|
|
161
|
+
impl=genie_service,
|
|
162
|
+
parameters=lru_cache_parameters,
|
|
163
|
+
)
|
|
164
|
+
|
|
111
165
|
@tool(
|
|
112
166
|
name_or_callable=tool_name,
|
|
113
167
|
description=tool_description,
|
|
114
168
|
)
|
|
115
169
|
def genie_tool(
|
|
116
170
|
question: Annotated[str, "The question to ask Genie about your data"],
|
|
117
|
-
|
|
118
|
-
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
171
|
+
runtime: ToolRuntime[Context, AgentState],
|
|
119
172
|
) -> Command:
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
173
|
+
"""Process a natural language question through Databricks Genie.
|
|
174
|
+
|
|
175
|
+
Uses ToolRuntime to access state and context in a type-safe way.
|
|
176
|
+
"""
|
|
177
|
+
# Access state through runtime
|
|
178
|
+
state: AgentState = runtime.state
|
|
179
|
+
tool_call_id: str = runtime.tool_call_id
|
|
180
|
+
|
|
181
|
+
# Ensure space_id is a string for state keys
|
|
182
|
+
space_id_str: str = str(space_id)
|
|
125
183
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
184
|
+
# Get session state (or create new one)
|
|
185
|
+
session: SessionState = state.get("session", SessionState())
|
|
186
|
+
|
|
187
|
+
# Get existing conversation ID from session
|
|
188
|
+
existing_conversation_id: str | None = session.genie.get_conversation_id(
|
|
189
|
+
space_id_str
|
|
190
|
+
)
|
|
130
191
|
logger.debug(
|
|
131
|
-
f"Existing conversation ID for space {
|
|
192
|
+
f"Existing conversation ID for space {space_id_str}: {existing_conversation_id}"
|
|
132
193
|
)
|
|
133
194
|
|
|
134
|
-
|
|
195
|
+
# Call ask_question which always returns CacheResult with cache metadata
|
|
196
|
+
cache_result: CacheResult = genie_service.ask_question(
|
|
135
197
|
question, conversation_id=existing_conversation_id
|
|
136
198
|
)
|
|
199
|
+
genie_response: GenieResponse = cache_result.response
|
|
200
|
+
cache_hit: bool = cache_result.cache_hit
|
|
201
|
+
cache_key: str | None = cache_result.served_by
|
|
137
202
|
|
|
138
|
-
current_conversation_id: str =
|
|
203
|
+
current_conversation_id: str = genie_response.conversation_id
|
|
139
204
|
logger.debug(
|
|
140
|
-
f"Current conversation ID for space {
|
|
205
|
+
f"Current conversation ID for space {space_id_str}: {current_conversation_id}, "
|
|
206
|
+
f"cache_hit: {cache_hit}, cache_key: {cache_key}"
|
|
141
207
|
)
|
|
142
208
|
|
|
143
|
-
# Update
|
|
144
|
-
|
|
209
|
+
# Update session state with cache information
|
|
210
|
+
if persist_conversation:
|
|
211
|
+
session.genie.update_space(
|
|
212
|
+
space_id=space_id_str,
|
|
213
|
+
conversation_id=current_conversation_id,
|
|
214
|
+
cache_hit=cache_hit,
|
|
215
|
+
cache_key=cache_key,
|
|
216
|
+
last_query=question,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Build update dict with response and session
|
|
145
220
|
update: dict[str, Any] = {
|
|
146
221
|
"messages": [
|
|
147
|
-
ToolMessage(
|
|
222
|
+
ToolMessage(
|
|
223
|
+
_response_to_json(genie_response), tool_call_id=tool_call_id
|
|
224
|
+
)
|
|
148
225
|
],
|
|
149
226
|
}
|
|
150
227
|
|
|
151
228
|
if persist_conversation:
|
|
152
|
-
|
|
153
|
-
updated_conversation_ids[space_id] = current_conversation_id
|
|
154
|
-
update["genie_conversation_ids"] = updated_conversation_ids
|
|
155
|
-
|
|
156
|
-
logger.debug(f"State update: {update}")
|
|
229
|
+
update["session"] = session
|
|
157
230
|
|
|
158
231
|
return Command(update=update)
|
|
159
232
|
|
dao_ai/tools/mcp.py
CHANGED
|
@@ -14,7 +14,6 @@ from dao_ai.config import (
|
|
|
14
14
|
McpFunctionModel,
|
|
15
15
|
TransportType,
|
|
16
16
|
)
|
|
17
|
-
from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
def create_mcp_tools(
|
|
@@ -95,7 +94,8 @@ def create_mcp_tools(
|
|
|
95
94
|
logger.error(f"MCP tool {mcp_tool.name} failed: {e}")
|
|
96
95
|
raise
|
|
97
96
|
|
|
98
|
-
|
|
97
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
98
|
+
return tool_wrapper
|
|
99
99
|
|
|
100
100
|
return [_create_tool_wrapper_with_connection(tool) for tool in mcp_tools]
|
|
101
101
|
|
|
@@ -190,6 +190,7 @@ def create_mcp_tools(
|
|
|
190
190
|
logger.error(f"MCP tool {mcp_tool.name} failed: {e}")
|
|
191
191
|
raise
|
|
192
192
|
|
|
193
|
-
|
|
193
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
194
|
+
return tool_wrapper
|
|
194
195
|
|
|
195
196
|
return [_create_tool_wrapper(tool) for tool in mcp_tools]
|
dao_ai/tools/memory.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Memory tools for DAO AI."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from langchain_core.tools import BaseTool, StructuredTool
|
|
6
|
+
from langmem import create_search_memory_tool as langmem_create_search_memory_tool
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def create_search_memory_tool(namespace: tuple[str, ...]) -> BaseTool:
|
|
11
|
+
"""
|
|
12
|
+
Create a Databricks-compatible search_memory tool.
|
|
13
|
+
|
|
14
|
+
The langmem search_memory tool has a 'filter' field with additionalProperties: true
|
|
15
|
+
in its schema, which Databricks LLM endpoints reject. This function creates a
|
|
16
|
+
wrapper tool that omits the problematic filter field.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
namespace: The memory namespace tuple
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
A StructuredTool compatible with Databricks
|
|
23
|
+
"""
|
|
24
|
+
# Get the original tool
|
|
25
|
+
original_tool = langmem_create_search_memory_tool(namespace=namespace)
|
|
26
|
+
|
|
27
|
+
# Create a schema without the problematic filter field
|
|
28
|
+
class SearchMemoryInput(BaseModel):
|
|
29
|
+
"""Input for search_memory tool."""
|
|
30
|
+
|
|
31
|
+
query: str = Field(..., description="The search query")
|
|
32
|
+
limit: int = Field(default=10, description="Maximum number of results")
|
|
33
|
+
offset: int = Field(default=0, description="Offset for pagination")
|
|
34
|
+
|
|
35
|
+
# Create a wrapper function
|
|
36
|
+
async def search_memory_wrapper(
|
|
37
|
+
query: str, limit: int = 10, offset: int = 0
|
|
38
|
+
) -> Any:
|
|
39
|
+
"""Search your long-term memories for information relevant to your current context."""
|
|
40
|
+
return await original_tool.ainvoke(
|
|
41
|
+
{"query": query, "limit": limit, "offset": offset}
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Create the new tool
|
|
45
|
+
return StructuredTool.from_function(
|
|
46
|
+
coroutine=search_memory_wrapper,
|
|
47
|
+
name="search_memory",
|
|
48
|
+
description="Search your long-term memories for information relevant to your current context.",
|
|
49
|
+
args_schema=SearchMemoryInput,
|
|
50
|
+
)
|
dao_ai/tools/python.py
CHANGED
|
@@ -7,7 +7,6 @@ from dao_ai.config import (
|
|
|
7
7
|
FactoryFunctionModel,
|
|
8
8
|
PythonFunctionModel,
|
|
9
9
|
)
|
|
10
|
-
from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
|
|
11
10
|
from dao_ai.utils import load_function
|
|
12
11
|
|
|
13
12
|
|
|
@@ -25,11 +24,8 @@ def create_factory_tool(
|
|
|
25
24
|
logger.debug(f"create_factory_tool: {function}")
|
|
26
25
|
|
|
27
26
|
factory: Callable[..., Any] = load_function(function_name=function.full_name)
|
|
28
|
-
tool:
|
|
29
|
-
|
|
30
|
-
tool=tool,
|
|
31
|
-
function=function,
|
|
32
|
-
)
|
|
27
|
+
tool: RunnableLike = factory(**function.args)
|
|
28
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
33
29
|
return tool
|
|
34
30
|
|
|
35
31
|
|
|
@@ -51,10 +47,6 @@ def create_python_tool(
|
|
|
51
47
|
function = function.full_name
|
|
52
48
|
|
|
53
49
|
# Load the Python function dynamically
|
|
54
|
-
tool:
|
|
55
|
-
|
|
56
|
-
tool = as_human_in_the_loop(
|
|
57
|
-
tool=tool,
|
|
58
|
-
function=function,
|
|
59
|
-
)
|
|
50
|
+
tool: RunnableLike = load_function(function_name=function)
|
|
51
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
60
52
|
return tool
|
dao_ai/tools/search.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from langchain_community.tools import DuckDuckGoSearchRun
|
|
2
|
+
from langchain_core.runnables.base import RunnableLike
|
|
3
|
+
from loguru import logger
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def create_search_tool() -> RunnableLike:
|
|
7
|
+
"""
|
|
8
|
+
Create a DuckDuckGo search tool.
|
|
9
|
+
|
|
10
|
+
Returns:
|
|
11
|
+
RunnableLike: A DuckDuckGo search tool that returns results as a list
|
|
12
|
+
"""
|
|
13
|
+
logger.debug("Creating DuckDuckGo search tool")
|
|
14
|
+
return DuckDuckGoSearchRun(output_format="list")
|
dao_ai/tools/slack.py
CHANGED
|
@@ -71,7 +71,7 @@ def create_send_slack_message_tool(
|
|
|
71
71
|
channel_name: Optional[str] = None,
|
|
72
72
|
name: Optional[str] = None,
|
|
73
73
|
description: Optional[str] = None,
|
|
74
|
-
) -> Callable[[str],
|
|
74
|
+
) -> Callable[[str], str]:
|
|
75
75
|
"""
|
|
76
76
|
Create a tool that sends a message to a Slack channel.
|
|
77
77
|
|
dao_ai/tools/unity_catalog.py
CHANGED
|
@@ -6,6 +6,7 @@ from databricks_langchain import DatabricksFunctionClient, UCFunctionToolkit
|
|
|
6
6
|
from langchain_core.runnables.base import RunnableLike
|
|
7
7
|
from langchain_core.tools import StructuredTool
|
|
8
8
|
from loguru import logger
|
|
9
|
+
from unitycatalog.ai.core.base import FunctionExecutionResult
|
|
9
10
|
|
|
10
11
|
from dao_ai.config import (
|
|
11
12
|
AnyVariable,
|
|
@@ -14,7 +15,6 @@ from dao_ai.config import (
|
|
|
14
15
|
UnityCatalogFunctionModel,
|
|
15
16
|
value_of,
|
|
16
17
|
)
|
|
17
|
-
from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
|
|
18
18
|
from dao_ai.utils import normalize_host
|
|
19
19
|
|
|
20
20
|
|
|
@@ -65,8 +65,8 @@ def create_uc_tools(
|
|
|
65
65
|
tools = toolkit.tools or []
|
|
66
66
|
logger.debug(f"Retrieved tools: {tools}")
|
|
67
67
|
|
|
68
|
-
#
|
|
69
|
-
return
|
|
68
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
69
|
+
return list(tools)
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
def _execute_uc_function(
|
|
@@ -87,14 +87,16 @@ def _execute_uc_function(
|
|
|
87
87
|
f"Calling UC function {function_name} with parameters: {list(all_params.keys())}"
|
|
88
88
|
)
|
|
89
89
|
|
|
90
|
-
result = client.execute_function(
|
|
90
|
+
result: FunctionExecutionResult = client.execute_function(
|
|
91
|
+
function_name=function_name, parameters=all_params
|
|
92
|
+
)
|
|
91
93
|
|
|
92
94
|
# Handle errors and extract result
|
|
93
|
-
if
|
|
95
|
+
if result.error:
|
|
94
96
|
logger.error(f"Unity Catalog function error: {result.error}")
|
|
95
97
|
raise RuntimeError(f"Function execution failed: {result.error}")
|
|
96
98
|
|
|
97
|
-
result_value: str = result.value if
|
|
99
|
+
result_value: str = result.value if result.value is not None else str(result)
|
|
98
100
|
logger.debug(f"UC function result: {result_value}")
|
|
99
101
|
return result_value
|
|
100
102
|
|
dao_ai/tools/vector_search.py
CHANGED
|
@@ -1,5 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vector search tool for retrieving documents from Databricks Vector Search.
|
|
3
|
+
|
|
4
|
+
This module provides a tool factory for creating semantic search tools
|
|
5
|
+
using ToolRuntime[Context, AgentState] for type-safe runtime access.
|
|
6
|
+
"""
|
|
7
|
+
|
|
1
8
|
import os
|
|
2
|
-
from typing import
|
|
9
|
+
from typing import Any, Callable, List, Optional, Sequence
|
|
3
10
|
|
|
4
11
|
import mlflow
|
|
5
12
|
from databricks.vector_search.reranker import DatabricksReranker
|
|
@@ -9,9 +16,8 @@ from databricks_ai_bridge.vector_search_retriever_tool import (
|
|
|
9
16
|
)
|
|
10
17
|
from databricks_langchain.vectorstores import DatabricksVectorSearch
|
|
11
18
|
from flashrank import Ranker, RerankRequest
|
|
19
|
+
from langchain.tools import ToolRuntime, tool
|
|
12
20
|
from langchain_core.documents import Document
|
|
13
|
-
from langchain_core.tools import InjectedToolCallId, tool
|
|
14
|
-
from langgraph.prebuilt import InjectedState
|
|
15
21
|
from loguru import logger
|
|
16
22
|
from mlflow.entities import SpanType
|
|
17
23
|
|
|
@@ -20,6 +26,7 @@ from dao_ai.config import (
|
|
|
20
26
|
RetrieverModel,
|
|
21
27
|
VectorStoreModel,
|
|
22
28
|
)
|
|
29
|
+
from dao_ai.state import AgentState, Context
|
|
23
30
|
from dao_ai.utils import normalize_host
|
|
24
31
|
|
|
25
32
|
|
|
@@ -27,7 +34,7 @@ def create_vector_search_tool(
|
|
|
27
34
|
retriever: RetrieverModel | dict[str, Any],
|
|
28
35
|
name: Optional[str] = None,
|
|
29
36
|
description: Optional[str] = None,
|
|
30
|
-
) -> Callable:
|
|
37
|
+
) -> Callable[..., list[dict[str, Any]]]:
|
|
31
38
|
"""
|
|
32
39
|
Create a Vector Search tool for retrieving documents from a Databricks Vector Search index.
|
|
33
40
|
|
|
@@ -254,8 +261,7 @@ def create_vector_search_tool(
|
|
|
254
261
|
return reranked_docs
|
|
255
262
|
|
|
256
263
|
# Create the main vector search tool using @tool decorator
|
|
257
|
-
#
|
|
258
|
-
# so Annotated is only needed for injected LangGraph parameters
|
|
264
|
+
# Uses ToolRuntime[Context, AgentState] for type-safe runtime access
|
|
259
265
|
@tool(
|
|
260
266
|
name_or_callable=name or index_name,
|
|
261
267
|
description=description or "Search for documents using vector similarity",
|
|
@@ -264,8 +270,7 @@ def create_vector_search_tool(
|
|
|
264
270
|
def vector_search_tool(
|
|
265
271
|
query: str,
|
|
266
272
|
filters: Optional[List[FilterItem]] = None,
|
|
267
|
-
|
|
268
|
-
tool_call_id: Annotated[str, InjectedToolCallId] = None,
|
|
273
|
+
runtime: ToolRuntime[Context, AgentState] = None,
|
|
269
274
|
) -> list[dict[str, Any]]:
|
|
270
275
|
"""
|
|
271
276
|
Search for documents using vector similarity with optional reranking.
|
|
@@ -276,8 +281,10 @@ def create_vector_search_tool(
|
|
|
276
281
|
|
|
277
282
|
Both stages are traced in MLflow for observability.
|
|
278
283
|
|
|
284
|
+
Uses ToolRuntime[Context, AgentState] for type-safe runtime access.
|
|
285
|
+
|
|
279
286
|
Returns:
|
|
280
|
-
|
|
287
|
+
List of serialized documents with page_content and metadata
|
|
281
288
|
"""
|
|
282
289
|
logger.debug(
|
|
283
290
|
f"Vector search tool called: query='{query[:50]}...', reranking={reranker_config is not None}"
|
dao_ai/utils.py
CHANGED
|
@@ -7,6 +7,7 @@ from importlib.metadata import PackageNotFoundError, version
|
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from typing import Any, Callable, Sequence
|
|
9
9
|
|
|
10
|
+
from langchain_core.tools import BaseTool
|
|
10
11
|
from loguru import logger
|
|
11
12
|
|
|
12
13
|
import dao_ai
|
|
@@ -19,7 +20,7 @@ def is_lib_provided(lib_name: str, pip_requirements: Sequence[str]) -> bool:
|
|
|
19
20
|
)
|
|
20
21
|
|
|
21
22
|
|
|
22
|
-
def is_installed():
|
|
23
|
+
def is_installed() -> bool:
|
|
23
24
|
current_file = os.path.abspath(dao_ai.__file__)
|
|
24
25
|
site_packages = [os.path.abspath(path) for path in site.getsitepackages()]
|
|
25
26
|
if site.getusersitepackages():
|
|
@@ -157,9 +158,6 @@ def get_installed_packages() -> dict[str, str]:
|
|
|
157
158
|
f"langchain-tavily=={version('langchain-tavily')}",
|
|
158
159
|
f"langgraph=={version('langgraph')}",
|
|
159
160
|
f"langgraph-checkpoint-postgres=={version('langgraph-checkpoint-postgres')}",
|
|
160
|
-
f"langgraph-prebuilt=={version('langgraph-prebuilt')}",
|
|
161
|
-
f"langgraph-supervisor=={version('langgraph-supervisor')}",
|
|
162
|
-
f"langgraph-swarm=={version('langgraph-swarm')}",
|
|
163
161
|
f"langmem=={version('langmem')}",
|
|
164
162
|
f"loguru=={version('loguru')}",
|
|
165
163
|
f"mcp=={version('mcp')}",
|
|
@@ -212,13 +210,13 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
212
210
|
module = importlib.import_module(module_path)
|
|
213
211
|
|
|
214
212
|
# Get the function from the module
|
|
215
|
-
func = getattr(module, func_name)
|
|
213
|
+
func: Any = getattr(module, func_name)
|
|
216
214
|
|
|
217
|
-
# Verify that the resolved object is callable or is a
|
|
215
|
+
# Verify that the resolved object is callable or is a LangChain tool
|
|
218
216
|
# In langchain 1.x, StructuredTool objects are not directly callable
|
|
219
217
|
# but have an invoke() method
|
|
220
|
-
is_callable = callable(func)
|
|
221
|
-
is_langchain_tool =
|
|
218
|
+
is_callable: bool = callable(func)
|
|
219
|
+
is_langchain_tool: bool = isinstance(func, BaseTool)
|
|
222
220
|
|
|
223
221
|
if not is_callable and not is_langchain_tool:
|
|
224
222
|
raise TypeError(f"Function {func_name} is not callable or invocable.")
|
|
@@ -229,6 +227,72 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
229
227
|
raise ImportError(f"Failed to import {function_name}: {e}")
|
|
230
228
|
|
|
231
229
|
|
|
230
|
+
def type_from_fqn(type_name: str) -> type:
|
|
231
|
+
"""
|
|
232
|
+
Load a type from a fully qualified name (FQN).
|
|
233
|
+
|
|
234
|
+
Dynamically imports and returns a type (class) from a module using its
|
|
235
|
+
fully qualified name. Useful for loading Pydantic models, dataclasses,
|
|
236
|
+
or any Python type specified as a string in configuration files.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
type_name: Fully qualified type name in format "module.path.ClassName"
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
The imported type/class
|
|
243
|
+
|
|
244
|
+
Raises:
|
|
245
|
+
ValueError: If the FQN format is invalid
|
|
246
|
+
ImportError: If the module cannot be imported
|
|
247
|
+
AttributeError: If the type doesn't exist in the module
|
|
248
|
+
TypeError: If the resolved object is not a type
|
|
249
|
+
|
|
250
|
+
Example:
|
|
251
|
+
>>> ProductModel = type_from_fqn("my_models.ProductInfo")
|
|
252
|
+
>>> instance = ProductModel(name="Widget", price=9.99)
|
|
253
|
+
"""
|
|
254
|
+
logger.debug(f"Loading type: {type_name}")
|
|
255
|
+
|
|
256
|
+
try:
|
|
257
|
+
# Split the FQN into module path and class name
|
|
258
|
+
parts = type_name.rsplit(".", 1)
|
|
259
|
+
if len(parts) != 2:
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"Invalid type name '{type_name}'. "
|
|
262
|
+
"Expected format: 'module.path.ClassName'"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
module_path, class_name = parts
|
|
266
|
+
|
|
267
|
+
# Dynamically import the module
|
|
268
|
+
try:
|
|
269
|
+
module = importlib.import_module(module_path)
|
|
270
|
+
except ModuleNotFoundError as e:
|
|
271
|
+
raise ImportError(
|
|
272
|
+
f"Could not import module '{module_path}' for type '{type_name}': {e}"
|
|
273
|
+
) from e
|
|
274
|
+
|
|
275
|
+
# Get the class from the module
|
|
276
|
+
if not hasattr(module, class_name):
|
|
277
|
+
raise AttributeError(
|
|
278
|
+
f"Module '{module_path}' does not have attribute '{class_name}'"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
resolved_type = getattr(module, class_name)
|
|
282
|
+
|
|
283
|
+
# Verify it's actually a type
|
|
284
|
+
if not isinstance(resolved_type, type):
|
|
285
|
+
raise TypeError(
|
|
286
|
+
f"'{type_name}' resolved to {resolved_type}, which is not a type"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return resolved_type
|
|
290
|
+
|
|
291
|
+
except (ValueError, ImportError, AttributeError, TypeError) as e:
|
|
292
|
+
# Provide a detailed error message that includes the original exception
|
|
293
|
+
raise type(e)(f"Failed to load type '{type_name}': {e}") from e
|
|
294
|
+
|
|
295
|
+
|
|
232
296
|
def is_in_model_serving() -> bool:
|
|
233
297
|
"""Check if running in Databricks Model Serving environment.
|
|
234
298
|
|