dao-ai 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/apps/handlers.py CHANGED
@@ -14,7 +14,7 @@ from typing import AsyncGenerator
14
14
 
15
15
  import mlflow
16
16
  from dotenv import load_dotenv
17
- from mlflow.genai.agent_server import invoke, stream
17
+ from mlflow.genai.agent_server import get_request_headers, invoke, stream
18
18
  from mlflow.types.responses import (
19
19
  ResponsesAgentRequest,
20
20
  ResponsesAgentResponse,
@@ -25,6 +25,23 @@ from dao_ai.config import AppConfig
25
25
  from dao_ai.logging import configure_logging
26
26
  from dao_ai.models import LanggraphResponsesAgent
27
27
 
28
+
29
+ def _inject_headers_into_request(request: ResponsesAgentRequest) -> None:
30
+ """Inject request headers into custom_inputs for Context propagation.
31
+
32
+ Captures headers from the MLflow AgentServer context (where they're available)
33
+ and injects them into request.custom_inputs.configurable.headers so they
34
+ flow through to Context and can be used for OBO authentication.
35
+ """
36
+ headers: dict[str, str] = get_request_headers()
37
+ if headers:
38
+ if request.custom_inputs is None:
39
+ request.custom_inputs = {}
40
+ if "configurable" not in request.custom_inputs:
41
+ request.custom_inputs["configurable"] = {}
42
+ request.custom_inputs["configurable"]["headers"] = headers
43
+
44
+
28
45
  # Load environment variables from .env.local if it exists
29
46
  load_dotenv(dotenv_path=".env.local", override=True)
30
47
 
@@ -61,6 +78,8 @@ async def non_streaming(request: ResponsesAgentRequest) -> ResponsesAgentRespons
61
78
  Returns:
62
79
  ResponsesAgentResponse with the complete output
63
80
  """
81
+ # Capture headers while in the AgentServer async context (before they're lost)
82
+ _inject_headers_into_request(request)
64
83
  return await _responses_agent.apredict(request)
65
84
 
66
85
 
@@ -80,5 +99,7 @@ async def streaming(
80
99
  Yields:
81
100
  ResponsesAgentStreamEvent objects as they are generated
82
101
  """
102
+ # Capture headers while in the AgentServer async context (before they're lost)
103
+ _inject_headers_into_request(request)
83
104
  async for event in _responses_agent.apredict_stream(request):
84
105
  yield event
dao_ai/config.py CHANGED
@@ -7,6 +7,7 @@ from enum import Enum
7
7
  from os import PathLike
8
8
  from pathlib import Path
9
9
  from typing import (
10
+ TYPE_CHECKING,
10
11
  Any,
11
12
  Callable,
12
13
  Iterator,
@@ -18,6 +19,9 @@ from typing import (
18
19
  Union,
19
20
  )
20
21
 
22
+ if TYPE_CHECKING:
23
+ from dao_ai.state import Context
24
+
21
25
  from databricks.sdk import WorkspaceClient
22
26
  from databricks.sdk.credentials_provider import (
23
27
  CredentialsStrategy,
@@ -284,8 +288,8 @@ class IsDatabricksResource(ABC, BaseModel):
284
288
 
285
289
  Authentication priority:
286
290
  1. On-Behalf-Of User (on_behalf_of_user=True):
287
- - Forwarded headers (Databricks Apps)
288
- - ModelServingUserCredentials (Model Serving)
291
+ - Uses ModelServingUserCredentials (Model Serving)
292
+ - For Databricks Apps with headers, use workspace_client_from(context)
289
293
  2. Service Principal (client_id + client_secret + workspace_host)
290
294
  3. PAT (pat + workspace_host)
291
295
  4. Ambient/default authentication
@@ -294,36 +298,6 @@ class IsDatabricksResource(ABC, BaseModel):
294
298
 
295
299
  # Check for OBO first (highest priority)
296
300
  if self.on_behalf_of_user:
297
- # NEW: In Databricks Apps, use forwarded headers for per-user auth
298
- try:
299
- from mlflow.genai.agent_server import get_request_headers
300
-
301
- headers = get_request_headers()
302
- forwarded_token = headers.get("x-forwarded-access-token")
303
-
304
- if forwarded_token:
305
- forwarded_user = headers.get("x-forwarded-user", "unknown")
306
- logger.debug(
307
- f"Creating WorkspaceClient for {self.__class__.__name__} "
308
- f"with OBO using forwarded token from Databricks Apps",
309
- forwarded_user=forwarded_user,
310
- )
311
- # Use workspace_host if configured, otherwise SDK will auto-detect
312
- workspace_host_value: str | None = (
313
- normalize_host(value_of(self.workspace_host))
314
- if self.workspace_host
315
- else None
316
- )
317
- return WorkspaceClient(
318
- host=workspace_host_value,
319
- token=forwarded_token,
320
- auth_type="pat",
321
- )
322
- except (ImportError, LookupError):
323
- # mlflow not available or headers not set - fall through to Model Serving
324
- pass
325
-
326
- # Fall back to Model Serving OBO (existing behavior)
327
301
  credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
328
302
  logger.debug(
329
303
  f"Creating WorkspaceClient for {self.__class__.__name__} "
@@ -382,6 +356,55 @@ class IsDatabricksResource(ABC, BaseModel):
382
356
  )
383
357
  return WorkspaceClient()
384
358
 
359
+ def workspace_client_from(self, context: "Context | None") -> WorkspaceClient:
360
+ """
361
+ Get a WorkspaceClient using headers from the provided Context.
362
+
363
+ Use this method from tools that have access to ToolRuntime[Context].
364
+ This allows OBO authentication to work in Databricks Apps where headers
365
+ are captured at request entry and passed through the Context.
366
+
367
+ Args:
368
+ context: Runtime context containing headers for OBO auth.
369
+ If None or no headers, falls back to workspace_client property.
370
+
371
+ Returns:
372
+ WorkspaceClient configured with appropriate authentication.
373
+ """
374
+ from dao_ai.utils import normalize_host
375
+
376
+ # Check if we have headers in context for OBO
377
+ if context and context.headers and self.on_behalf_of_user:
378
+ headers = context.headers
379
+ # Try both lowercase and title-case header names (HTTP headers are case-insensitive)
380
+ forwarded_token = headers.get("x-forwarded-access-token") or headers.get(
381
+ "X-Forwarded-Access-Token"
382
+ )
383
+
384
+ if forwarded_token:
385
+ forwarded_user = headers.get("x-forwarded-user") or headers.get(
386
+ "X-Forwarded-User", "unknown"
387
+ )
388
+ logger.debug(
389
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
390
+ f"with OBO using forwarded token from Context",
391
+ forwarded_user=forwarded_user,
392
+ )
393
+ # Use workspace_host if configured, otherwise SDK will auto-detect
394
+ workspace_host_value: str | None = (
395
+ normalize_host(value_of(self.workspace_host))
396
+ if self.workspace_host
397
+ else None
398
+ )
399
+ return WorkspaceClient(
400
+ host=workspace_host_value,
401
+ token=forwarded_token,
402
+ auth_type="pat",
403
+ )
404
+
405
+ # Fall back to existing workspace_client property
406
+ return self.workspace_client
407
+
385
408
 
386
409
  class DeploymentTarget(str, Enum):
387
410
  """Target platform for agent deployment."""
dao_ai/tools/genie.py CHANGED
@@ -139,29 +139,53 @@ Returns:
139
139
  GenieResponse: A response object containing the conversation ID and result from Genie."""
140
140
  tool_description = tool_description + function_docs
141
141
 
142
- genie: Genie = Genie(
143
- space_id=space_id,
144
- client=genie_room.workspace_client,
145
- truncate_results=truncate_results,
146
- )
142
+ # Cache for genie service - created lazily on first call
143
+ # This allows us to use workspace_client_from with runtime context for OBO
144
+ _cached_genie_service: GenieServiceBase | None = None
145
+
146
+ def _get_genie_service(context: Context | None) -> GenieServiceBase:
147
+ """Get or create the Genie service, using context for OBO auth if available."""
148
+ nonlocal _cached_genie_service
149
+
150
+ # Use cached service if available (for non-OBO or after first call)
151
+ # For OBO, we need fresh workspace client each time to use the user's token
152
+ if _cached_genie_service is not None and not genie_room.on_behalf_of_user:
153
+ return _cached_genie_service
154
+
155
+ # Get workspace client using context for OBO support
156
+ from databricks.sdk import WorkspaceClient
147
157
 
148
- genie_service: GenieServiceBase = GenieService(genie)
149
-
150
- # Wrap with semantic cache first (checked second due to decorator pattern)
151
- if semantic_cache_parameters is not None:
152
- genie_service = SemanticCacheService(
153
- impl=genie_service,
154
- parameters=semantic_cache_parameters,
155
- workspace_client=genie_room.workspace_client, # Pass workspace client for conversation history
156
- ).initialize() # Eagerly initialize to fail fast and create table
157
-
158
- # Wrap with LRU cache last (checked first - fast O(1) exact match)
159
- if lru_cache_parameters is not None:
160
- genie_service = LRUCacheService(
161
- impl=genie_service,
162
- parameters=lru_cache_parameters,
158
+ workspace_client: WorkspaceClient = genie_room.workspace_client_from(context)
159
+
160
+ genie: Genie = Genie(
161
+ space_id=space_id,
162
+ client=workspace_client,
163
+ truncate_results=truncate_results,
163
164
  )
164
165
 
166
+ genie_service: GenieServiceBase = GenieService(genie)
167
+
168
+ # Wrap with semantic cache first (checked second due to decorator pattern)
169
+ if semantic_cache_parameters is not None:
170
+ genie_service = SemanticCacheService(
171
+ impl=genie_service,
172
+ parameters=semantic_cache_parameters,
173
+ workspace_client=workspace_client,
174
+ ).initialize()
175
+
176
+ # Wrap with LRU cache last (checked first - fast O(1) exact match)
177
+ if lru_cache_parameters is not None:
178
+ genie_service = LRUCacheService(
179
+ impl=genie_service,
180
+ parameters=lru_cache_parameters,
181
+ )
182
+
183
+ # Cache for non-OBO scenarios
184
+ if not genie_room.on_behalf_of_user:
185
+ _cached_genie_service = genie_service
186
+
187
+ return genie_service
188
+
165
189
  @tool(
166
190
  name_or_callable=tool_name,
167
191
  description=tool_description,
@@ -177,6 +201,10 @@ GenieResponse: A response object containing the conversation ID and result from
177
201
  # Access state through runtime
178
202
  state: AgentState = runtime.state
179
203
  tool_call_id: str = runtime.tool_call_id
204
+ context: Context | None = runtime.context
205
+
206
+ # Get genie service with OBO support via context
207
+ genie_service: GenieServiceBase = _get_genie_service(context)
180
208
 
181
209
  # Ensure space_id is a string for state keys
182
210
  space_id_str: str = str(space_id)
@@ -194,6 +222,14 @@ GenieResponse: A response object containing the conversation ID and result from
194
222
  conversation_id=existing_conversation_id,
195
223
  )
196
224
 
225
+ # Log the prompt being sent to Genie
226
+ logger.trace(
227
+ "Sending prompt to Genie",
228
+ space_id=space_id_str,
229
+ conversation_id=existing_conversation_id,
230
+ prompt=question[:500] + "..." if len(question) > 500 else question,
231
+ )
232
+
197
233
  # Call ask_question which always returns CacheResult with cache metadata
198
234
  cache_result: CacheResult = genie_service.ask_question(
199
235
  question, conversation_id=existing_conversation_id
@@ -211,6 +247,22 @@ GenieResponse: A response object containing the conversation ID and result from
211
247
  cache_key=cache_key,
212
248
  )
213
249
 
250
+ # Log truncated response for debugging
251
+ result_preview: str = str(genie_response.result)
252
+ if len(result_preview) > 500:
253
+ result_preview = result_preview[:500] + "..."
254
+ logger.trace(
255
+ "Genie response content",
256
+ question=question[:100] + "..." if len(question) > 100 else question,
257
+ query=genie_response.query,
258
+ description=(
259
+ genie_response.description[:200] + "..."
260
+ if genie_response.description and len(genie_response.description) > 200
261
+ else genie_response.description
262
+ ),
263
+ result_preview=result_preview,
264
+ )
265
+
214
266
  # Update session state with cache information
215
267
  if persist_conversation:
216
268
  session.genie.update_space(
dao_ai/tools/mcp.py CHANGED
@@ -30,6 +30,7 @@ from dao_ai.config import (
30
30
  McpFunctionModel,
31
31
  TransportType,
32
32
  )
33
+ from dao_ai.state import Context
33
34
 
34
35
 
35
36
  @dataclass
@@ -173,6 +174,7 @@ def _get_auth_resource(function: McpFunctionModel) -> IsDatabricksResource:
173
174
 
174
175
  def _build_connection_config(
175
176
  function: McpFunctionModel,
177
+ context: Context | None = None,
176
178
  ) -> dict[str, Any]:
177
179
  """
178
180
  Build the connection configuration dictionary for MultiServerMCPClient.
@@ -193,6 +195,7 @@ def _build_connection_config(
193
195
 
194
196
  Args:
195
197
  function: The MCP function model configuration.
198
+ context: Optional runtime context with headers for OBO auth.
196
199
 
197
200
  Returns:
198
201
  A dictionary containing the transport-specific connection settings.
@@ -205,14 +208,17 @@ def _build_connection_config(
205
208
  }
206
209
 
207
210
  # For HTTP transport, use DatabricksOAuthClientProvider with unified auth
211
+ from databricks.sdk import WorkspaceClient
208
212
  from databricks_mcp import DatabricksOAuthClientProvider
209
213
 
210
214
  # Get the resource to use for authentication
211
- auth_resource = _get_auth_resource(function)
215
+ auth_resource: IsDatabricksResource = _get_auth_resource(function)
212
216
 
213
- # Get workspace client from the auth resource
214
- workspace_client = auth_resource.workspace_client
215
- auth_provider = DatabricksOAuthClientProvider(workspace_client)
217
+ # Get workspace client from the auth resource with OBO support via context
218
+ workspace_client: WorkspaceClient = auth_resource.workspace_client_from(context)
219
+ auth_provider: DatabricksOAuthClientProvider = DatabricksOAuthClientProvider(
220
+ workspace_client
221
+ )
216
222
 
217
223
  # Log which resource is providing auth
218
224
  resource_name = (
@@ -509,19 +515,28 @@ async def acreate_mcp_tools(
509
515
  def _create_tool_wrapper(mcp_tool: Tool) -> RunnableLike:
510
516
  """
511
517
  Create a LangChain tool wrapper for an MCP tool.
518
+
519
+ Supports OBO authentication via context headers.
512
520
  """
521
+ from langchain.tools import ToolRuntime
513
522
 
514
523
  @create_tool(
515
524
  mcp_tool.name,
516
525
  description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
517
526
  args_schema=mcp_tool.inputSchema,
518
527
  )
519
- async def tool_wrapper(**kwargs: Any) -> str:
528
+ async def tool_wrapper(
529
+ runtime: ToolRuntime[Context] = None,
530
+ **kwargs: Any,
531
+ ) -> str:
520
532
  """Execute MCP tool with fresh session."""
521
533
  logger.trace("Invoking MCP tool", tool_name=mcp_tool.name, args=kwargs)
522
534
 
523
- invocation_client = MultiServerMCPClient(
524
- {"mcp_function": _build_connection_config(function)}
535
+ # Get context for OBO support
536
+ context: Context | None = runtime.context if runtime else None
537
+
538
+ invocation_client: MultiServerMCPClient = MultiServerMCPClient(
539
+ {"mcp_function": _build_connection_config(function, context)}
525
540
  )
526
541
 
527
542
  try:
@@ -530,7 +545,7 @@ async def acreate_mcp_tools(
530
545
  mcp_tool.name, kwargs
531
546
  )
532
547
 
533
- text_result = _extract_text_content(result)
548
+ text_result: str = _extract_text_content(result)
534
549
 
535
550
  logger.trace(
536
551
  "MCP tool completed",
@@ -625,20 +640,28 @@ def create_mcp_tools(
625
640
  This wrapper handles:
626
641
  - Fresh session creation per invocation (stateless)
627
642
  - Content extraction to plain text (avoiding extra fields)
643
+ - OBO authentication via context headers
628
644
  """
645
+ from langchain.tools import ToolRuntime
629
646
 
630
647
  @create_tool(
631
648
  mcp_tool.name,
632
649
  description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
633
650
  args_schema=mcp_tool.inputSchema,
634
651
  )
635
- async def tool_wrapper(**kwargs: Any) -> str:
652
+ async def tool_wrapper(
653
+ runtime: ToolRuntime[Context] = None,
654
+ **kwargs: Any,
655
+ ) -> str:
636
656
  """Execute MCP tool with fresh session."""
637
657
  logger.trace("Invoking MCP tool", tool_name=mcp_tool.name, args=kwargs)
638
658
 
639
- # Create a fresh client/session for each invocation
640
- invocation_client = MultiServerMCPClient(
641
- {"mcp_function": _build_connection_config(function)}
659
+ # Get context for OBO support
660
+ context: Context | None = runtime.context if runtime else None
661
+
662
+ # Create a fresh client/session for each invocation with OBO support
663
+ invocation_client: MultiServerMCPClient = MultiServerMCPClient(
664
+ {"mcp_function": _build_connection_config(function, context)}
642
665
  )
643
666
 
644
667
  try:
@@ -648,7 +671,7 @@ def create_mcp_tools(
648
671
  )
649
672
 
650
673
  # Extract text content, avoiding extra fields
651
- text_result = _extract_text_content(result)
674
+ text_result: str = _extract_text_content(result)
652
675
 
653
676
  logger.trace(
654
677
  "MCP tool completed",
dao_ai/tools/slack.py CHANGED
@@ -1,11 +1,13 @@
1
1
  from typing import Any, Callable, Optional
2
2
 
3
3
  from databricks.sdk.service.serving import ExternalFunctionRequestHttpMethod
4
+ from langchain.tools import ToolRuntime
4
5
  from langchain_core.tools import tool
5
6
  from loguru import logger
6
7
  from requests import Response
7
8
 
8
9
  from dao_ai.config import ConnectionModel
10
+ from dao_ai.state import Context
9
11
 
10
12
 
11
13
  def _find_channel_id_by_name(
@@ -129,8 +131,17 @@ def create_send_slack_message_tool(
129
131
  name_or_callable=name,
130
132
  description=description,
131
133
  )
132
- def send_slack_message(text: str) -> str:
133
- response: Response = connection.workspace_client.serving_endpoints.http_request(
134
+ def send_slack_message(
135
+ text: str,
136
+ runtime: ToolRuntime[Context] = None,
137
+ ) -> str:
138
+ from databricks.sdk import WorkspaceClient
139
+
140
+ # Get workspace client with OBO support via context
141
+ context: Context | None = runtime.context if runtime else None
142
+ workspace_client: WorkspaceClient = connection.workspace_client_from(context)
143
+
144
+ response: Response = workspace_client.serving_endpoints.http_request(
134
145
  conn=connection.name,
135
146
  method=ExternalFunctionRequestHttpMethod.POST,
136
147
  path="/api/chat.postMessage",
dao_ai/tools/sql.py CHANGED
@@ -7,10 +7,11 @@ pre-configured SQL statements against a Databricks SQL warehouse.
7
7
 
8
8
  from databricks.sdk import WorkspaceClient
9
9
  from databricks.sdk.service.sql import StatementResponse, StatementState
10
- from langchain.tools import tool
10
+ from langchain.tools import ToolRuntime, tool
11
11
  from loguru import logger
12
12
 
13
13
  from dao_ai.config import WarehouseModel, value_of
14
+ from dao_ai.state import Context
14
15
 
15
16
 
16
17
  def create_execute_statement_tool(
@@ -63,7 +64,6 @@ def create_execute_statement_tool(
63
64
  description = f"Execute a pre-configured SQL query against the {warehouse.name} warehouse and return the results."
64
65
 
65
66
  warehouse_id: str = value_of(warehouse.warehouse_id)
66
- workspace_client: WorkspaceClient = warehouse.workspace_client
67
67
 
68
68
  logger.debug(
69
69
  "Creating SQL execution tool",
@@ -74,7 +74,7 @@ def create_execute_statement_tool(
74
74
  )
75
75
 
76
76
  @tool(name_or_callable=name, description=description)
77
- def execute_statement_tool() -> str:
77
+ def execute_statement_tool(runtime: ToolRuntime[Context] = None) -> str:
78
78
  """
79
79
  Execute the pre-configured SQL statement against the Databricks SQL warehouse.
80
80
 
@@ -88,6 +88,10 @@ def create_execute_statement_tool(
88
88
  sql_preview=statement[:100] + "..." if len(statement) > 100 else statement,
89
89
  )
90
90
 
91
+ # Get workspace client with OBO support via context
92
+ context: Context | None = runtime.context if runtime else None
93
+ workspace_client: WorkspaceClient = warehouse.workspace_client_from(context)
94
+
91
95
  try:
92
96
  # Execute the SQL statement
93
97
  statement_response: StatementResponse = (
@@ -1,10 +1,11 @@
1
- from typing import Any, Dict, Optional, Sequence, Set
1
+ from typing import Annotated, Any, Dict, Optional, Sequence, Set
2
2
 
3
3
  from databricks.sdk import WorkspaceClient
4
4
  from databricks.sdk.service.catalog import FunctionInfo, PermissionsChange, Privilege
5
5
  from databricks_langchain import DatabricksFunctionClient, UCFunctionToolkit
6
+ from langchain.tools import ToolRuntime
6
7
  from langchain_core.runnables.base import RunnableLike
7
- from langchain_core.tools import StructuredTool
8
+ from langchain_core.tools import InjectedToolArg, StructuredTool
8
9
  from loguru import logger
9
10
  from pydantic import BaseModel
10
11
  from unitycatalog.ai.core.base import FunctionExecutionResult
@@ -15,6 +16,7 @@ from dao_ai.config import (
15
16
  UnityCatalogFunctionModel,
16
17
  value_of,
17
18
  )
19
+ from dao_ai.state import Context
18
20
  from dao_ai.utils import normalize_host
19
21
 
20
22
 
@@ -35,13 +37,11 @@ def create_uc_tools(
35
37
  A sequence of BaseTool objects that wrap the specified UC functions
36
38
  """
37
39
  original_function_model: UnityCatalogFunctionModel | None = None
38
- workspace_client: WorkspaceClient | None = None
39
40
  function_name: str
40
41
 
41
42
  if isinstance(function, UnityCatalogFunctionModel):
42
43
  original_function_model = function
43
44
  function_name = function.resource.full_name
44
- workspace_client = function.resource.workspace_client
45
45
  else:
46
46
  function_name = function
47
47
 
@@ -56,6 +56,12 @@ def create_uc_tools(
56
56
  # Use with_partial_args directly with UnityCatalogFunctionModel
57
57
  tools = [with_partial_args(original_function_model)]
58
58
  else:
59
+ # For standard UC toolkit, we need workspace_client at creation time
60
+ # Use the resource's workspace_client (will use ambient auth if no OBO)
61
+ workspace_client: WorkspaceClient | None = None
62
+ if original_function_model:
63
+ workspace_client = original_function_model.resource.workspace_client
64
+
59
65
  # Fallback to standard UC toolkit approach
60
66
  client: DatabricksFunctionClient = DatabricksFunctionClient(
61
67
  client=workspace_client
@@ -356,7 +362,6 @@ def with_partial_args(
356
362
  # Get function info from the resource
357
363
  function_name: str = uc_function.resource.full_name
358
364
  tool_name: str = uc_function.resource.name or function_name.replace(".", "_")
359
- workspace_client: WorkspaceClient = uc_function.resource.workspace_client
360
365
 
361
366
  logger.debug(
362
367
  "Creating UC tool with partial args",
@@ -365,7 +370,7 @@ def with_partial_args(
365
370
  partial_args=list(resolved_args.keys()),
366
371
  )
367
372
 
368
- # Grant permissions if we have credentials
373
+ # Grant permissions if we have credentials (using ambient auth for setup)
369
374
  if "client_id" in resolved_args:
370
375
  client_id: str = resolved_args["client_id"]
371
376
  host: Optional[str] = resolved_args.get("host")
@@ -376,14 +381,18 @@ def with_partial_args(
376
381
  "Failed to grant permissions", function_name=function_name, error=str(e)
377
382
  )
378
383
 
379
- # Create the client for function execution using the resource's workspace client
380
- client: DatabricksFunctionClient = DatabricksFunctionClient(client=workspace_client)
384
+ # Get workspace client for schema introspection (uses ambient auth at definition time)
385
+ # Actual execution will use OBO via context
386
+ setup_workspace_client: WorkspaceClient = uc_function.resource.workspace_client
387
+ setup_client: DatabricksFunctionClient = DatabricksFunctionClient(
388
+ client=setup_workspace_client
389
+ )
381
390
 
382
391
  # Try to get the function schema for better tool definition
383
392
  schema_model: type[BaseModel]
384
393
  tool_description: str
385
394
  try:
386
- function_info: FunctionInfo = client.get_function(function_name)
395
+ function_info: FunctionInfo = setup_client.get_function(function_name)
387
396
  schema_info = generate_function_input_params_schema(function_info)
388
397
  tool_description = (
389
398
  function_info.comment or f"Unity Catalog function: {function_name}"
@@ -419,8 +428,21 @@ def with_partial_args(
419
428
  tool_description = f"Unity Catalog function: {function_name}"
420
429
 
421
430
  # Create a wrapper function that calls _execute_uc_function with partial args
422
- def uc_function_wrapper(**kwargs) -> str:
431
+ # Uses InjectedToolArg to ensure runtime is injected but hidden from the LLM
432
+ def uc_function_wrapper(
433
+ runtime: Annotated[ToolRuntime[Context], InjectedToolArg] = None,
434
+ **kwargs: Any,
435
+ ) -> str:
423
436
  """Wrapper function that executes Unity Catalog function with partial args."""
437
+ # Get workspace client with OBO support via context
438
+ context: Context | None = runtime.context if runtime else None
439
+ workspace_client: WorkspaceClient = uc_function.resource.workspace_client_from(
440
+ context
441
+ )
442
+ client: DatabricksFunctionClient = DatabricksFunctionClient(
443
+ client=workspace_client
444
+ )
445
+
424
446
  return _execute_uc_function(
425
447
  client=client,
426
448
  function_name=function_name,
@@ -7,13 +7,14 @@ with dynamic filter schemas based on table columns and FlashRank reranking suppo
7
7
 
8
8
  import json
9
9
  import os
10
- from typing import Any, Optional
10
+ from typing import Annotated, Any, Optional
11
11
 
12
12
  import mlflow
13
13
  from databricks.sdk import WorkspaceClient
14
14
  from databricks.vector_search.reranker import DatabricksReranker
15
15
  from databricks_langchain import DatabricksVectorSearch
16
16
  from flashrank import Ranker, RerankRequest
17
+ from langchain.tools import ToolRuntime, tool
17
18
  from langchain_core.documents import Document
18
19
  from langchain_core.tools import StructuredTool
19
20
  from loguru import logger
@@ -27,6 +28,7 @@ from dao_ai.config import (
27
28
  VectorStoreModel,
28
29
  value_of,
29
30
  )
31
+ from dao_ai.state import Context
30
32
  from dao_ai.utils import normalize_host
31
33
 
32
34
  # Create FilterItem model at module level so it can be used in type hints
@@ -299,35 +301,67 @@ def create_vector_search_tool(
299
301
  client_args_keys=list(client_args.keys()) if client_args else [],
300
302
  )
301
303
 
302
- # Create DatabricksVectorSearch
303
- # Note: text_column should be None for Databricks-managed embeddings
304
- # (it's automatically determined from the index)
305
- vector_search: DatabricksVectorSearch = DatabricksVectorSearch(
306
- index_name=index_name,
307
- text_column=None,
308
- columns=columns,
309
- workspace_client=vector_store.workspace_client,
310
- client_args=client_args if client_args else None,
311
- primary_key=vector_store.primary_key,
312
- doc_uri=vector_store.doc_uri,
313
- include_score=True,
314
- reranker=(
315
- DatabricksReranker(columns_to_rerank=rerank_config.columns)
316
- if rerank_config and rerank_config.columns
317
- else None
318
- ),
319
- )
304
+ # Cache for DatabricksVectorSearch - created lazily for OBO support
305
+ _cached_vector_search: DatabricksVectorSearch | None = None
306
+
307
+ def _get_vector_search(context: Context | None) -> DatabricksVectorSearch:
308
+ """Get or create DatabricksVectorSearch, using context for OBO auth if available."""
309
+ nonlocal _cached_vector_search
310
+
311
+ # Use cached instance if available and not OBO
312
+ if _cached_vector_search is not None and not vector_store.on_behalf_of_user:
313
+ return _cached_vector_search
314
+
315
+ # Get workspace client with OBO support via context
316
+ workspace_client: WorkspaceClient = vector_store.workspace_client_from(context)
317
+
318
+ # Create DatabricksVectorSearch
319
+ # Note: text_column should be None for Databricks-managed embeddings
320
+ # (it's automatically determined from the index)
321
+ vs: DatabricksVectorSearch = DatabricksVectorSearch(
322
+ index_name=index_name,
323
+ text_column=None,
324
+ columns=columns,
325
+ workspace_client=workspace_client,
326
+ client_args=client_args if client_args else None,
327
+ primary_key=vector_store.primary_key,
328
+ doc_uri=vector_store.doc_uri,
329
+ include_score=True,
330
+ reranker=(
331
+ DatabricksReranker(columns_to_rerank=rerank_config.columns)
332
+ if rerank_config and rerank_config.columns
333
+ else None
334
+ ),
335
+ )
320
336
 
321
- # Create dynamic input schema
322
- input_schema: type[BaseModel] = _create_dynamic_input_schema(
323
- index_name, vector_store.workspace_client
324
- )
337
+ # Cache for non-OBO scenarios
338
+ if not vector_store.on_behalf_of_user:
339
+ _cached_vector_search = vs
340
+
341
+ return vs
342
+
343
+ # Determine tool name and description
344
+ tool_name: str = name or f"vector_search_{vector_store.index.name}"
345
+ tool_description: str = description or f"Search documents in {index_name}"
325
346
 
326
- # Define the tool function
347
+ # Use @tool decorator for proper ToolRuntime injection
348
+ # The decorator ensures runtime is automatically injected and hidden from the LLM
349
+ @tool(name_or_callable=tool_name, description=tool_description)
327
350
  def vector_search_func(
328
- query: str, filters: Optional[list[FilterItem]] = None
351
+ query: Annotated[str, "The search query to find relevant documents"],
352
+ filters: Annotated[
353
+ Optional[list[FilterItem]],
354
+ "Optional filters to apply to the search results",
355
+ ] = None,
356
+ runtime: ToolRuntime[Context] = None,
329
357
  ) -> str:
330
358
  """Search for relevant documents using vector similarity."""
359
+ # Get context for OBO support
360
+ context: Context | None = runtime.context if runtime else None
361
+
362
+ # Get vector search instance with OBO support
363
+ vector_search: DatabricksVectorSearch = _get_vector_search(context)
364
+
331
365
  # Convert FilterItem Pydantic models to dict format for DatabricksVectorSearch
332
366
  filters_dict: dict[str, Any] = {}
333
367
  if filters:
@@ -379,14 +413,6 @@ def create_vector_search_tool(
379
413
  # Return as JSON string
380
414
  return json.dumps(serialized_docs)
381
415
 
382
- # Create the StructuredTool
383
- tool: StructuredTool = StructuredTool.from_function(
384
- func=vector_search_func,
385
- name=name or f"vector_search_{vector_store.index.name}",
386
- description=description or f"Search documents in {index_name}",
387
- args_schema=input_schema,
388
- )
389
-
390
- logger.success("Vector search tool created", name=tool.name, index=index_name)
416
+ logger.success("Vector search tool created", name=tool_name, index=index_name)
391
417
 
392
- return tool
418
+ return vector_search_func
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dao-ai
3
- Version: 0.1.12
3
+ Version: 0.1.14
4
4
  Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
5
5
  Project-URL: Homepage, https://github.com/natefleming/dao-ai
6
6
  Project-URL: Documentation, https://natefleming.github.io/dao-ai
@@ -125,7 +125,7 @@ DAO AI Builder generates valid YAML configurations that work seamlessly with thi
125
125
  - **[Architecture](docs/architecture.md)** - Understand how DAO works under the hood
126
126
 
127
127
  ### Core Concepts
128
- - **[Key Capabilities](docs/key-capabilities.md)** - Explore 14 powerful features for production agents
128
+ - **[Key Capabilities](docs/key-capabilities.md)** - Explore 15 powerful features for production agents
129
129
  - **[Configuration Reference](docs/configuration-reference.md)** - Complete YAML configuration guide
130
130
  - **[Examples](docs/examples.md)** - Ready-to-use example configurations
131
131
 
@@ -148,7 +148,7 @@ Before you begin, you'll need:
148
148
  - **Python 3.11 or newer** installed on your computer ([download here](https://www.python.org/downloads/))
149
149
  - **A Databricks workspace** (ask your IT team or see [Databricks docs](https://docs.databricks.com/))
150
150
  - Access to **Unity Catalog** (your organization's data catalog)
151
- - **Model Serving** enabled (for deploying AI agents)
151
+ - **Model Serving** or **Databricks Apps** enabled (for deploying AI agents)
152
152
  - *Optional*: Vector Search, Genie (for advanced features)
153
153
 
154
154
  **Not sure if you have access?** Your Databricks administrator can grant you permissions.
@@ -345,6 +345,7 @@ DAO provides powerful capabilities for building production-ready AI agents:
345
345
 
346
346
  | Feature | Description |
347
347
  |---------|-------------|
348
+ | **Dual Deployment Targets** | Deploy to Databricks Model Serving or Databricks Apps with a single config |
348
349
  | **Multi-Tool Support** | Python functions, Unity Catalog, MCP, Agent Endpoints |
349
350
  | **On-Behalf-Of User** | Per-user permissions and governance |
350
351
  | **Advanced Caching** | Two-tier (LRU + Semantic) caching for cost optimization |
@@ -1,7 +1,7 @@
1
1
  dao_ai/__init__.py,sha256=18P98ExEgUaJ1Byw440Ct1ty59v6nxyWtc5S6Uq2m9Q,1062
2
2
  dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
3
3
  dao_ai/cli.py,sha256=1Ox8qjLKRlrKu2YXozm0lWoeZnDCouECeZSGVPkQgIQ,50923
4
- dao_ai/config.py,sha256=9G_JiPbr_ihUCaqYPvnMbzLKtyppXTjraQfVOxnqeBA,129323
4
+ dao_ai/config.py,sha256=7MDuX7xGSyDuBpdFZbKNDUPuTiuVe9onnUEGFtDI0jc,130123
5
5
  dao_ai/graph.py,sha256=1-uQlo7iXZQTT3uU8aYu0N5rnhw5_g_2YLwVsAs6M-U,1119
6
6
  dao_ai/logging.py,sha256=lYy4BmucCHvwW7aI3YQkQXKJtMvtTnPDu9Hnd7_O4oc,1556
7
7
  dao_ai/messages.py,sha256=4ZBzO4iFdktGSLrmhHzFjzMIt2tpaL-aQLHOQJysGnY,6959
@@ -14,7 +14,7 @@ dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  dao_ai/utils.py,sha256=_Urd7Nj2VzrgPKf3NS4E6vt0lWRhEUddBqWN9BksqeE,11543
15
15
  dao_ai/vector_search.py,sha256=8d3xROg9zSIYNXjRRl6rSexsJTlufjRl5Fy1ZA8daKA,4019
16
16
  dao_ai/apps/__init__.py,sha256=RLuhZf4gQ4pemwKDz1183aXib8UfaRhwfKvRx68GRlM,661
17
- dao_ai/apps/handlers.py,sha256=nbJZOgmnHG5icR4Pb56jxIWsm_AGnsURgViMJX2_LTU,2608
17
+ dao_ai/apps/handlers.py,sha256=6-IhhklHSPnS8aqKp155wPaSnYWTU1BSOPwbdWYBkFU,3594
18
18
  dao_ai/apps/model_serving.py,sha256=XLt3_0pGSRceMK6YtOrND9Jnh7mKLPCtwjVDLIaptQU,847
19
19
  dao_ai/apps/resources.py,sha256=5l6UxfMq6uspOql-HNDyUikfqRAa9eH_TiJHrGgMb6s,40029
20
20
  dao_ai/apps/server.py,sha256=neWbVnC2z9f-tJZBnho70FytNDEVOdOM1YngoGc5KHI,1264
@@ -58,18 +58,18 @@ dao_ai/tools/__init__.py,sha256=NfRpAKds_taHbx6gzLPWgtPXve-YpwzkoOAUflwxceM,1734
58
58
  dao_ai/tools/agent.py,sha256=plIWALywRjaDSnot13nYehBsrHRpBUpsVZakoGeajOE,1858
59
59
  dao_ai/tools/core.py,sha256=bRIN3BZhRQX8-Kpu3HPomliodyskCqjxynQmYbk6Vjs,3783
60
60
  dao_ai/tools/email.py,sha256=A3TsCoQgJR7UUWR0g45OPRGDpVoYwctFs1MOZMTt_d4,7389
61
- dao_ai/tools/genie.py,sha256=4e_5MeAe7kDzHbYeXuNPFbY5z8ci3ouj8l5254CZ2lA,8874
62
- dao_ai/tools/mcp.py,sha256=tfn-sdKwfNY31RsDFlafdGyN4XlKGfniXG_mO-Meh4E,21030
61
+ dao_ai/tools/genie.py,sha256=b0R51N5D58H1vpOCUCA88ALjLs58KSMn6nl80ap8_c0,11009
62
+ dao_ai/tools/mcp.py,sha256=K1yMQ39UgJ0Q4xhMpNWV3AVNx929w9vxZlLoCq_jrws,22016
63
63
  dao_ai/tools/memory.py,sha256=lwObKimAand22Nq3Y63tsv-AXQ5SXUigN9PqRjoWKes,1836
64
64
  dao_ai/tools/python.py,sha256=jWFnZPni2sCdtd8D1CqXnZIPHnWkdK27bCJnBXpzhvo,1879
65
65
  dao_ai/tools/search.py,sha256=cJ3D9FKr1GAR6xz55dLtRkjtQsI0WRueGt9TPDFpOxc,433
66
- dao_ai/tools/slack.py,sha256=QpLMXDApjPKyRpEanLp0tOhCp9WXaEBa615p4t0pucs,5040
67
- dao_ai/tools/sql.py,sha256=tKd1gjpLuKdQDyfmyYYtMiNRHDW6MGRbdEVaeqyB8Ok,7632
66
+ dao_ai/tools/slack.py,sha256=QnMsA7cYD1MnEcqGqqSr6bKIhV0RgDpkyaiPmDqnAts,5433
67
+ dao_ai/tools/sql.py,sha256=FG-Aa0FAUAnhCuZvao1J-y-cMM6bU5eCujNbsYn0xDw,7864
68
68
  dao_ai/tools/time.py,sha256=tufJniwivq29y0LIffbgeBTIDE6VgrLpmVf8Qr90qjw,9224
69
- dao_ai/tools/unity_catalog.py,sha256=AjQfW7bvV8NurqDLIyntYRv2eJuTwNdbvex1L5CRjOk,15534
70
- dao_ai/tools/vector_search.py,sha256=oe2uBwl2TfeJIXPpwiS6Rmz7wcHczSxNyqS9P3hE6co,14542
71
- dao_ai-0.1.12.dist-info/METADATA,sha256=BhkwtDjbzohpk86ICfQP2qAeNLsvo9kBbgwzpnB_WZQ,16698
72
- dao_ai-0.1.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
73
- dao_ai-0.1.12.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
74
- dao_ai-0.1.12.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
75
- dao_ai-0.1.12.dist-info/RECORD,,
69
+ dao_ai/tools/unity_catalog.py,sha256=oBlW6pH-Ne08g60QW9wVi_tyeVYDiecuNoxQbIIFmN8,16515
70
+ dao_ai/tools/vector_search.py,sha256=LF_72vlEF6TwUjKVv6nkUetLK766l9Kl6DQQTc9ebJI,15888
71
+ dao_ai-0.1.14.dist-info/METADATA,sha256=3cgCatKya02uIxRs9fP-P2R_GbV3DfrQ-_JsknH0kkg,16830
72
+ dao_ai-0.1.14.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
73
+ dao_ai-0.1.14.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
74
+ dao_ai-0.1.14.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
75
+ dao_ai-0.1.14.dist-info/RECORD,,