dao-ai 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/apps/handlers.py CHANGED
@@ -14,7 +14,7 @@ from typing import AsyncGenerator
14
14
 
15
15
  import mlflow
16
16
  from dotenv import load_dotenv
17
- from mlflow.genai.agent_server import invoke, stream
17
+ from mlflow.genai.agent_server import get_request_headers, invoke, stream
18
18
  from mlflow.types.responses import (
19
19
  ResponsesAgentRequest,
20
20
  ResponsesAgentResponse,
@@ -25,6 +25,23 @@ from dao_ai.config import AppConfig
25
25
  from dao_ai.logging import configure_logging
26
26
  from dao_ai.models import LanggraphResponsesAgent
27
27
 
28
+
29
+ def _inject_headers_into_request(request: ResponsesAgentRequest) -> None:
30
+ """Inject request headers into custom_inputs for Context propagation.
31
+
32
+ Captures headers from the MLflow AgentServer context (where they're available)
33
+ and injects them into request.custom_inputs.configurable.headers so they
34
+ flow through to Context and can be used for OBO authentication.
35
+ """
36
+ headers: dict[str, str] = get_request_headers()
37
+ if headers:
38
+ if request.custom_inputs is None:
39
+ request.custom_inputs = {}
40
+ if "configurable" not in request.custom_inputs:
41
+ request.custom_inputs["configurable"] = {}
42
+ request.custom_inputs["configurable"]["headers"] = headers
43
+
44
+
28
45
  # Load environment variables from .env.local if it exists
29
46
  load_dotenv(dotenv_path=".env.local", override=True)
30
47
 
@@ -61,6 +78,8 @@ async def non_streaming(request: ResponsesAgentRequest) -> ResponsesAgentRespons
61
78
  Returns:
62
79
  ResponsesAgentResponse with the complete output
63
80
  """
81
+ # Capture headers while in the AgentServer async context (before they're lost)
82
+ _inject_headers_into_request(request)
64
83
  return await _responses_agent.apredict(request)
65
84
 
66
85
 
@@ -80,5 +99,7 @@ async def streaming(
80
99
  Yields:
81
100
  ResponsesAgentStreamEvent objects as they are generated
82
101
  """
102
+ # Capture headers while in the AgentServer async context (before they're lost)
103
+ _inject_headers_into_request(request)
83
104
  async for event in _responses_agent.apredict_stream(request):
84
105
  yield event
dao_ai/cli.py CHANGED
@@ -63,16 +63,26 @@ def detect_cloud_provider(profile: Optional[str] = None) -> Optional[str]:
63
63
  Cloud provider string ('azure', 'aws', 'gcp') or None if detection fails
64
64
  """
65
65
  try:
66
+ import os
66
67
  from databricks.sdk import WorkspaceClient
67
68
 
69
+ # Check for environment variables that might override profile
70
+ if profile and os.environ.get("DATABRICKS_HOST"):
71
+ logger.warning(
72
+ f"DATABRICKS_HOST environment variable is set, which may override --profile {profile}"
73
+ )
74
+
68
75
  # Create workspace client with optional profile
69
76
  if profile:
77
+ logger.debug(f"Creating WorkspaceClient with profile: {profile}")
70
78
  w = WorkspaceClient(profile=profile)
71
79
  else:
80
+ logger.debug("Creating WorkspaceClient with default/ambient credentials")
72
81
  w = WorkspaceClient()
73
82
 
74
83
  # Get the workspace URL from config
75
84
  host = w.config.host
85
+ logger.debug(f"WorkspaceClient host: {host}, profile used: {profile}")
76
86
  if not host:
77
87
  logger.warning("Could not determine workspace URL for cloud detection")
78
88
  return None
@@ -1143,7 +1153,7 @@ def run_databricks_command(
1143
1153
  app_config: AppConfig = AppConfig.from_file(config_path) if config_path else None
1144
1154
  normalized_name: str = normalize_name(app_config.app.name) if app_config else None
1145
1155
 
1146
- # Auto-detect cloud provider if not specified
1156
+ # Auto-detect cloud provider if not specified (used for node_type selection)
1147
1157
  if not cloud:
1148
1158
  cloud = detect_cloud_provider(profile)
1149
1159
  if cloud:
@@ -1156,10 +1166,12 @@ def run_databricks_command(
1156
1166
  if config_path and app_config:
1157
1167
  generate_bundle_from_template(config_path, normalized_name)
1158
1168
 
1159
- # Use cloud as target (azure, aws, gcp) - can be overridden with explicit --target
1169
+ # Use app-specific cloud target: {app_name}-{cloud}
1170
+ # This ensures each app has unique deployment identity while supporting cloud-specific settings
1171
+ # Can be overridden with explicit --target
1160
1172
  if not target:
1161
- target = cloud
1162
- logger.debug(f"Using cloud-based target: {target}")
1173
+ target = f"{normalized_name}-{cloud}"
1174
+ logger.info(f"Using app-specific cloud target: {target}")
1163
1175
 
1164
1176
  # Build databricks command
1165
1177
  # --profile is a global flag, --target is a subcommand flag for 'bundle'
dao_ai/config.py CHANGED
@@ -7,6 +7,7 @@ from enum import Enum
7
7
  from os import PathLike
8
8
  from pathlib import Path
9
9
  from typing import (
10
+ TYPE_CHECKING,
10
11
  Any,
11
12
  Callable,
12
13
  Iterator,
@@ -18,6 +19,9 @@ from typing import (
18
19
  Union,
19
20
  )
20
21
 
22
+ if TYPE_CHECKING:
23
+ from dao_ai.state import Context
24
+
21
25
  from databricks.sdk import WorkspaceClient
22
26
  from databricks.sdk.credentials_provider import (
23
27
  CredentialsStrategy,
@@ -284,8 +288,8 @@ class IsDatabricksResource(ABC, BaseModel):
284
288
 
285
289
  Authentication priority:
286
290
  1. On-Behalf-Of User (on_behalf_of_user=True):
287
- - Forwarded headers (Databricks Apps)
288
- - ModelServingUserCredentials (Model Serving)
291
+ - Uses ModelServingUserCredentials (Model Serving)
292
+ - For Databricks Apps with headers, use workspace_client_from(context)
289
293
  2. Service Principal (client_id + client_secret + workspace_host)
290
294
  3. PAT (pat + workspace_host)
291
295
  4. Ambient/default authentication
@@ -294,37 +298,6 @@ class IsDatabricksResource(ABC, BaseModel):
294
298
 
295
299
  # Check for OBO first (highest priority)
296
300
  if self.on_behalf_of_user:
297
- # In Databricks Apps, use forwarded headers for per-user auth
298
- from mlflow.genai.agent_server import get_request_headers
299
-
300
- headers = get_request_headers()
301
- logger.debug(f"Headers received: {list(headers.keys())}")
302
- # Try both lowercase and title-case header names (HTTP headers are case-insensitive)
303
- forwarded_token = headers.get("x-forwarded-access-token") or headers.get(
304
- "X-Forwarded-Access-Token"
305
- )
306
-
307
- if forwarded_token:
308
- forwarded_user = headers.get("x-forwarded-user") or headers.get(
309
- "X-Forwarded-User", "unknown"
310
- )
311
- logger.debug(
312
- f"Creating WorkspaceClient for {self.__class__.__name__} "
313
- f"with OBO using forwarded token from Databricks Apps",
314
- forwarded_user=forwarded_user,
315
- )
316
- # Use workspace_host if configured, otherwise SDK will auto-detect
317
- workspace_host_value: str | None = (
318
- normalize_host(value_of(self.workspace_host))
319
- if self.workspace_host
320
- else None
321
- )
322
- return WorkspaceClient(
323
- host=workspace_host_value,
324
- token=forwarded_token,
325
- auth_type="pat",
326
- )
327
-
328
301
  credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
329
302
  logger.debug(
330
303
  f"Creating WorkspaceClient for {self.__class__.__name__} "
@@ -383,6 +356,57 @@ class IsDatabricksResource(ABC, BaseModel):
383
356
  )
384
357
  return WorkspaceClient()
385
358
 
359
+ def workspace_client_from(self, context: "Context | None") -> WorkspaceClient:
360
+ """
361
+ Get a WorkspaceClient using headers from the provided Context.
362
+
363
+ Use this method from tools that have access to ToolRuntime[Context].
364
+ This allows OBO authentication to work in Databricks Apps where headers
365
+ are captured at request entry and passed through the Context.
366
+
367
+ Args:
368
+ context: Runtime context containing headers for OBO auth.
369
+ If None or no headers, falls back to workspace_client property.
370
+
371
+ Returns:
372
+ WorkspaceClient configured with appropriate authentication.
373
+ """
374
+ from dao_ai.utils import normalize_host
375
+
376
+ logger.trace(f"workspace_client_from called", context=context, on_behalf_of_user=self.on_behalf_of_user)
377
+
378
+ # Check if we have headers in context for OBO
379
+ if context and context.headers and self.on_behalf_of_user:
380
+ headers = context.headers
381
+ # Try both lowercase and title-case header names (HTTP headers are case-insensitive)
382
+ forwarded_token: str = headers.get("x-forwarded-access-token") or headers.get(
383
+ "X-Forwarded-Access-Token"
384
+ )
385
+
386
+ if forwarded_token:
387
+ forwarded_user = headers.get("x-forwarded-user") or headers.get(
388
+ "X-Forwarded-User", "unknown"
389
+ )
390
+ logger.debug(
391
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
392
+ f"with OBO using forwarded token from Context",
393
+ forwarded_user=forwarded_user,
394
+ )
395
+ # Use workspace_host if configured, otherwise SDK will auto-detect
396
+ workspace_host_value: str | None = (
397
+ normalize_host(value_of(self.workspace_host))
398
+ if self.workspace_host
399
+ else None
400
+ )
401
+ return WorkspaceClient(
402
+ host=workspace_host_value,
403
+ token=forwarded_token,
404
+ auth_type="pat",
405
+ )
406
+
407
+ # Fall back to existing workspace_client property
408
+ return self.workspace_client
409
+
386
410
 
387
411
  class DeploymentTarget(str, Enum):
388
412
  """Target platform for agent deployment."""
dao_ai/tools/genie.py CHANGED
@@ -139,29 +139,53 @@ Returns:
139
139
  GenieResponse: A response object containing the conversation ID and result from Genie."""
140
140
  tool_description = tool_description + function_docs
141
141
 
142
- genie: Genie = Genie(
143
- space_id=space_id,
144
- client=genie_room.workspace_client,
145
- truncate_results=truncate_results,
146
- )
142
+ # Cache for genie service - created lazily on first call
143
+ # This allows us to use workspace_client_from with runtime context for OBO
144
+ _cached_genie_service: GenieServiceBase | None = None
145
+
146
+ def _get_genie_service(context: Context | None) -> GenieServiceBase:
147
+ """Get or create the Genie service, using context for OBO auth if available."""
148
+ nonlocal _cached_genie_service
149
+
150
+ # Use cached service if available (for non-OBO or after first call)
151
+ # For OBO, we need fresh workspace client each time to use the user's token
152
+ if _cached_genie_service is not None and not genie_room.on_behalf_of_user:
153
+ return _cached_genie_service
154
+
155
+ # Get workspace client using context for OBO support
156
+ from databricks.sdk import WorkspaceClient
147
157
 
148
- genie_service: GenieServiceBase = GenieService(genie)
149
-
150
- # Wrap with semantic cache first (checked second due to decorator pattern)
151
- if semantic_cache_parameters is not None:
152
- genie_service = SemanticCacheService(
153
- impl=genie_service,
154
- parameters=semantic_cache_parameters,
155
- workspace_client=genie_room.workspace_client, # Pass workspace client for conversation history
156
- ).initialize() # Eagerly initialize to fail fast and create table
157
-
158
- # Wrap with LRU cache last (checked first - fast O(1) exact match)
159
- if lru_cache_parameters is not None:
160
- genie_service = LRUCacheService(
161
- impl=genie_service,
162
- parameters=lru_cache_parameters,
158
+ workspace_client: WorkspaceClient = genie_room.workspace_client_from(context)
159
+
160
+ genie: Genie = Genie(
161
+ space_id=space_id,
162
+ client=workspace_client,
163
+ truncate_results=truncate_results,
163
164
  )
164
165
 
166
+ genie_service: GenieServiceBase = GenieService(genie)
167
+
168
+ # Wrap with semantic cache first (checked second due to decorator pattern)
169
+ if semantic_cache_parameters is not None:
170
+ genie_service = SemanticCacheService(
171
+ impl=genie_service,
172
+ parameters=semantic_cache_parameters,
173
+ workspace_client=workspace_client,
174
+ ).initialize()
175
+
176
+ # Wrap with LRU cache last (checked first - fast O(1) exact match)
177
+ if lru_cache_parameters is not None:
178
+ genie_service = LRUCacheService(
179
+ impl=genie_service,
180
+ parameters=lru_cache_parameters,
181
+ )
182
+
183
+ # Cache for non-OBO scenarios
184
+ if not genie_room.on_behalf_of_user:
185
+ _cached_genie_service = genie_service
186
+
187
+ return genie_service
188
+
165
189
  @tool(
166
190
  name_or_callable=tool_name,
167
191
  description=tool_description,
@@ -177,6 +201,10 @@ GenieResponse: A response object containing the conversation ID and result from
177
201
  # Access state through runtime
178
202
  state: AgentState = runtime.state
179
203
  tool_call_id: str = runtime.tool_call_id
204
+ context: Context | None = runtime.context
205
+
206
+ # Get genie service with OBO support via context
207
+ genie_service: GenieServiceBase = _get_genie_service(context)
180
208
 
181
209
  # Ensure space_id is a string for state keys
182
210
  space_id_str: str = str(space_id)
@@ -194,6 +222,14 @@ GenieResponse: A response object containing the conversation ID and result from
194
222
  conversation_id=existing_conversation_id,
195
223
  )
196
224
 
225
+ # Log the prompt being sent to Genie
226
+ logger.trace(
227
+ "Sending prompt to Genie",
228
+ space_id=space_id_str,
229
+ conversation_id=existing_conversation_id,
230
+ prompt=question[:500] + "..." if len(question) > 500 else question,
231
+ )
232
+
197
233
  # Call ask_question which always returns CacheResult with cache metadata
198
234
  cache_result: CacheResult = genie_service.ask_question(
199
235
  question, conversation_id=existing_conversation_id
@@ -211,6 +247,22 @@ GenieResponse: A response object containing the conversation ID and result from
211
247
  cache_key=cache_key,
212
248
  )
213
249
 
250
+ # Log truncated response for debugging
251
+ result_preview: str = str(genie_response.result)
252
+ if len(result_preview) > 500:
253
+ result_preview = result_preview[:500] + "..."
254
+ logger.trace(
255
+ "Genie response content",
256
+ question=question[:100] + "..." if len(question) > 100 else question,
257
+ query=genie_response.query,
258
+ description=(
259
+ genie_response.description[:200] + "..."
260
+ if genie_response.description and len(genie_response.description) > 200
261
+ else genie_response.description
262
+ ),
263
+ result_preview=result_preview,
264
+ )
265
+
214
266
  # Update session state with cache information
215
267
  if persist_conversation:
216
268
  session.genie.update_space(
dao_ai/tools/mcp.py CHANGED
@@ -30,6 +30,7 @@ from dao_ai.config import (
30
30
  McpFunctionModel,
31
31
  TransportType,
32
32
  )
33
+ from dao_ai.state import Context
33
34
 
34
35
 
35
36
  @dataclass
@@ -173,6 +174,7 @@ def _get_auth_resource(function: McpFunctionModel) -> IsDatabricksResource:
173
174
 
174
175
  def _build_connection_config(
175
176
  function: McpFunctionModel,
177
+ context: Context | None = None,
176
178
  ) -> dict[str, Any]:
177
179
  """
178
180
  Build the connection configuration dictionary for MultiServerMCPClient.
@@ -193,6 +195,7 @@ def _build_connection_config(
193
195
 
194
196
  Args:
195
197
  function: The MCP function model configuration.
198
+ context: Optional runtime context with headers for OBO auth.
196
199
 
197
200
  Returns:
198
201
  A dictionary containing the transport-specific connection settings.
@@ -205,14 +208,17 @@ def _build_connection_config(
205
208
  }
206
209
 
207
210
  # For HTTP transport, use DatabricksOAuthClientProvider with unified auth
211
+ from databricks.sdk import WorkspaceClient
208
212
  from databricks_mcp import DatabricksOAuthClientProvider
209
213
 
210
214
  # Get the resource to use for authentication
211
- auth_resource = _get_auth_resource(function)
215
+ auth_resource: IsDatabricksResource = _get_auth_resource(function)
212
216
 
213
- # Get workspace client from the auth resource
214
- workspace_client = auth_resource.workspace_client
215
- auth_provider = DatabricksOAuthClientProvider(workspace_client)
217
+ # Get workspace client from the auth resource with OBO support via context
218
+ workspace_client: WorkspaceClient = auth_resource.workspace_client_from(context)
219
+ auth_provider: DatabricksOAuthClientProvider = DatabricksOAuthClientProvider(
220
+ workspace_client
221
+ )
216
222
 
217
223
  # Log which resource is providing auth
218
224
  resource_name = (
@@ -509,19 +515,28 @@ async def acreate_mcp_tools(
509
515
  def _create_tool_wrapper(mcp_tool: Tool) -> RunnableLike:
510
516
  """
511
517
  Create a LangChain tool wrapper for an MCP tool.
518
+
519
+ Supports OBO authentication via context headers.
512
520
  """
521
+ from langchain.tools import ToolRuntime
513
522
 
514
523
  @create_tool(
515
524
  mcp_tool.name,
516
525
  description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
517
526
  args_schema=mcp_tool.inputSchema,
518
527
  )
519
- async def tool_wrapper(**kwargs: Any) -> str:
528
+ async def tool_wrapper(
529
+ runtime: ToolRuntime[Context] = None,
530
+ **kwargs: Any,
531
+ ) -> str:
520
532
  """Execute MCP tool with fresh session."""
521
533
  logger.trace("Invoking MCP tool", tool_name=mcp_tool.name, args=kwargs)
522
534
 
523
- invocation_client = MultiServerMCPClient(
524
- {"mcp_function": _build_connection_config(function)}
535
+ # Get context for OBO support
536
+ context: Context | None = runtime.context if runtime else None
537
+
538
+ invocation_client: MultiServerMCPClient = MultiServerMCPClient(
539
+ {"mcp_function": _build_connection_config(function, context)}
525
540
  )
526
541
 
527
542
  try:
@@ -530,7 +545,7 @@ async def acreate_mcp_tools(
530
545
  mcp_tool.name, kwargs
531
546
  )
532
547
 
533
- text_result = _extract_text_content(result)
548
+ text_result: str = _extract_text_content(result)
534
549
 
535
550
  logger.trace(
536
551
  "MCP tool completed",
@@ -625,20 +640,28 @@ def create_mcp_tools(
625
640
  This wrapper handles:
626
641
  - Fresh session creation per invocation (stateless)
627
642
  - Content extraction to plain text (avoiding extra fields)
643
+ - OBO authentication via context headers
628
644
  """
645
+ from langchain.tools import ToolRuntime
629
646
 
630
647
  @create_tool(
631
648
  mcp_tool.name,
632
649
  description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
633
650
  args_schema=mcp_tool.inputSchema,
634
651
  )
635
- async def tool_wrapper(**kwargs: Any) -> str:
652
+ async def tool_wrapper(
653
+ runtime: ToolRuntime[Context] = None,
654
+ **kwargs: Any,
655
+ ) -> str:
636
656
  """Execute MCP tool with fresh session."""
637
657
  logger.trace("Invoking MCP tool", tool_name=mcp_tool.name, args=kwargs)
638
658
 
639
- # Create a fresh client/session for each invocation
640
- invocation_client = MultiServerMCPClient(
641
- {"mcp_function": _build_connection_config(function)}
659
+ # Get context for OBO support
660
+ context: Context | None = runtime.context if runtime else None
661
+
662
+ # Create a fresh client/session for each invocation with OBO support
663
+ invocation_client: MultiServerMCPClient = MultiServerMCPClient(
664
+ {"mcp_function": _build_connection_config(function, context)}
642
665
  )
643
666
 
644
667
  try:
@@ -648,7 +671,7 @@ def create_mcp_tools(
648
671
  )
649
672
 
650
673
  # Extract text content, avoiding extra fields
651
- text_result = _extract_text_content(result)
674
+ text_result: str = _extract_text_content(result)
652
675
 
653
676
  logger.trace(
654
677
  "MCP tool completed",
dao_ai/tools/slack.py CHANGED
@@ -1,11 +1,13 @@
1
1
  from typing import Any, Callable, Optional
2
2
 
3
3
  from databricks.sdk.service.serving import ExternalFunctionRequestHttpMethod
4
+ from langchain.tools import ToolRuntime
4
5
  from langchain_core.tools import tool
5
6
  from loguru import logger
6
7
  from requests import Response
7
8
 
8
9
  from dao_ai.config import ConnectionModel
10
+ from dao_ai.state import Context
9
11
 
10
12
 
11
13
  def _find_channel_id_by_name(
@@ -129,8 +131,17 @@ def create_send_slack_message_tool(
129
131
  name_or_callable=name,
130
132
  description=description,
131
133
  )
132
- def send_slack_message(text: str) -> str:
133
- response: Response = connection.workspace_client.serving_endpoints.http_request(
134
+ def send_slack_message(
135
+ text: str,
136
+ runtime: ToolRuntime[Context] = None,
137
+ ) -> str:
138
+ from databricks.sdk import WorkspaceClient
139
+
140
+ # Get workspace client with OBO support via context
141
+ context: Context | None = runtime.context if runtime else None
142
+ workspace_client: WorkspaceClient = connection.workspace_client_from(context)
143
+
144
+ response: Response = workspace_client.serving_endpoints.http_request(
134
145
  conn=connection.name,
135
146
  method=ExternalFunctionRequestHttpMethod.POST,
136
147
  path="/api/chat.postMessage",
dao_ai/tools/sql.py CHANGED
@@ -7,10 +7,11 @@ pre-configured SQL statements against a Databricks SQL warehouse.
7
7
 
8
8
  from databricks.sdk import WorkspaceClient
9
9
  from databricks.sdk.service.sql import StatementResponse, StatementState
10
- from langchain.tools import tool
10
+ from langchain.tools import ToolRuntime, tool
11
11
  from loguru import logger
12
12
 
13
13
  from dao_ai.config import WarehouseModel, value_of
14
+ from dao_ai.state import Context
14
15
 
15
16
 
16
17
  def create_execute_statement_tool(
@@ -63,7 +64,6 @@ def create_execute_statement_tool(
63
64
  description = f"Execute a pre-configured SQL query against the {warehouse.name} warehouse and return the results."
64
65
 
65
66
  warehouse_id: str = value_of(warehouse.warehouse_id)
66
- workspace_client: WorkspaceClient = warehouse.workspace_client
67
67
 
68
68
  logger.debug(
69
69
  "Creating SQL execution tool",
@@ -74,7 +74,7 @@ def create_execute_statement_tool(
74
74
  )
75
75
 
76
76
  @tool(name_or_callable=name, description=description)
77
- def execute_statement_tool() -> str:
77
+ def execute_statement_tool(runtime: ToolRuntime[Context] = None) -> str:
78
78
  """
79
79
  Execute the pre-configured SQL statement against the Databricks SQL warehouse.
80
80
 
@@ -88,6 +88,10 @@ def create_execute_statement_tool(
88
88
  sql_preview=statement[:100] + "..." if len(statement) > 100 else statement,
89
89
  )
90
90
 
91
+ # Get workspace client with OBO support via context
92
+ context: Context | None = runtime.context if runtime else None
93
+ workspace_client: WorkspaceClient = warehouse.workspace_client_from(context)
94
+
91
95
  try:
92
96
  # Execute the SQL statement
93
97
  statement_response: StatementResponse = (
@@ -1,10 +1,11 @@
1
- from typing import Any, Dict, Optional, Sequence, Set
1
+ from typing import Annotated, Any, Dict, Optional, Sequence, Set
2
2
 
3
3
  from databricks.sdk import WorkspaceClient
4
4
  from databricks.sdk.service.catalog import FunctionInfo, PermissionsChange, Privilege
5
5
  from databricks_langchain import DatabricksFunctionClient, UCFunctionToolkit
6
+ from langchain.tools import ToolRuntime
6
7
  from langchain_core.runnables.base import RunnableLike
7
- from langchain_core.tools import StructuredTool
8
+ from langchain_core.tools import InjectedToolArg, StructuredTool
8
9
  from loguru import logger
9
10
  from pydantic import BaseModel
10
11
  from unitycatalog.ai.core.base import FunctionExecutionResult
@@ -15,6 +16,7 @@ from dao_ai.config import (
15
16
  UnityCatalogFunctionModel,
16
17
  value_of,
17
18
  )
19
+ from dao_ai.state import Context
18
20
  from dao_ai.utils import normalize_host
19
21
 
20
22
 
@@ -35,13 +37,11 @@ def create_uc_tools(
35
37
  A sequence of BaseTool objects that wrap the specified UC functions
36
38
  """
37
39
  original_function_model: UnityCatalogFunctionModel | None = None
38
- workspace_client: WorkspaceClient | None = None
39
40
  function_name: str
40
41
 
41
42
  if isinstance(function, UnityCatalogFunctionModel):
42
43
  original_function_model = function
43
44
  function_name = function.resource.full_name
44
- workspace_client = function.resource.workspace_client
45
45
  else:
46
46
  function_name = function
47
47
 
@@ -56,6 +56,12 @@ def create_uc_tools(
56
56
  # Use with_partial_args directly with UnityCatalogFunctionModel
57
57
  tools = [with_partial_args(original_function_model)]
58
58
  else:
59
+ # For standard UC toolkit, we need workspace_client at creation time
60
+ # Use the resource's workspace_client (will use ambient auth if no OBO)
61
+ workspace_client: WorkspaceClient | None = None
62
+ if original_function_model:
63
+ workspace_client = original_function_model.resource.workspace_client
64
+
59
65
  # Fallback to standard UC toolkit approach
60
66
  client: DatabricksFunctionClient = DatabricksFunctionClient(
61
67
  client=workspace_client
@@ -356,7 +362,6 @@ def with_partial_args(
356
362
  # Get function info from the resource
357
363
  function_name: str = uc_function.resource.full_name
358
364
  tool_name: str = uc_function.resource.name or function_name.replace(".", "_")
359
- workspace_client: WorkspaceClient = uc_function.resource.workspace_client
360
365
 
361
366
  logger.debug(
362
367
  "Creating UC tool with partial args",
@@ -365,7 +370,7 @@ def with_partial_args(
365
370
  partial_args=list(resolved_args.keys()),
366
371
  )
367
372
 
368
- # Grant permissions if we have credentials
373
+ # Grant permissions if we have credentials (using ambient auth for setup)
369
374
  if "client_id" in resolved_args:
370
375
  client_id: str = resolved_args["client_id"]
371
376
  host: Optional[str] = resolved_args.get("host")
@@ -376,14 +381,18 @@ def with_partial_args(
376
381
  "Failed to grant permissions", function_name=function_name, error=str(e)
377
382
  )
378
383
 
379
- # Create the client for function execution using the resource's workspace client
380
- client: DatabricksFunctionClient = DatabricksFunctionClient(client=workspace_client)
384
+ # Get workspace client for schema introspection (uses ambient auth at definition time)
385
+ # Actual execution will use OBO via context
386
+ setup_workspace_client: WorkspaceClient = uc_function.resource.workspace_client
387
+ setup_client: DatabricksFunctionClient = DatabricksFunctionClient(
388
+ client=setup_workspace_client
389
+ )
381
390
 
382
391
  # Try to get the function schema for better tool definition
383
392
  schema_model: type[BaseModel]
384
393
  tool_description: str
385
394
  try:
386
- function_info: FunctionInfo = client.get_function(function_name)
395
+ function_info: FunctionInfo = setup_client.get_function(function_name)
387
396
  schema_info = generate_function_input_params_schema(function_info)
388
397
  tool_description = (
389
398
  function_info.comment or f"Unity Catalog function: {function_name}"
@@ -419,8 +428,21 @@ def with_partial_args(
419
428
  tool_description = f"Unity Catalog function: {function_name}"
420
429
 
421
430
  # Create a wrapper function that calls _execute_uc_function with partial args
422
- def uc_function_wrapper(**kwargs) -> str:
431
+ # Uses InjectedToolArg to ensure runtime is injected but hidden from the LLM
432
+ def uc_function_wrapper(
433
+ runtime: Annotated[ToolRuntime[Context], InjectedToolArg] = None,
434
+ **kwargs: Any,
435
+ ) -> str:
423
436
  """Wrapper function that executes Unity Catalog function with partial args."""
437
+ # Get workspace client with OBO support via context
438
+ context: Context | None = runtime.context if runtime else None
439
+ workspace_client: WorkspaceClient = uc_function.resource.workspace_client_from(
440
+ context
441
+ )
442
+ client: DatabricksFunctionClient = DatabricksFunctionClient(
443
+ client=workspace_client
444
+ )
445
+
424
446
  return _execute_uc_function(
425
447
  client=client,
426
448
  function_name=function_name,
@@ -7,13 +7,14 @@ with dynamic filter schemas based on table columns and FlashRank reranking suppo
7
7
 
8
8
  import json
9
9
  import os
10
- from typing import Any, Optional
10
+ from typing import Annotated, Any, Optional
11
11
 
12
12
  import mlflow
13
13
  from databricks.sdk import WorkspaceClient
14
14
  from databricks.vector_search.reranker import DatabricksReranker
15
15
  from databricks_langchain import DatabricksVectorSearch
16
16
  from flashrank import Ranker, RerankRequest
17
+ from langchain.tools import ToolRuntime, tool
17
18
  from langchain_core.documents import Document
18
19
  from langchain_core.tools import StructuredTool
19
20
  from loguru import logger
@@ -27,6 +28,7 @@ from dao_ai.config import (
27
28
  VectorStoreModel,
28
29
  value_of,
29
30
  )
31
+ from dao_ai.state import Context
30
32
  from dao_ai.utils import normalize_host
31
33
 
32
34
  # Create FilterItem model at module level so it can be used in type hints
@@ -299,35 +301,67 @@ def create_vector_search_tool(
299
301
  client_args_keys=list(client_args.keys()) if client_args else [],
300
302
  )
301
303
 
302
- # Create DatabricksVectorSearch
303
- # Note: text_column should be None for Databricks-managed embeddings
304
- # (it's automatically determined from the index)
305
- vector_search: DatabricksVectorSearch = DatabricksVectorSearch(
306
- index_name=index_name,
307
- text_column=None,
308
- columns=columns,
309
- workspace_client=vector_store.workspace_client,
310
- client_args=client_args if client_args else None,
311
- primary_key=vector_store.primary_key,
312
- doc_uri=vector_store.doc_uri,
313
- include_score=True,
314
- reranker=(
315
- DatabricksReranker(columns_to_rerank=rerank_config.columns)
316
- if rerank_config and rerank_config.columns
317
- else None
318
- ),
319
- )
304
+ # Cache for DatabricksVectorSearch - created lazily for OBO support
305
+ _cached_vector_search: DatabricksVectorSearch | None = None
306
+
307
+ def _get_vector_search(context: Context | None) -> DatabricksVectorSearch:
308
+ """Get or create DatabricksVectorSearch, using context for OBO auth if available."""
309
+ nonlocal _cached_vector_search
310
+
311
+ # Use cached instance if available and not OBO
312
+ if _cached_vector_search is not None and not vector_store.on_behalf_of_user:
313
+ return _cached_vector_search
314
+
315
+ # Get workspace client with OBO support via context
316
+ workspace_client: WorkspaceClient = vector_store.workspace_client_from(context)
317
+
318
+ # Create DatabricksVectorSearch
319
+ # Note: text_column should be None for Databricks-managed embeddings
320
+ # (it's automatically determined from the index)
321
+ vs: DatabricksVectorSearch = DatabricksVectorSearch(
322
+ index_name=index_name,
323
+ text_column=None,
324
+ columns=columns,
325
+ workspace_client=workspace_client,
326
+ client_args=client_args if client_args else None,
327
+ primary_key=vector_store.primary_key,
328
+ doc_uri=vector_store.doc_uri,
329
+ include_score=True,
330
+ reranker=(
331
+ DatabricksReranker(columns_to_rerank=rerank_config.columns)
332
+ if rerank_config and rerank_config.columns
333
+ else None
334
+ ),
335
+ )
320
336
 
321
- # Create dynamic input schema
322
- input_schema: type[BaseModel] = _create_dynamic_input_schema(
323
- index_name, vector_store.workspace_client
324
- )
337
+ # Cache for non-OBO scenarios
338
+ if not vector_store.on_behalf_of_user:
339
+ _cached_vector_search = vs
340
+
341
+ return vs
342
+
343
+ # Determine tool name and description
344
+ tool_name: str = name or f"vector_search_{vector_store.index.name}"
345
+ tool_description: str = description or f"Search documents in {index_name}"
325
346
 
326
- # Define the tool function
347
+ # Use @tool decorator for proper ToolRuntime injection
348
+ # The decorator ensures runtime is automatically injected and hidden from the LLM
349
+ @tool(name_or_callable=tool_name, description=tool_description)
327
350
  def vector_search_func(
328
- query: str, filters: Optional[list[FilterItem]] = None
351
+ query: Annotated[str, "The search query to find relevant documents"],
352
+ filters: Annotated[
353
+ Optional[list[FilterItem]],
354
+ "Optional filters to apply to the search results",
355
+ ] = None,
356
+ runtime: ToolRuntime[Context] = None,
329
357
  ) -> str:
330
358
  """Search for relevant documents using vector similarity."""
359
+ # Get context for OBO support
360
+ context: Context | None = runtime.context if runtime else None
361
+
362
+ # Get vector search instance with OBO support
363
+ vector_search: DatabricksVectorSearch = _get_vector_search(context)
364
+
331
365
  # Convert FilterItem Pydantic models to dict format for DatabricksVectorSearch
332
366
  filters_dict: dict[str, Any] = {}
333
367
  if filters:
@@ -379,14 +413,6 @@ def create_vector_search_tool(
379
413
  # Return as JSON string
380
414
  return json.dumps(serialized_docs)
381
415
 
382
- # Create the StructuredTool
383
- tool: StructuredTool = StructuredTool.from_function(
384
- func=vector_search_func,
385
- name=name or f"vector_search_{vector_store.index.name}",
386
- description=description or f"Search documents in {index_name}",
387
- args_schema=input_schema,
388
- )
389
-
390
- logger.success("Vector search tool created", name=tool.name, index=index_name)
416
+ logger.success("Vector search tool created", name=tool_name, index=index_name)
391
417
 
392
- return tool
418
+ return vector_search_func
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dao-ai
3
- Version: 0.1.13
3
+ Version: 0.1.15
4
4
  Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
5
5
  Project-URL: Homepage, https://github.com/natefleming/dao-ai
6
6
  Project-URL: Documentation, https://natefleming.github.io/dao-ai
@@ -1,7 +1,7 @@
1
1
  dao_ai/__init__.py,sha256=18P98ExEgUaJ1Byw440Ct1ty59v6nxyWtc5S6Uq2m9Q,1062
2
2
  dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
3
- dao_ai/cli.py,sha256=1Ox8qjLKRlrKu2YXozm0lWoeZnDCouECeZSGVPkQgIQ,50923
4
- dao_ai/config.py,sha256=E2lwWro3A6c3cKLYyHZeqNz2X5vkXgLS8TfDlGL5o9M,129307
3
+ dao_ai/cli.py,sha256=6qwlS07_Tei6iEPXIJ-19cQVnLXd7vJDpuY4Qu0k96E,51634
4
+ dao_ai/config.py,sha256=Bpaj1iDuDarAnRnTMvIYjtfbewjOSfBppZ6Sp3Id0CM,130242
5
5
  dao_ai/graph.py,sha256=1-uQlo7iXZQTT3uU8aYu0N5rnhw5_g_2YLwVsAs6M-U,1119
6
6
  dao_ai/logging.py,sha256=lYy4BmucCHvwW7aI3YQkQXKJtMvtTnPDu9Hnd7_O4oc,1556
7
7
  dao_ai/messages.py,sha256=4ZBzO4iFdktGSLrmhHzFjzMIt2tpaL-aQLHOQJysGnY,6959
@@ -14,7 +14,7 @@ dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  dao_ai/utils.py,sha256=_Urd7Nj2VzrgPKf3NS4E6vt0lWRhEUddBqWN9BksqeE,11543
15
15
  dao_ai/vector_search.py,sha256=8d3xROg9zSIYNXjRRl6rSexsJTlufjRl5Fy1ZA8daKA,4019
16
16
  dao_ai/apps/__init__.py,sha256=RLuhZf4gQ4pemwKDz1183aXib8UfaRhwfKvRx68GRlM,661
17
- dao_ai/apps/handlers.py,sha256=nbJZOgmnHG5icR4Pb56jxIWsm_AGnsURgViMJX2_LTU,2608
17
+ dao_ai/apps/handlers.py,sha256=6-IhhklHSPnS8aqKp155wPaSnYWTU1BSOPwbdWYBkFU,3594
18
18
  dao_ai/apps/model_serving.py,sha256=XLt3_0pGSRceMK6YtOrND9Jnh7mKLPCtwjVDLIaptQU,847
19
19
  dao_ai/apps/resources.py,sha256=5l6UxfMq6uspOql-HNDyUikfqRAa9eH_TiJHrGgMb6s,40029
20
20
  dao_ai/apps/server.py,sha256=neWbVnC2z9f-tJZBnho70FytNDEVOdOM1YngoGc5KHI,1264
@@ -58,18 +58,18 @@ dao_ai/tools/__init__.py,sha256=NfRpAKds_taHbx6gzLPWgtPXve-YpwzkoOAUflwxceM,1734
58
58
  dao_ai/tools/agent.py,sha256=plIWALywRjaDSnot13nYehBsrHRpBUpsVZakoGeajOE,1858
59
59
  dao_ai/tools/core.py,sha256=bRIN3BZhRQX8-Kpu3HPomliodyskCqjxynQmYbk6Vjs,3783
60
60
  dao_ai/tools/email.py,sha256=A3TsCoQgJR7UUWR0g45OPRGDpVoYwctFs1MOZMTt_d4,7389
61
- dao_ai/tools/genie.py,sha256=4e_5MeAe7kDzHbYeXuNPFbY5z8ci3ouj8l5254CZ2lA,8874
62
- dao_ai/tools/mcp.py,sha256=tfn-sdKwfNY31RsDFlafdGyN4XlKGfniXG_mO-Meh4E,21030
61
+ dao_ai/tools/genie.py,sha256=b0R51N5D58H1vpOCUCA88ALjLs58KSMn6nl80ap8_c0,11009
62
+ dao_ai/tools/mcp.py,sha256=K1yMQ39UgJ0Q4xhMpNWV3AVNx929w9vxZlLoCq_jrws,22016
63
63
  dao_ai/tools/memory.py,sha256=lwObKimAand22Nq3Y63tsv-AXQ5SXUigN9PqRjoWKes,1836
64
64
  dao_ai/tools/python.py,sha256=jWFnZPni2sCdtd8D1CqXnZIPHnWkdK27bCJnBXpzhvo,1879
65
65
  dao_ai/tools/search.py,sha256=cJ3D9FKr1GAR6xz55dLtRkjtQsI0WRueGt9TPDFpOxc,433
66
- dao_ai/tools/slack.py,sha256=QpLMXDApjPKyRpEanLp0tOhCp9WXaEBa615p4t0pucs,5040
67
- dao_ai/tools/sql.py,sha256=tKd1gjpLuKdQDyfmyYYtMiNRHDW6MGRbdEVaeqyB8Ok,7632
66
+ dao_ai/tools/slack.py,sha256=QnMsA7cYD1MnEcqGqqSr6bKIhV0RgDpkyaiPmDqnAts,5433
67
+ dao_ai/tools/sql.py,sha256=FG-Aa0FAUAnhCuZvao1J-y-cMM6bU5eCujNbsYn0xDw,7864
68
68
  dao_ai/tools/time.py,sha256=tufJniwivq29y0LIffbgeBTIDE6VgrLpmVf8Qr90qjw,9224
69
- dao_ai/tools/unity_catalog.py,sha256=AjQfW7bvV8NurqDLIyntYRv2eJuTwNdbvex1L5CRjOk,15534
70
- dao_ai/tools/vector_search.py,sha256=oe2uBwl2TfeJIXPpwiS6Rmz7wcHczSxNyqS9P3hE6co,14542
71
- dao_ai-0.1.13.dist-info/METADATA,sha256=xQ1apcAp24Co2FBzFL6Hw5mCqQzsskmMw-br41NSJqk,16830
72
- dao_ai-0.1.13.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
73
- dao_ai-0.1.13.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
74
- dao_ai-0.1.13.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
75
- dao_ai-0.1.13.dist-info/RECORD,,
69
+ dao_ai/tools/unity_catalog.py,sha256=oBlW6pH-Ne08g60QW9wVi_tyeVYDiecuNoxQbIIFmN8,16515
70
+ dao_ai/tools/vector_search.py,sha256=LF_72vlEF6TwUjKVv6nkUetLK766l9Kl6DQQTc9ebJI,15888
71
+ dao_ai-0.1.15.dist-info/METADATA,sha256=-cei_FUcN2BBlatDymERXwoag5tuuoJvdOcBBnsF8qU,16830
72
+ dao_ai-0.1.15.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
73
+ dao_ai-0.1.15.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
74
+ dao_ai-0.1.15.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
75
+ dao_ai-0.1.15.dist-info/RECORD,,