dao-ai 0.0.35__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +797 -242
  4. dao_ai/genie/__init__.py +38 -0
  5. dao_ai/genie/cache/__init__.py +43 -0
  6. dao_ai/genie/cache/base.py +72 -0
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +329 -0
  9. dao_ai/genie/cache/semantic.py +919 -0
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +11 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +108 -35
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.0.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.0.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/human_in_the_loop.py +0 -100
  54. dao_ai-0.0.35.dist-info/METADATA +0 -1169
  55. dao_ai-0.0.35.dist-info/RECORD +0 -41
  56. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
  57. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
  58. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/genie.py CHANGED
@@ -1,3 +1,14 @@
1
+ """
2
+ Genie tool for natural language queries to databases.
3
+
4
+ This module provides the tool factory for creating LangGraph tools that
5
+ interact with Databricks Genie.
6
+
7
+ For the core Genie service and cache implementations, see:
8
+ - dao_ai.genie: GenieService, GenieServiceBase
9
+ - dao_ai.genie.cache: LRUCacheService, SemanticCacheService
10
+ """
11
+
1
12
  import json
2
13
  import os
3
14
  from textwrap import dedent
@@ -5,23 +16,29 @@ from typing import Annotated, Any, Callable
5
16
 
6
17
  import pandas as pd
7
18
  from databricks_ai_bridge.genie import Genie, GenieResponse
8
- from langchain.tools import tool
19
+ from langchain.tools import ToolRuntime, tool
9
20
  from langchain_core.messages import ToolMessage
10
- from langchain_core.tools import InjectedToolCallId
11
- from langgraph.prebuilt import InjectedState
12
21
  from langgraph.types import Command
13
22
  from loguru import logger
14
- from pydantic import BaseModel, Field
23
+ from pydantic import BaseModel
15
24
 
16
- from dao_ai.config import AnyVariable, CompositeVariableModel, GenieRoomModel, value_of
25
+ from dao_ai.config import (
26
+ AnyVariable,
27
+ CompositeVariableModel,
28
+ GenieLRUCacheParametersModel,
29
+ GenieRoomModel,
30
+ GenieSemanticCacheParametersModel,
31
+ value_of,
32
+ )
33
+ from dao_ai.genie import GenieService, GenieServiceBase
34
+ from dao_ai.genie.cache import CacheResult, LRUCacheService, SemanticCacheService
35
+ from dao_ai.state import AgentState, Context, SessionState
17
36
 
18
37
 
19
38
  class GenieToolInput(BaseModel):
20
- """Input schema for the Genie tool."""
39
+ """Input schema for Genie tool - only includes user-facing parameters."""
21
40
 
22
- question: str = Field(
23
- description="The question to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question."
24
- )
41
+ question: str
25
42
 
26
43
 
27
44
  def _response_to_json(response: GenieResponse) -> str:
@@ -46,6 +63,10 @@ def create_genie_tool(
46
63
  description: str | None = None,
47
64
  persist_conversation: bool = True,
48
65
  truncate_results: bool = False,
66
+ lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
67
+ semantic_cache_parameters: GenieSemanticCacheParametersModel
68
+ | dict[str, Any]
69
+ | None = None,
49
70
  ) -> Callable[..., Command]:
50
71
  """
51
72
  Create a tool for interacting with Databricks Genie for natural language queries to databases.
@@ -61,6 +82,9 @@ def create_genie_tool(
61
82
  persist_conversation: Whether to persist conversation IDs across tool calls for
62
83
  multi-turn conversations within the same Genie space
63
84
  truncate_results: Whether to truncate large query results to fit token limits
85
+ lru_cache_parameters: Optional LRU cache configuration for SQL query caching
86
+ semantic_cache_parameters: Optional semantic cache configuration using pg_vector
87
+ for similarity-based query matching
64
88
 
65
89
  Returns:
66
90
  A LangGraph tool that processes natural language queries through Genie
@@ -72,13 +96,20 @@ def create_genie_tool(
72
96
  logger.debug(f"truncate_results: {truncate_results}")
73
97
  logger.debug(f"name: {name}")
74
98
  logger.debug(f"description: {description}")
75
- logger.debug(f"genie_room: {genie_room}")
76
- logger.debug(f"persist_conversation: {persist_conversation}")
77
- logger.debug(f"truncate_results: {truncate_results}")
99
+ logger.debug(f"lru_cache_parameters: {lru_cache_parameters}")
100
+ logger.debug(f"semantic_cache_parameters: {semantic_cache_parameters}")
78
101
 
79
102
  if isinstance(genie_room, dict):
80
103
  genie_room = GenieRoomModel(**genie_room)
81
104
 
105
+ if isinstance(lru_cache_parameters, dict):
106
+ lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
107
+
108
+ if isinstance(semantic_cache_parameters, dict):
109
+ semantic_cache_parameters = GenieSemanticCacheParametersModel(
110
+ **semantic_cache_parameters
111
+ )
112
+
82
113
  space_id: AnyVariable = genie_room.space_id or os.environ.get(
83
114
  "DATABRICKS_GENIE_SPACE_ID"
84
115
  )
@@ -108,52 +139,94 @@ Returns:
108
139
  GenieResponse: A response object containing the conversation ID and result from Genie."""
109
140
  tool_description = tool_description + function_docs
110
141
 
142
+ genie: Genie = Genie(
143
+ space_id=space_id,
144
+ client=genie_room.workspace_client,
145
+ truncate_results=truncate_results,
146
+ )
147
+
148
+ genie_service: GenieServiceBase = GenieService(genie)
149
+
150
+ # Wrap with semantic cache first (checked second due to decorator pattern)
151
+ if semantic_cache_parameters is not None:
152
+ genie_service = SemanticCacheService(
153
+ impl=genie_service,
154
+ parameters=semantic_cache_parameters,
155
+ workspace_client=genie_room.workspace_client, # Pass workspace client for conversation history
156
+ ).initialize() # Eagerly initialize to fail fast and create table
157
+
158
+ # Wrap with LRU cache last (checked first - fast O(1) exact match)
159
+ if lru_cache_parameters is not None:
160
+ genie_service = LRUCacheService(
161
+ impl=genie_service,
162
+ parameters=lru_cache_parameters,
163
+ )
164
+
111
165
  @tool(
112
166
  name_or_callable=tool_name,
113
167
  description=tool_description,
114
168
  )
115
169
  def genie_tool(
116
170
  question: Annotated[str, "The question to ask Genie about your data"],
117
- state: Annotated[dict, InjectedState],
118
- tool_call_id: Annotated[str, InjectedToolCallId],
171
+ runtime: ToolRuntime[Context, AgentState],
119
172
  ) -> Command:
120
- genie: Genie = Genie(
121
- space_id=space_id,
122
- client=genie_room.workspace_client,
123
- truncate_results=truncate_results,
124
- )
173
+ """Process a natural language question through Databricks Genie.
174
+
175
+ Uses ToolRuntime to access state and context in a type-safe way.
176
+ """
177
+ # Access state through runtime
178
+ state: AgentState = runtime.state
179
+ tool_call_id: str = runtime.tool_call_id
180
+
181
+ # Ensure space_id is a string for state keys
182
+ space_id_str: str = str(space_id)
125
183
 
126
- """Process a natural language question through Databricks Genie."""
127
- # Get existing conversation mapping and retrieve conversation ID for this space
128
- conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
129
- existing_conversation_id: str | None = conversation_ids.get(space_id)
184
+ # Get session state (or create new one)
185
+ session: SessionState = state.get("session", SessionState())
186
+
187
+ # Get existing conversation ID from session
188
+ existing_conversation_id: str | None = session.genie.get_conversation_id(
189
+ space_id_str
190
+ )
130
191
  logger.debug(
131
- f"Existing conversation ID for space {space_id}: {existing_conversation_id}"
192
+ f"Existing conversation ID for space {space_id_str}: {existing_conversation_id}"
132
193
  )
133
194
 
134
- response: GenieResponse = genie.ask_question(
195
+ # Call ask_question which always returns CacheResult with cache metadata
196
+ cache_result: CacheResult = genie_service.ask_question(
135
197
  question, conversation_id=existing_conversation_id
136
198
  )
199
+ genie_response: GenieResponse = cache_result.response
200
+ cache_hit: bool = cache_result.cache_hit
201
+ cache_key: str | None = cache_result.served_by
137
202
 
138
- current_conversation_id: str = response.conversation_id
203
+ current_conversation_id: str = genie_response.conversation_id
139
204
  logger.debug(
140
- f"Current conversation ID for space {space_id}: {current_conversation_id}"
205
+ f"Current conversation ID for space {space_id_str}: {current_conversation_id}, "
206
+ f"cache_hit: {cache_hit}, cache_key: {cache_key}"
141
207
  )
142
208
 
143
- # Update the conversation mapping with the new conversation ID for this space
144
-
209
+ # Update session state with cache information
210
+ if persist_conversation:
211
+ session.genie.update_space(
212
+ space_id=space_id_str,
213
+ conversation_id=current_conversation_id,
214
+ cache_hit=cache_hit,
215
+ cache_key=cache_key,
216
+ last_query=question,
217
+ )
218
+
219
+ # Build update dict with response and session
145
220
  update: dict[str, Any] = {
146
221
  "messages": [
147
- ToolMessage(_response_to_json(response), tool_call_id=tool_call_id)
222
+ ToolMessage(
223
+ _response_to_json(genie_response), tool_call_id=tool_call_id
224
+ )
148
225
  ],
149
226
  }
150
227
 
151
228
  if persist_conversation:
152
- updated_conversation_ids: dict[str, str] = conversation_ids.copy()
153
- updated_conversation_ids[space_id] = current_conversation_id
154
- update["genie_conversation_ids"] = updated_conversation_ids
155
-
156
- logger.debug(f"State update: {update}")
229
+ update["session"] = session
157
230
 
158
231
  return Command(update=update)
159
232
 
dao_ai/tools/mcp.py CHANGED
@@ -14,7 +14,6 @@ from dao_ai.config import (
14
14
  McpFunctionModel,
15
15
  TransportType,
16
16
  )
17
- from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
18
17
 
19
18
 
20
19
  def create_mcp_tools(
@@ -95,7 +94,8 @@ def create_mcp_tools(
95
94
  logger.error(f"MCP tool {mcp_tool.name} failed: {e}")
96
95
  raise
97
96
 
98
- return as_human_in_the_loop(tool_wrapper, function)
97
+ # HITL is now handled at middleware level via HumanInTheLoopMiddleware
98
+ return tool_wrapper
99
99
 
100
100
  return [_create_tool_wrapper_with_connection(tool) for tool in mcp_tools]
101
101
 
@@ -190,6 +190,7 @@ def create_mcp_tools(
190
190
  logger.error(f"MCP tool {mcp_tool.name} failed: {e}")
191
191
  raise
192
192
 
193
- return as_human_in_the_loop(tool_wrapper, function)
193
+ # HITL is now handled at middleware level via HumanInTheLoopMiddleware
194
+ return tool_wrapper
194
195
 
195
196
  return [_create_tool_wrapper(tool) for tool in mcp_tools]
dao_ai/tools/memory.py ADDED
@@ -0,0 +1,50 @@
1
+ """Memory tools for DAO AI."""
2
+
3
+ from typing import Any
4
+
5
+ from langchain_core.tools import BaseTool, StructuredTool
6
+ from langmem import create_search_memory_tool as langmem_create_search_memory_tool
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ def create_search_memory_tool(namespace: tuple[str, ...]) -> BaseTool:
11
+ """
12
+ Create a Databricks-compatible search_memory tool.
13
+
14
+ The langmem search_memory tool has a 'filter' field with additionalProperties: true
15
+ in its schema, which Databricks LLM endpoints reject. This function creates a
16
+ wrapper tool that omits the problematic filter field.
17
+
18
+ Args:
19
+ namespace: The memory namespace tuple
20
+
21
+ Returns:
22
+ A StructuredTool compatible with Databricks
23
+ """
24
+ # Get the original tool
25
+ original_tool = langmem_create_search_memory_tool(namespace=namespace)
26
+
27
+ # Create a schema without the problematic filter field
28
+ class SearchMemoryInput(BaseModel):
29
+ """Input for search_memory tool."""
30
+
31
+ query: str = Field(..., description="The search query")
32
+ limit: int = Field(default=10, description="Maximum number of results")
33
+ offset: int = Field(default=0, description="Offset for pagination")
34
+
35
+ # Create a wrapper function
36
+ async def search_memory_wrapper(
37
+ query: str, limit: int = 10, offset: int = 0
38
+ ) -> Any:
39
+ """Search your long-term memories for information relevant to your current context."""
40
+ return await original_tool.ainvoke(
41
+ {"query": query, "limit": limit, "offset": offset}
42
+ )
43
+
44
+ # Create the new tool
45
+ return StructuredTool.from_function(
46
+ coroutine=search_memory_wrapper,
47
+ name="search_memory",
48
+ description="Search your long-term memories for information relevant to your current context.",
49
+ args_schema=SearchMemoryInput,
50
+ )
dao_ai/tools/python.py CHANGED
@@ -7,7 +7,6 @@ from dao_ai.config import (
7
7
  FactoryFunctionModel,
8
8
  PythonFunctionModel,
9
9
  )
10
- from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
11
10
  from dao_ai.utils import load_function
12
11
 
13
12
 
@@ -25,11 +24,8 @@ def create_factory_tool(
25
24
  logger.debug(f"create_factory_tool: {function}")
26
25
 
27
26
  factory: Callable[..., Any] = load_function(function_name=function.full_name)
28
- tool: Callable[..., Any] = factory(**function.args)
29
- tool = as_human_in_the_loop(
30
- tool=tool,
31
- function=function,
32
- )
27
+ tool: RunnableLike = factory(**function.args)
28
+ # HITL is now handled at middleware level via HumanInTheLoopMiddleware
33
29
  return tool
34
30
 
35
31
 
@@ -51,10 +47,6 @@ def create_python_tool(
51
47
  function = function.full_name
52
48
 
53
49
  # Load the Python function dynamically
54
- tool: Callable[..., Any] = load_function(function_name=function)
55
-
56
- tool = as_human_in_the_loop(
57
- tool=tool,
58
- function=function,
59
- )
50
+ tool: RunnableLike = load_function(function_name=function)
51
+ # HITL is now handled at middleware level via HumanInTheLoopMiddleware
60
52
  return tool
dao_ai/tools/search.py ADDED
@@ -0,0 +1,14 @@
1
+ from langchain_community.tools import DuckDuckGoSearchRun
2
+ from langchain_core.runnables.base import RunnableLike
3
+ from loguru import logger
4
+
5
+
6
+ def create_search_tool() -> RunnableLike:
7
+ """
8
+ Create a DuckDuckGo search tool.
9
+
10
+ Returns:
11
+ RunnableLike: A DuckDuckGo search tool that returns results as a list
12
+ """
13
+ logger.debug("Creating DuckDuckGo search tool")
14
+ return DuckDuckGoSearchRun(output_format="list")
dao_ai/tools/slack.py CHANGED
@@ -71,7 +71,7 @@ def create_send_slack_message_tool(
71
71
  channel_name: Optional[str] = None,
72
72
  name: Optional[str] = None,
73
73
  description: Optional[str] = None,
74
- ) -> Callable[[str], Any]:
74
+ ) -> Callable[[str], str]:
75
75
  """
76
76
  Create a tool that sends a message to a Slack channel.
77
77
 
@@ -6,6 +6,7 @@ from databricks_langchain import DatabricksFunctionClient, UCFunctionToolkit
6
6
  from langchain_core.runnables.base import RunnableLike
7
7
  from langchain_core.tools import StructuredTool
8
8
  from loguru import logger
9
+ from unitycatalog.ai.core.base import FunctionExecutionResult
9
10
 
10
11
  from dao_ai.config import (
11
12
  AnyVariable,
@@ -14,7 +15,6 @@ from dao_ai.config import (
14
15
  UnityCatalogFunctionModel,
15
16
  value_of,
16
17
  )
17
- from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
18
18
  from dao_ai.utils import normalize_host
19
19
 
20
20
 
@@ -65,8 +65,8 @@ def create_uc_tools(
65
65
  tools = toolkit.tools or []
66
66
  logger.debug(f"Retrieved tools: {tools}")
67
67
 
68
- # Apply human-in-the-loop wrapper to all tools and return
69
- return [as_human_in_the_loop(tool=tool, function=function_name) for tool in tools]
68
+ # HITL is now handled at middleware level via HumanInTheLoopMiddleware
69
+ return list(tools)
70
70
 
71
71
 
72
72
  def _execute_uc_function(
@@ -87,14 +87,16 @@ def _execute_uc_function(
87
87
  f"Calling UC function {function_name} with parameters: {list(all_params.keys())}"
88
88
  )
89
89
 
90
- result = client.execute_function(function_name=function_name, parameters=all_params)
90
+ result: FunctionExecutionResult = client.execute_function(
91
+ function_name=function_name, parameters=all_params
92
+ )
91
93
 
92
94
  # Handle errors and extract result
93
- if hasattr(result, "error") and result.error:
95
+ if result.error:
94
96
  logger.error(f"Unity Catalog function error: {result.error}")
95
97
  raise RuntimeError(f"Function execution failed: {result.error}")
96
98
 
97
- result_value: str = result.value if hasattr(result, "value") else str(result)
99
+ result_value: str = result.value if result.value is not None else str(result)
98
100
  logger.debug(f"UC function result: {result_value}")
99
101
  return result_value
100
102
 
@@ -1,5 +1,12 @@
1
+ """
2
+ Vector search tool for retrieving documents from Databricks Vector Search.
3
+
4
+ This module provides a tool factory for creating semantic search tools
5
+ using ToolRuntime[Context, AgentState] for type-safe runtime access.
6
+ """
7
+
1
8
  import os
2
- from typing import Annotated, Any, Callable, List, Optional, Sequence
9
+ from typing import Any, Callable, List, Optional, Sequence
3
10
 
4
11
  import mlflow
5
12
  from databricks.vector_search.reranker import DatabricksReranker
@@ -9,9 +16,8 @@ from databricks_ai_bridge.vector_search_retriever_tool import (
9
16
  )
10
17
  from databricks_langchain.vectorstores import DatabricksVectorSearch
11
18
  from flashrank import Ranker, RerankRequest
19
+ from langchain.tools import ToolRuntime, tool
12
20
  from langchain_core.documents import Document
13
- from langchain_core.tools import InjectedToolCallId, tool
14
- from langgraph.prebuilt import InjectedState
15
21
  from loguru import logger
16
22
  from mlflow.entities import SpanType
17
23
 
@@ -20,6 +26,7 @@ from dao_ai.config import (
20
26
  RetrieverModel,
21
27
  VectorStoreModel,
22
28
  )
29
+ from dao_ai.state import AgentState, Context
23
30
  from dao_ai.utils import normalize_host
24
31
 
25
32
 
@@ -27,7 +34,7 @@ def create_vector_search_tool(
27
34
  retriever: RetrieverModel | dict[str, Any],
28
35
  name: Optional[str] = None,
29
36
  description: Optional[str] = None,
30
- ) -> Callable:
37
+ ) -> Callable[..., list[dict[str, Any]]]:
31
38
  """
32
39
  Create a Vector Search tool for retrieving documents from a Databricks Vector Search index.
33
40
 
@@ -254,8 +261,7 @@ def create_vector_search_tool(
254
261
  return reranked_docs
255
262
 
256
263
  # Create the main vector search tool using @tool decorator
257
- # Note: args_schema provides descriptions for query and filters,
258
- # so Annotated is only needed for injected LangGraph parameters
264
+ # Uses ToolRuntime[Context, AgentState] for type-safe runtime access
259
265
  @tool(
260
266
  name_or_callable=name or index_name,
261
267
  description=description or "Search for documents using vector similarity",
@@ -264,8 +270,7 @@ def create_vector_search_tool(
264
270
  def vector_search_tool(
265
271
  query: str,
266
272
  filters: Optional[List[FilterItem]] = None,
267
- state: Annotated[dict, InjectedState] = None,
268
- tool_call_id: Annotated[str, InjectedToolCallId] = None,
273
+ runtime: ToolRuntime[Context, AgentState] = None,
269
274
  ) -> list[dict[str, Any]]:
270
275
  """
271
276
  Search for documents using vector similarity with optional reranking.
@@ -276,8 +281,10 @@ def create_vector_search_tool(
276
281
 
277
282
  Both stages are traced in MLflow for observability.
278
283
 
284
+ Uses ToolRuntime[Context, AgentState] for type-safe runtime access.
285
+
279
286
  Returns:
280
- Command with ToolMessage containing the retrieved documents
287
+ List of serialized documents with page_content and metadata
281
288
  """
282
289
  logger.debug(
283
290
  f"Vector search tool called: query='{query[:50]}...', reranking={reranker_config is not None}"
dao_ai/utils.py CHANGED
@@ -7,6 +7,7 @@ from importlib.metadata import PackageNotFoundError, version
7
7
  from pathlib import Path
8
8
  from typing import Any, Callable, Sequence
9
9
 
10
+ from langchain_core.tools import BaseTool
10
11
  from loguru import logger
11
12
 
12
13
  import dao_ai
@@ -19,7 +20,7 @@ def is_lib_provided(lib_name: str, pip_requirements: Sequence[str]) -> bool:
19
20
  )
20
21
 
21
22
 
22
- def is_installed():
23
+ def is_installed() -> bool:
23
24
  current_file = os.path.abspath(dao_ai.__file__)
24
25
  site_packages = [os.path.abspath(path) for path in site.getsitepackages()]
25
26
  if site.getusersitepackages():
@@ -157,9 +158,6 @@ def get_installed_packages() -> dict[str, str]:
157
158
  f"langchain-tavily=={version('langchain-tavily')}",
158
159
  f"langgraph=={version('langgraph')}",
159
160
  f"langgraph-checkpoint-postgres=={version('langgraph-checkpoint-postgres')}",
160
- f"langgraph-prebuilt=={version('langgraph-prebuilt')}",
161
- f"langgraph-supervisor=={version('langgraph-supervisor')}",
162
- f"langgraph-swarm=={version('langgraph-swarm')}",
163
161
  f"langmem=={version('langmem')}",
164
162
  f"loguru=={version('loguru')}",
165
163
  f"mcp=={version('mcp')}",
@@ -212,13 +210,13 @@ def load_function(function_name: str) -> Callable[..., Any]:
212
210
  module = importlib.import_module(module_path)
213
211
 
214
212
  # Get the function from the module
215
- func = getattr(module, func_name)
213
+ func: Any = getattr(module, func_name)
216
214
 
217
- # Verify that the resolved object is callable or is a langchain tool
215
+ # Verify that the resolved object is callable or is a LangChain tool
218
216
  # In langchain 1.x, StructuredTool objects are not directly callable
219
217
  # but have an invoke() method
220
- is_callable = callable(func)
221
- is_langchain_tool = hasattr(func, "invoke") and hasattr(func, "name")
218
+ is_callable: bool = callable(func)
219
+ is_langchain_tool: bool = isinstance(func, BaseTool)
222
220
 
223
221
  if not is_callable and not is_langchain_tool:
224
222
  raise TypeError(f"Function {func_name} is not callable or invocable.")
@@ -229,6 +227,72 @@ def load_function(function_name: str) -> Callable[..., Any]:
229
227
  raise ImportError(f"Failed to import {function_name}: {e}")
230
228
 
231
229
 
230
+ def type_from_fqn(type_name: str) -> type:
231
+ """
232
+ Load a type from a fully qualified name (FQN).
233
+
234
+ Dynamically imports and returns a type (class) from a module using its
235
+ fully qualified name. Useful for loading Pydantic models, dataclasses,
236
+ or any Python type specified as a string in configuration files.
237
+
238
+ Args:
239
+ type_name: Fully qualified type name in format "module.path.ClassName"
240
+
241
+ Returns:
242
+ The imported type/class
243
+
244
+ Raises:
245
+ ValueError: If the FQN format is invalid
246
+ ImportError: If the module cannot be imported
247
+ AttributeError: If the type doesn't exist in the module
248
+ TypeError: If the resolved object is not a type
249
+
250
+ Example:
251
+ >>> ProductModel = type_from_fqn("my_models.ProductInfo")
252
+ >>> instance = ProductModel(name="Widget", price=9.99)
253
+ """
254
+ logger.debug(f"Loading type: {type_name}")
255
+
256
+ try:
257
+ # Split the FQN into module path and class name
258
+ parts = type_name.rsplit(".", 1)
259
+ if len(parts) != 2:
260
+ raise ValueError(
261
+ f"Invalid type name '{type_name}'. "
262
+ "Expected format: 'module.path.ClassName'"
263
+ )
264
+
265
+ module_path, class_name = parts
266
+
267
+ # Dynamically import the module
268
+ try:
269
+ module = importlib.import_module(module_path)
270
+ except ModuleNotFoundError as e:
271
+ raise ImportError(
272
+ f"Could not import module '{module_path}' for type '{type_name}': {e}"
273
+ ) from e
274
+
275
+ # Get the class from the module
276
+ if not hasattr(module, class_name):
277
+ raise AttributeError(
278
+ f"Module '{module_path}' does not have attribute '{class_name}'"
279
+ )
280
+
281
+ resolved_type = getattr(module, class_name)
282
+
283
+ # Verify it's actually a type
284
+ if not isinstance(resolved_type, type):
285
+ raise TypeError(
286
+ f"'{type_name}' resolved to {resolved_type}, which is not a type"
287
+ )
288
+
289
+ return resolved_type
290
+
291
+ except (ValueError, ImportError, AttributeError, TypeError) as e:
292
+ # Provide a detailed error message that includes the original exception
293
+ raise type(e)(f"Failed to load type '{type_name}': {e}") from e
294
+
295
+
232
296
  def is_in_model_serving() -> bool:
233
297
  """Check if running in Databricks Model Serving environment.
234
298