dao-ai 0.0.36__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +770 -244
- dao_ai/genie/__init__.py +1 -22
- dao_ai/genie/cache/__init__.py +1 -2
- dao_ai/genie/cache/base.py +20 -70
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +44 -21
- dao_ai/genie/cache/semantic.py +390 -109
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +8 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +47 -24
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.1.dist-info/METADATA +1878 -0
- dao_ai-0.1.1.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/genie/__init__.py +0 -236
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.36.dist-info/METADATA +0 -951
- dao_ai-0.0.36.dist-info/RECORD +0 -47
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/genie.py
CHANGED
|
@@ -16,10 +16,8 @@ from typing import Annotated, Any, Callable
|
|
|
16
16
|
|
|
17
17
|
import pandas as pd
|
|
18
18
|
from databricks_ai_bridge.genie import Genie, GenieResponse
|
|
19
|
-
from langchain.tools import tool
|
|
19
|
+
from langchain.tools import ToolRuntime, tool
|
|
20
20
|
from langchain_core.messages import ToolMessage
|
|
21
|
-
from langchain_core.tools import InjectedToolCallId
|
|
22
|
-
from langgraph.prebuilt import InjectedState
|
|
23
21
|
from langgraph.types import Command
|
|
24
22
|
from loguru import logger
|
|
25
23
|
from pydantic import BaseModel
|
|
@@ -33,7 +31,8 @@ from dao_ai.config import (
|
|
|
33
31
|
value_of,
|
|
34
32
|
)
|
|
35
33
|
from dao_ai.genie import GenieService, GenieServiceBase
|
|
36
|
-
from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
|
|
34
|
+
from dao_ai.genie.cache import CacheResult, LRUCacheService, SemanticCacheService
|
|
35
|
+
from dao_ai.state import AgentState, Context, SessionState
|
|
37
36
|
|
|
38
37
|
|
|
39
38
|
class GenieToolInput(BaseModel):
|
|
@@ -97,9 +96,6 @@ def create_genie_tool(
|
|
|
97
96
|
logger.debug(f"truncate_results: {truncate_results}")
|
|
98
97
|
logger.debug(f"name: {name}")
|
|
99
98
|
logger.debug(f"description: {description}")
|
|
100
|
-
logger.debug(f"genie_room: {genie_room}")
|
|
101
|
-
logger.debug(f"persist_conversation: {persist_conversation}")
|
|
102
|
-
logger.debug(f"truncate_results: {truncate_results}")
|
|
103
99
|
logger.debug(f"lru_cache_parameters: {lru_cache_parameters}")
|
|
104
100
|
logger.debug(f"semantic_cache_parameters: {semantic_cache_parameters}")
|
|
105
101
|
|
|
@@ -156,7 +152,7 @@ GenieResponse: A response object containing the conversation ID and result from
|
|
|
156
152
|
genie_service = SemanticCacheService(
|
|
157
153
|
impl=genie_service,
|
|
158
154
|
parameters=semantic_cache_parameters,
|
|
159
|
-
|
|
155
|
+
workspace_client=genie_room.workspace_client, # Pass workspace client for conversation history
|
|
160
156
|
).initialize() # Eagerly initialize to fail fast and create table
|
|
161
157
|
|
|
162
158
|
# Wrap with LRU cache last (checked first - fast O(1) exact match)
|
|
@@ -172,38 +168,65 @@ GenieResponse: A response object containing the conversation ID and result from
|
|
|
172
168
|
)
|
|
173
169
|
def genie_tool(
|
|
174
170
|
question: Annotated[str, "The question to ask Genie about your data"],
|
|
175
|
-
|
|
176
|
-
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
171
|
+
runtime: ToolRuntime[Context, AgentState],
|
|
177
172
|
) -> Command:
|
|
178
|
-
"""Process a natural language question through Databricks Genie.
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
173
|
+
"""Process a natural language question through Databricks Genie.
|
|
174
|
+
|
|
175
|
+
Uses ToolRuntime to access state and context in a type-safe way.
|
|
176
|
+
"""
|
|
177
|
+
# Access state through runtime
|
|
178
|
+
state: AgentState = runtime.state
|
|
179
|
+
tool_call_id: str = runtime.tool_call_id
|
|
180
|
+
|
|
181
|
+
# Ensure space_id is a string for state keys
|
|
182
|
+
space_id_str: str = str(space_id)
|
|
183
|
+
|
|
184
|
+
# Get session state (or create new one)
|
|
185
|
+
session: SessionState = state.get("session", SessionState())
|
|
186
|
+
|
|
187
|
+
# Get existing conversation ID from session
|
|
188
|
+
existing_conversation_id: str | None = session.genie.get_conversation_id(
|
|
189
|
+
space_id_str
|
|
190
|
+
)
|
|
182
191
|
logger.debug(
|
|
183
|
-
f"Existing conversation ID for space {
|
|
192
|
+
f"Existing conversation ID for space {space_id_str}: {existing_conversation_id}"
|
|
184
193
|
)
|
|
185
194
|
|
|
186
|
-
|
|
195
|
+
# Call ask_question which always returns CacheResult with cache metadata
|
|
196
|
+
cache_result: CacheResult = genie_service.ask_question(
|
|
187
197
|
question, conversation_id=existing_conversation_id
|
|
188
198
|
)
|
|
199
|
+
genie_response: GenieResponse = cache_result.response
|
|
200
|
+
cache_hit: bool = cache_result.cache_hit
|
|
201
|
+
cache_key: str | None = cache_result.served_by
|
|
189
202
|
|
|
190
|
-
current_conversation_id: str =
|
|
203
|
+
current_conversation_id: str = genie_response.conversation_id
|
|
191
204
|
logger.debug(
|
|
192
|
-
f"Current conversation ID for space {
|
|
205
|
+
f"Current conversation ID for space {space_id_str}: {current_conversation_id}, "
|
|
206
|
+
f"cache_hit: {cache_hit}, cache_key: {cache_key}"
|
|
193
207
|
)
|
|
194
208
|
|
|
195
|
-
# Update
|
|
196
|
-
|
|
209
|
+
# Update session state with cache information
|
|
210
|
+
if persist_conversation:
|
|
211
|
+
session.genie.update_space(
|
|
212
|
+
space_id=space_id_str,
|
|
213
|
+
conversation_id=current_conversation_id,
|
|
214
|
+
cache_hit=cache_hit,
|
|
215
|
+
cache_key=cache_key,
|
|
216
|
+
last_query=question,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Build update dict with response and session
|
|
197
220
|
update: dict[str, Any] = {
|
|
198
221
|
"messages": [
|
|
199
|
-
ToolMessage(
|
|
222
|
+
ToolMessage(
|
|
223
|
+
_response_to_json(genie_response), tool_call_id=tool_call_id
|
|
224
|
+
)
|
|
200
225
|
],
|
|
201
226
|
}
|
|
202
227
|
|
|
203
228
|
if persist_conversation:
|
|
204
|
-
|
|
205
|
-
updated_conversation_ids[space_id] = current_conversation_id
|
|
206
|
-
update["genie_conversation_ids"] = updated_conversation_ids
|
|
229
|
+
update["session"] = session
|
|
207
230
|
|
|
208
231
|
return Command(update=update)
|
|
209
232
|
|
dao_ai/tools/mcp.py
CHANGED
|
@@ -14,7 +14,6 @@ from dao_ai.config import (
|
|
|
14
14
|
McpFunctionModel,
|
|
15
15
|
TransportType,
|
|
16
16
|
)
|
|
17
|
-
from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
def create_mcp_tools(
|
|
@@ -95,7 +94,8 @@ def create_mcp_tools(
|
|
|
95
94
|
logger.error(f"MCP tool {mcp_tool.name} failed: {e}")
|
|
96
95
|
raise
|
|
97
96
|
|
|
98
|
-
|
|
97
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
98
|
+
return tool_wrapper
|
|
99
99
|
|
|
100
100
|
return [_create_tool_wrapper_with_connection(tool) for tool in mcp_tools]
|
|
101
101
|
|
|
@@ -190,6 +190,7 @@ def create_mcp_tools(
|
|
|
190
190
|
logger.error(f"MCP tool {mcp_tool.name} failed: {e}")
|
|
191
191
|
raise
|
|
192
192
|
|
|
193
|
-
|
|
193
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
194
|
+
return tool_wrapper
|
|
194
195
|
|
|
195
196
|
return [_create_tool_wrapper(tool) for tool in mcp_tools]
|
dao_ai/tools/memory.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Memory tools for DAO AI."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from langchain_core.tools import BaseTool, StructuredTool
|
|
6
|
+
from langmem import create_search_memory_tool as langmem_create_search_memory_tool
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def create_search_memory_tool(namespace: tuple[str, ...]) -> BaseTool:
|
|
11
|
+
"""
|
|
12
|
+
Create a Databricks-compatible search_memory tool.
|
|
13
|
+
|
|
14
|
+
The langmem search_memory tool has a 'filter' field with additionalProperties: true
|
|
15
|
+
in its schema, which Databricks LLM endpoints reject. This function creates a
|
|
16
|
+
wrapper tool that omits the problematic filter field.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
namespace: The memory namespace tuple
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
A StructuredTool compatible with Databricks
|
|
23
|
+
"""
|
|
24
|
+
# Get the original tool
|
|
25
|
+
original_tool = langmem_create_search_memory_tool(namespace=namespace)
|
|
26
|
+
|
|
27
|
+
# Create a schema without the problematic filter field
|
|
28
|
+
class SearchMemoryInput(BaseModel):
|
|
29
|
+
"""Input for search_memory tool."""
|
|
30
|
+
|
|
31
|
+
query: str = Field(..., description="The search query")
|
|
32
|
+
limit: int = Field(default=10, description="Maximum number of results")
|
|
33
|
+
offset: int = Field(default=0, description="Offset for pagination")
|
|
34
|
+
|
|
35
|
+
# Create a wrapper function
|
|
36
|
+
async def search_memory_wrapper(
|
|
37
|
+
query: str, limit: int = 10, offset: int = 0
|
|
38
|
+
) -> Any:
|
|
39
|
+
"""Search your long-term memories for information relevant to your current context."""
|
|
40
|
+
return await original_tool.ainvoke(
|
|
41
|
+
{"query": query, "limit": limit, "offset": offset}
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Create the new tool
|
|
45
|
+
return StructuredTool.from_function(
|
|
46
|
+
coroutine=search_memory_wrapper,
|
|
47
|
+
name="search_memory",
|
|
48
|
+
description="Search your long-term memories for information relevant to your current context.",
|
|
49
|
+
args_schema=SearchMemoryInput,
|
|
50
|
+
)
|
dao_ai/tools/python.py
CHANGED
|
@@ -7,7 +7,6 @@ from dao_ai.config import (
|
|
|
7
7
|
FactoryFunctionModel,
|
|
8
8
|
PythonFunctionModel,
|
|
9
9
|
)
|
|
10
|
-
from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
|
|
11
10
|
from dao_ai.utils import load_function
|
|
12
11
|
|
|
13
12
|
|
|
@@ -25,11 +24,8 @@ def create_factory_tool(
|
|
|
25
24
|
logger.debug(f"create_factory_tool: {function}")
|
|
26
25
|
|
|
27
26
|
factory: Callable[..., Any] = load_function(function_name=function.full_name)
|
|
28
|
-
tool:
|
|
29
|
-
|
|
30
|
-
tool=tool,
|
|
31
|
-
function=function,
|
|
32
|
-
)
|
|
27
|
+
tool: RunnableLike = factory(**function.args)
|
|
28
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
33
29
|
return tool
|
|
34
30
|
|
|
35
31
|
|
|
@@ -51,10 +47,6 @@ def create_python_tool(
|
|
|
51
47
|
function = function.full_name
|
|
52
48
|
|
|
53
49
|
# Load the Python function dynamically
|
|
54
|
-
tool:
|
|
55
|
-
|
|
56
|
-
tool = as_human_in_the_loop(
|
|
57
|
-
tool=tool,
|
|
58
|
-
function=function,
|
|
59
|
-
)
|
|
50
|
+
tool: RunnableLike = load_function(function_name=function)
|
|
51
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
60
52
|
return tool
|
dao_ai/tools/search.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from langchain_community.tools import DuckDuckGoSearchRun
|
|
2
|
+
from langchain_core.runnables.base import RunnableLike
|
|
3
|
+
from loguru import logger
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def create_search_tool() -> RunnableLike:
|
|
7
|
+
"""
|
|
8
|
+
Create a DuckDuckGo search tool.
|
|
9
|
+
|
|
10
|
+
Returns:
|
|
11
|
+
RunnableLike: A DuckDuckGo search tool that returns results as a list
|
|
12
|
+
"""
|
|
13
|
+
logger.debug("Creating DuckDuckGo search tool")
|
|
14
|
+
return DuckDuckGoSearchRun(output_format="list")
|
dao_ai/tools/slack.py
CHANGED
|
@@ -71,7 +71,7 @@ def create_send_slack_message_tool(
|
|
|
71
71
|
channel_name: Optional[str] = None,
|
|
72
72
|
name: Optional[str] = None,
|
|
73
73
|
description: Optional[str] = None,
|
|
74
|
-
) -> Callable[[str],
|
|
74
|
+
) -> Callable[[str], str]:
|
|
75
75
|
"""
|
|
76
76
|
Create a tool that sends a message to a Slack channel.
|
|
77
77
|
|
dao_ai/tools/unity_catalog.py
CHANGED
|
@@ -6,6 +6,7 @@ from databricks_langchain import DatabricksFunctionClient, UCFunctionToolkit
|
|
|
6
6
|
from langchain_core.runnables.base import RunnableLike
|
|
7
7
|
from langchain_core.tools import StructuredTool
|
|
8
8
|
from loguru import logger
|
|
9
|
+
from unitycatalog.ai.core.base import FunctionExecutionResult
|
|
9
10
|
|
|
10
11
|
from dao_ai.config import (
|
|
11
12
|
AnyVariable,
|
|
@@ -14,7 +15,6 @@ from dao_ai.config import (
|
|
|
14
15
|
UnityCatalogFunctionModel,
|
|
15
16
|
value_of,
|
|
16
17
|
)
|
|
17
|
-
from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
|
|
18
18
|
from dao_ai.utils import normalize_host
|
|
19
19
|
|
|
20
20
|
|
|
@@ -65,8 +65,8 @@ def create_uc_tools(
|
|
|
65
65
|
tools = toolkit.tools or []
|
|
66
66
|
logger.debug(f"Retrieved tools: {tools}")
|
|
67
67
|
|
|
68
|
-
#
|
|
69
|
-
return
|
|
68
|
+
# HITL is now handled at middleware level via HumanInTheLoopMiddleware
|
|
69
|
+
return list(tools)
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
def _execute_uc_function(
|
|
@@ -87,14 +87,16 @@ def _execute_uc_function(
|
|
|
87
87
|
f"Calling UC function {function_name} with parameters: {list(all_params.keys())}"
|
|
88
88
|
)
|
|
89
89
|
|
|
90
|
-
result = client.execute_function(
|
|
90
|
+
result: FunctionExecutionResult = client.execute_function(
|
|
91
|
+
function_name=function_name, parameters=all_params
|
|
92
|
+
)
|
|
91
93
|
|
|
92
94
|
# Handle errors and extract result
|
|
93
|
-
if
|
|
95
|
+
if result.error:
|
|
94
96
|
logger.error(f"Unity Catalog function error: {result.error}")
|
|
95
97
|
raise RuntimeError(f"Function execution failed: {result.error}")
|
|
96
98
|
|
|
97
|
-
result_value: str = result.value if
|
|
99
|
+
result_value: str = result.value if result.value is not None else str(result)
|
|
98
100
|
logger.debug(f"UC function result: {result_value}")
|
|
99
101
|
return result_value
|
|
100
102
|
|
dao_ai/tools/vector_search.py
CHANGED
|
@@ -1,5 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vector search tool for retrieving documents from Databricks Vector Search.
|
|
3
|
+
|
|
4
|
+
This module provides a tool factory for creating semantic search tools
|
|
5
|
+
using ToolRuntime[Context, AgentState] for type-safe runtime access.
|
|
6
|
+
"""
|
|
7
|
+
|
|
1
8
|
import os
|
|
2
|
-
from typing import
|
|
9
|
+
from typing import Any, Callable, List, Optional, Sequence
|
|
3
10
|
|
|
4
11
|
import mlflow
|
|
5
12
|
from databricks.vector_search.reranker import DatabricksReranker
|
|
@@ -9,9 +16,8 @@ from databricks_ai_bridge.vector_search_retriever_tool import (
|
|
|
9
16
|
)
|
|
10
17
|
from databricks_langchain.vectorstores import DatabricksVectorSearch
|
|
11
18
|
from flashrank import Ranker, RerankRequest
|
|
19
|
+
from langchain.tools import ToolRuntime, tool
|
|
12
20
|
from langchain_core.documents import Document
|
|
13
|
-
from langchain_core.tools import InjectedToolCallId, tool
|
|
14
|
-
from langgraph.prebuilt import InjectedState
|
|
15
21
|
from loguru import logger
|
|
16
22
|
from mlflow.entities import SpanType
|
|
17
23
|
|
|
@@ -20,6 +26,7 @@ from dao_ai.config import (
|
|
|
20
26
|
RetrieverModel,
|
|
21
27
|
VectorStoreModel,
|
|
22
28
|
)
|
|
29
|
+
from dao_ai.state import AgentState, Context
|
|
23
30
|
from dao_ai.utils import normalize_host
|
|
24
31
|
|
|
25
32
|
|
|
@@ -27,7 +34,7 @@ def create_vector_search_tool(
|
|
|
27
34
|
retriever: RetrieverModel | dict[str, Any],
|
|
28
35
|
name: Optional[str] = None,
|
|
29
36
|
description: Optional[str] = None,
|
|
30
|
-
) -> Callable:
|
|
37
|
+
) -> Callable[..., list[dict[str, Any]]]:
|
|
31
38
|
"""
|
|
32
39
|
Create a Vector Search tool for retrieving documents from a Databricks Vector Search index.
|
|
33
40
|
|
|
@@ -254,8 +261,7 @@ def create_vector_search_tool(
|
|
|
254
261
|
return reranked_docs
|
|
255
262
|
|
|
256
263
|
# Create the main vector search tool using @tool decorator
|
|
257
|
-
#
|
|
258
|
-
# so Annotated is only needed for injected LangGraph parameters
|
|
264
|
+
# Uses ToolRuntime[Context, AgentState] for type-safe runtime access
|
|
259
265
|
@tool(
|
|
260
266
|
name_or_callable=name or index_name,
|
|
261
267
|
description=description or "Search for documents using vector similarity",
|
|
@@ -264,8 +270,7 @@ def create_vector_search_tool(
|
|
|
264
270
|
def vector_search_tool(
|
|
265
271
|
query: str,
|
|
266
272
|
filters: Optional[List[FilterItem]] = None,
|
|
267
|
-
|
|
268
|
-
tool_call_id: Annotated[str, InjectedToolCallId] = None,
|
|
273
|
+
runtime: ToolRuntime[Context, AgentState] = None,
|
|
269
274
|
) -> list[dict[str, Any]]:
|
|
270
275
|
"""
|
|
271
276
|
Search for documents using vector similarity with optional reranking.
|
|
@@ -276,8 +281,10 @@ def create_vector_search_tool(
|
|
|
276
281
|
|
|
277
282
|
Both stages are traced in MLflow for observability.
|
|
278
283
|
|
|
284
|
+
Uses ToolRuntime[Context, AgentState] for type-safe runtime access.
|
|
285
|
+
|
|
279
286
|
Returns:
|
|
280
|
-
|
|
287
|
+
List of serialized documents with page_content and metadata
|
|
281
288
|
"""
|
|
282
289
|
logger.debug(
|
|
283
290
|
f"Vector search tool called: query='{query[:50]}...', reranking={reranker_config is not None}"
|
dao_ai/utils.py
CHANGED
|
@@ -7,6 +7,7 @@ from importlib.metadata import PackageNotFoundError, version
|
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from typing import Any, Callable, Sequence
|
|
9
9
|
|
|
10
|
+
from langchain_core.tools import BaseTool
|
|
10
11
|
from loguru import logger
|
|
11
12
|
|
|
12
13
|
import dao_ai
|
|
@@ -19,7 +20,7 @@ def is_lib_provided(lib_name: str, pip_requirements: Sequence[str]) -> bool:
|
|
|
19
20
|
)
|
|
20
21
|
|
|
21
22
|
|
|
22
|
-
def is_installed():
|
|
23
|
+
def is_installed() -> bool:
|
|
23
24
|
current_file = os.path.abspath(dao_ai.__file__)
|
|
24
25
|
site_packages = [os.path.abspath(path) for path in site.getsitepackages()]
|
|
25
26
|
if site.getusersitepackages():
|
|
@@ -157,9 +158,6 @@ def get_installed_packages() -> dict[str, str]:
|
|
|
157
158
|
f"langchain-tavily=={version('langchain-tavily')}",
|
|
158
159
|
f"langgraph=={version('langgraph')}",
|
|
159
160
|
f"langgraph-checkpoint-postgres=={version('langgraph-checkpoint-postgres')}",
|
|
160
|
-
f"langgraph-prebuilt=={version('langgraph-prebuilt')}",
|
|
161
|
-
f"langgraph-supervisor=={version('langgraph-supervisor')}",
|
|
162
|
-
f"langgraph-swarm=={version('langgraph-swarm')}",
|
|
163
161
|
f"langmem=={version('langmem')}",
|
|
164
162
|
f"loguru=={version('loguru')}",
|
|
165
163
|
f"mcp=={version('mcp')}",
|
|
@@ -212,13 +210,13 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
212
210
|
module = importlib.import_module(module_path)
|
|
213
211
|
|
|
214
212
|
# Get the function from the module
|
|
215
|
-
func = getattr(module, func_name)
|
|
213
|
+
func: Any = getattr(module, func_name)
|
|
216
214
|
|
|
217
|
-
# Verify that the resolved object is callable or is a
|
|
215
|
+
# Verify that the resolved object is callable or is a LangChain tool
|
|
218
216
|
# In langchain 1.x, StructuredTool objects are not directly callable
|
|
219
217
|
# but have an invoke() method
|
|
220
|
-
is_callable = callable(func)
|
|
221
|
-
is_langchain_tool =
|
|
218
|
+
is_callable: bool = callable(func)
|
|
219
|
+
is_langchain_tool: bool = isinstance(func, BaseTool)
|
|
222
220
|
|
|
223
221
|
if not is_callable and not is_langchain_tool:
|
|
224
222
|
raise TypeError(f"Function {func_name} is not callable or invocable.")
|
|
@@ -229,6 +227,72 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
229
227
|
raise ImportError(f"Failed to import {function_name}: {e}")
|
|
230
228
|
|
|
231
229
|
|
|
230
|
+
def type_from_fqn(type_name: str) -> type:
|
|
231
|
+
"""
|
|
232
|
+
Load a type from a fully qualified name (FQN).
|
|
233
|
+
|
|
234
|
+
Dynamically imports and returns a type (class) from a module using its
|
|
235
|
+
fully qualified name. Useful for loading Pydantic models, dataclasses,
|
|
236
|
+
or any Python type specified as a string in configuration files.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
type_name: Fully qualified type name in format "module.path.ClassName"
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
The imported type/class
|
|
243
|
+
|
|
244
|
+
Raises:
|
|
245
|
+
ValueError: If the FQN format is invalid
|
|
246
|
+
ImportError: If the module cannot be imported
|
|
247
|
+
AttributeError: If the type doesn't exist in the module
|
|
248
|
+
TypeError: If the resolved object is not a type
|
|
249
|
+
|
|
250
|
+
Example:
|
|
251
|
+
>>> ProductModel = type_from_fqn("my_models.ProductInfo")
|
|
252
|
+
>>> instance = ProductModel(name="Widget", price=9.99)
|
|
253
|
+
"""
|
|
254
|
+
logger.debug(f"Loading type: {type_name}")
|
|
255
|
+
|
|
256
|
+
try:
|
|
257
|
+
# Split the FQN into module path and class name
|
|
258
|
+
parts = type_name.rsplit(".", 1)
|
|
259
|
+
if len(parts) != 2:
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"Invalid type name '{type_name}'. "
|
|
262
|
+
"Expected format: 'module.path.ClassName'"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
module_path, class_name = parts
|
|
266
|
+
|
|
267
|
+
# Dynamically import the module
|
|
268
|
+
try:
|
|
269
|
+
module = importlib.import_module(module_path)
|
|
270
|
+
except ModuleNotFoundError as e:
|
|
271
|
+
raise ImportError(
|
|
272
|
+
f"Could not import module '{module_path}' for type '{type_name}': {e}"
|
|
273
|
+
) from e
|
|
274
|
+
|
|
275
|
+
# Get the class from the module
|
|
276
|
+
if not hasattr(module, class_name):
|
|
277
|
+
raise AttributeError(
|
|
278
|
+
f"Module '{module_path}' does not have attribute '{class_name}'"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
resolved_type = getattr(module, class_name)
|
|
282
|
+
|
|
283
|
+
# Verify it's actually a type
|
|
284
|
+
if not isinstance(resolved_type, type):
|
|
285
|
+
raise TypeError(
|
|
286
|
+
f"'{type_name}' resolved to {resolved_type}, which is not a type"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return resolved_type
|
|
290
|
+
|
|
291
|
+
except (ValueError, ImportError, AttributeError, TypeError) as e:
|
|
292
|
+
# Provide a detailed error message that includes the original exception
|
|
293
|
+
raise type(e)(f"Failed to load type '{type_name}': {e}") from e
|
|
294
|
+
|
|
295
|
+
|
|
232
296
|
def is_in_model_serving() -> bool:
|
|
233
297
|
"""Check if running in Databricks Model Serving environment.
|
|
234
298
|
|