dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +546 -37
  7. dao_ai/config.py +1179 -139
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +38 -0
  23. dao_ai/middleware/assertions.py +3 -3
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +4 -4
  26. dao_ai/middleware/guardrails.py +3 -3
  27. dao_ai/middleware/human_in_the_loop.py +3 -2
  28. dao_ai/middleware/message_validation.py +4 -4
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +1 -1
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/middleware/tool_selector.py +129 -0
  36. dao_ai/models.py +327 -370
  37. dao_ai/nodes.py +9 -16
  38. dao_ai/orchestration/core.py +33 -9
  39. dao_ai/orchestration/supervisor.py +29 -13
  40. dao_ai/orchestration/swarm.py +6 -1
  41. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  42. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  43. dao_ai/prompts/instruction_reranker.yaml +14 -0
  44. dao_ai/prompts/router.yaml +37 -0
  45. dao_ai/prompts/verifier.yaml +46 -0
  46. dao_ai/providers/base.py +28 -2
  47. dao_ai/providers/databricks.py +363 -33
  48. dao_ai/state.py +1 -0
  49. dao_ai/tools/__init__.py +5 -3
  50. dao_ai/tools/genie.py +103 -26
  51. dao_ai/tools/instructed_retriever.py +366 -0
  52. dao_ai/tools/instruction_reranker.py +202 -0
  53. dao_ai/tools/mcp.py +539 -97
  54. dao_ai/tools/router.py +89 -0
  55. dao_ai/tools/slack.py +13 -2
  56. dao_ai/tools/sql.py +7 -3
  57. dao_ai/tools/unity_catalog.py +32 -10
  58. dao_ai/tools/vector_search.py +493 -160
  59. dao_ai/tools/verifier.py +159 -0
  60. dao_ai/utils.py +182 -2
  61. dao_ai/vector_search.py +46 -1
  62. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
  63. dao_ai-0.1.20.dist-info/RECORD +89 -0
  64. dao_ai/agent_as_code.py +0 -22
  65. dao_ai/genie/cache/semantic.py +0 -970
  66. dao_ai-0.1.2.dist-info/RECORD +0 -64
  67. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  68. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  69. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/mcp.py CHANGED
@@ -7,10 +7,16 @@ MCP SDK and langchain-mcp-adapters library.
7
7
  For compatibility with Databricks APIs, we use manual tool wrappers
8
8
  that give us full control over the response format.
9
9
 
10
+ Public API:
11
+ - list_mcp_tools(): List available tools from an MCP server (for discovery/UI)
12
+ - create_mcp_tools(): Create LangChain tools for agent execution
13
+
10
14
  Reference: https://docs.langchain.com/oss/python/langchain/mcp
11
15
  """
12
16
 
13
17
  import asyncio
18
+ import fnmatch
19
+ from dataclasses import dataclass
14
20
  from typing import Any, Sequence
15
21
 
16
22
  from langchain_core.runnables.base import RunnableLike
@@ -20,20 +26,187 @@ from loguru import logger
20
26
  from mcp.types import CallToolResult, TextContent, Tool
21
27
 
22
28
  from dao_ai.config import (
29
+ IsDatabricksResource,
23
30
  McpFunctionModel,
24
31
  TransportType,
25
- value_of,
26
32
  )
33
+ from dao_ai.state import Context
34
+
35
+
36
+ @dataclass
37
+ class MCPToolInfo:
38
+ """
39
+ Information about an MCP tool for display and selection.
40
+
41
+ This is a simplified representation of an MCP tool that contains
42
+ only the information needed for UI display and tool selection.
43
+ It's designed to be easily serializable for use in web UIs.
44
+
45
+ Attributes:
46
+ name: The unique identifier/name of the tool
47
+ description: Human-readable description of what the tool does
48
+ input_schema: JSON Schema describing the tool's input parameters
49
+ """
50
+
51
+ name: str
52
+ description: str | None
53
+ input_schema: dict[str, Any]
54
+
55
+ def to_dict(self) -> dict[str, Any]:
56
+ """Convert to dictionary for JSON serialization."""
57
+ return {
58
+ "name": self.name,
59
+ "description": self.description,
60
+ "input_schema": self.input_schema,
61
+ }
62
+
63
+
64
+ def _matches_pattern(tool_name: str, patterns: list[str]) -> bool:
65
+ """
66
+ Check if tool name matches any of the provided patterns.
67
+
68
+ Supports glob patterns:
69
+ - * matches any characters
70
+ - ? matches single character
71
+ - [abc] matches any char in set
72
+ - [!abc] matches any char NOT in set
73
+
74
+ Args:
75
+ tool_name: Name of the tool to check
76
+ patterns: List of exact names or glob patterns
77
+
78
+ Returns:
79
+ True if tool name matches any pattern
80
+
81
+ Examples:
82
+ >>> _matches_pattern("query_sales", ["query_*"])
83
+ True
84
+ >>> _matches_pattern("list_tables", ["query_*"])
85
+ False
86
+ >>> _matches_pattern("tool_a", ["tool_?"])
87
+ True
88
+ """
89
+ for pattern in patterns:
90
+ if fnmatch.fnmatch(tool_name, pattern):
91
+ return True
92
+ return False
93
+
94
+
95
+ def _should_include_tool(
96
+ tool_name: str,
97
+ include_tools: list[str] | None,
98
+ exclude_tools: list[str] | None,
99
+ ) -> bool:
100
+ """
101
+ Determine if a tool should be included based on include/exclude filters.
102
+
103
+ Logic:
104
+ 1. If exclude_tools specified and tool matches: EXCLUDE (highest priority)
105
+ 2. If include_tools specified and tool matches: INCLUDE
106
+ 3. If include_tools specified and tool doesn't match: EXCLUDE
107
+ 4. If no filters specified: INCLUDE (default)
108
+
109
+ Args:
110
+ tool_name: Name of the tool
111
+ include_tools: Optional list of tools/patterns to include
112
+ exclude_tools: Optional list of tools/patterns to exclude
113
+
114
+ Returns:
115
+ True if tool should be included
116
+
117
+ Examples:
118
+ >>> _should_include_tool("query_sales", ["query_*"], None)
119
+ True
120
+ >>> _should_include_tool("drop_table", None, ["drop_*"])
121
+ False
122
+ >>> _should_include_tool("query_sales", ["query_*"], ["*_sales"])
123
+ False # exclude takes precedence
124
+ """
125
+ # Exclude has highest priority
126
+ if exclude_tools and _matches_pattern(tool_name, exclude_tools):
127
+ logger.debug("Tool excluded by exclude_tools", tool_name=tool_name)
128
+ return False
129
+
130
+ # If include list exists, tool must match it
131
+ if include_tools:
132
+ if _matches_pattern(tool_name, include_tools):
133
+ logger.debug("Tool included by include_tools", tool_name=tool_name)
134
+ return True
135
+ else:
136
+ logger.debug(
137
+ "Tool not in include_tools",
138
+ tool_name=tool_name,
139
+ include_patterns=include_tools,
140
+ )
141
+ return False
142
+
143
+ # Default: include all tools
144
+ return True
145
+
146
+
147
+ def _has_auth_configured(resource: IsDatabricksResource) -> bool:
148
+ """Check if a resource has explicit authentication configured."""
149
+ return bool(
150
+ resource.on_behalf_of_user
151
+ or resource.service_principal
152
+ or resource.client_id
153
+ or resource.pat
154
+ )
155
+
156
+
157
+ def _get_auth_resource(function: McpFunctionModel) -> IsDatabricksResource:
158
+ """
159
+ Get the IsDatabricksResource to use for authentication.
160
+
161
+ Follows a priority hierarchy:
162
+ 1. Nested resource with explicit auth (app, connection, genie_room, vector_search)
163
+ 2. McpFunctionModel itself (which also inherits from IsDatabricksResource)
164
+
165
+ Only uses a nested resource if it has authentication configured.
166
+ Otherwise falls back to McpFunctionModel which may have credentials set at the tool level.
167
+
168
+ Returns the resource whose workspace_client should be used for authentication.
169
+ """
170
+ # Check each possible resource source - only use if it has auth configured
171
+ if function.app and _has_auth_configured(function.app):
172
+ return function.app
173
+ if function.connection and _has_auth_configured(function.connection):
174
+ return function.connection
175
+ if function.genie_room and _has_auth_configured(function.genie_room):
176
+ return function.genie_room
177
+ if function.vector_search and _has_auth_configured(function.vector_search):
178
+ return function.vector_search
179
+ # SchemaModel (functions) doesn't have auth - always fall through
180
+
181
+ # Fall back to McpFunctionModel itself (it inherits from IsDatabricksResource)
182
+ # This allows credentials to be set at the tool level
183
+ return function
27
184
 
28
185
 
29
186
  def _build_connection_config(
30
187
  function: McpFunctionModel,
188
+ context: Context | None = None,
31
189
  ) -> dict[str, Any]:
32
190
  """
33
191
  Build the connection configuration dictionary for MultiServerMCPClient.
34
192
 
193
+ Authentication Strategy:
194
+ -----------------------
195
+ For HTTP transport, authentication is handled consistently using
196
+ DatabricksOAuthClientProvider with the workspace_client from the appropriate
197
+ IsDatabricksResource. The auth resource is selected in this priority:
198
+
199
+ 1. Nested resource (app, connection, genie_room, vector_search) if it has auth
200
+ 2. McpFunctionModel itself (inherits from IsDatabricksResource)
201
+
202
+ This approach ensures:
203
+ - Consistent auth handling across all MCP sources
204
+ - Automatic token refresh for long-running connections
205
+ - Support for OBO, service principal, PAT, and ambient auth
206
+
35
207
  Args:
36
208
  function: The MCP function model configuration.
209
+ context: Optional runtime context with headers for OBO auth.
37
210
 
38
211
  Returns:
39
212
  A dictionary containing the transport-specific connection settings.
@@ -45,52 +218,33 @@ def _build_connection_config(
45
218
  "transport": function.transport.value,
46
219
  }
47
220
 
48
- # For HTTP transport with UC Connection, use DatabricksOAuthClientProvider
49
- if function.connection:
50
- from databricks_mcp import DatabricksOAuthClientProvider
51
-
52
- workspace_client = function.connection.workspace_client
53
- auth_provider = DatabricksOAuthClientProvider(workspace_client)
54
-
55
- logger.trace(
56
- "Using DatabricksOAuthClientProvider for authentication",
57
- connection_name=function.connection.name,
58
- )
59
-
60
- return {
61
- "url": function.mcp_url,
62
- "transport": "http",
63
- "auth": auth_provider,
64
- }
65
-
66
- # For HTTP transport with headers-based authentication
67
- headers: dict[str, str] = {
68
- key: str(value_of(val)) for key, val in function.headers.items()
69
- }
70
-
71
- if "Authorization" not in headers:
72
- logger.trace("Generating fresh authentication token")
73
-
74
- from dao_ai.providers.databricks import DatabricksProvider
75
-
76
- try:
77
- provider = DatabricksProvider(
78
- workspace_host=value_of(function.workspace_host),
79
- client_id=value_of(function.client_id),
80
- client_secret=value_of(function.client_secret),
81
- pat=value_of(function.pat),
82
- )
83
- headers["Authorization"] = f"Bearer {provider.create_token()}"
84
- logger.trace("Generated fresh authentication token")
85
- except Exception as e:
86
- logger.error("Failed to create fresh token", error=str(e))
87
- else:
88
- logger.trace("Using existing authentication token")
221
+ # For HTTP transport, use DatabricksOAuthClientProvider with unified auth
222
+ from databricks.sdk import WorkspaceClient
223
+ from databricks_mcp import DatabricksOAuthClientProvider
224
+
225
+ # Get the resource to use for authentication
226
+ auth_resource: IsDatabricksResource = _get_auth_resource(function)
227
+
228
+ # Get workspace client from the auth resource with OBO support via context
229
+ workspace_client: WorkspaceClient = auth_resource.workspace_client_from(context)
230
+ auth_provider: DatabricksOAuthClientProvider = DatabricksOAuthClientProvider(
231
+ workspace_client
232
+ )
233
+
234
+ # Log which resource is providing auth
235
+ resource_name = (
236
+ getattr(auth_resource, "name", None) or auth_resource.__class__.__name__
237
+ )
238
+ logger.trace(
239
+ "Using DatabricksOAuthClientProvider for authentication",
240
+ auth_resource=resource_name,
241
+ resource_type=auth_resource.__class__.__name__,
242
+ )
89
243
 
90
244
  return {
91
245
  "url": function.mcp_url,
92
246
  "transport": "http",
93
- "headers": headers,
247
+ "auth": auth_provider,
94
248
  }
95
249
 
96
250
 
@@ -124,69 +278,29 @@ def _extract_text_content(result: CallToolResult) -> str:
124
278
  return "\n".join(text_parts)
125
279
 
126
280
 
127
- def create_mcp_tools(
128
- function: McpFunctionModel,
129
- ) -> Sequence[RunnableLike]:
281
+ async def _afetch_tools_from_server(function: McpFunctionModel) -> list[Tool]:
130
282
  """
131
- Create tools for invoking Databricks MCP functions.
132
-
133
- Supports both direct MCP connections and UC Connection-based MCP access.
134
- Uses manual tool wrappers to ensure response format compatibility with
135
- Databricks APIs (which reject extra fields in tool results).
283
+ Async version: Fetch raw MCP tools from the server.
136
284
 
137
- Based on: https://docs.databricks.com/aws/en/generative-ai/mcp/external-mcp
285
+ This is the primary async implementation that handles the actual MCP connection
286
+ and tool listing. It's used by both alist_mcp_tools() and acreate_mcp_tools().
138
287
 
139
288
  Args:
140
289
  function: The MCP function model configuration.
141
290
 
142
291
  Returns:
143
- A sequence of LangChain tools that can be used by agents.
144
- """
145
- mcp_url = function.mcp_url
146
- logger.debug("Creating MCP tools", mcp_url=mcp_url)
292
+ List of raw MCP Tool objects from the server.
147
293
 
294
+ Raises:
295
+ RuntimeError: If connection to MCP server fails.
296
+ """
148
297
  connection_config = _build_connection_config(function)
149
-
150
- if function.connection:
151
- logger.debug(
152
- "Using UC Connection for MCP",
153
- connection_name=function.connection.name,
154
- mcp_url=mcp_url,
155
- )
156
- else:
157
- logger.debug(
158
- "Using direct connection for MCP",
159
- transport=function.transport,
160
- mcp_url=mcp_url,
161
- )
162
-
163
- # Create client to list available tools
164
298
  client = MultiServerMCPClient({"mcp_function": connection_config})
165
299
 
166
- async def _list_tools() -> list[Tool]:
167
- """List available MCP tools from the server."""
300
+ try:
168
301
  async with client.session("mcp_function") as session:
169
302
  result = await session.list_tools()
170
303
  return result.tools if hasattr(result, "tools") else list(result)
171
-
172
- try:
173
- mcp_tools: list[Tool] = asyncio.run(_list_tools())
174
-
175
- # Log discovered tools
176
- logger.info(
177
- "Discovered MCP tools",
178
- tools_count=len(mcp_tools),
179
- mcp_url=mcp_url,
180
- )
181
- for mcp_tool in mcp_tools:
182
- logger.debug(
183
- "MCP tool discovered",
184
- tool_name=mcp_tool.name,
185
- tool_description=(
186
- mcp_tool.description[:100] if mcp_tool.description else None
187
- ),
188
- )
189
-
190
304
  except Exception as e:
191
305
  if function.connection:
192
306
  logger.error(
@@ -210,6 +324,326 @@ def create_mcp_tools(
210
324
  f"and URL '{function.url}': {e}"
211
325
  ) from e
212
326
 
327
+
328
+ def _fetch_tools_from_server(function: McpFunctionModel) -> list[Tool]:
329
+ """
330
+ Sync wrapper: Fetch raw MCP tools from the server.
331
+
332
+ For async contexts, use _afetch_tools_from_server() directly.
333
+
334
+ Args:
335
+ function: The MCP function model configuration.
336
+
337
+ Returns:
338
+ List of raw MCP Tool objects from the server.
339
+
340
+ Raises:
341
+ RuntimeError: If connection to MCP server fails.
342
+ """
343
+ return asyncio.run(_afetch_tools_from_server(function))
344
+
345
+
346
+ async def alist_mcp_tools(
347
+ function: McpFunctionModel,
348
+ apply_filters: bool = True,
349
+ ) -> list[MCPToolInfo]:
350
+ """
351
+ Async version: List available tools from an MCP server.
352
+
353
+ This is the primary async implementation for tool discovery.
354
+ For sync contexts, use list_mcp_tools() instead.
355
+
356
+ Args:
357
+ function: The MCP function model configuration.
358
+ apply_filters: Whether to apply include_tools/exclude_tools filters.
359
+
360
+ Returns:
361
+ List of MCPToolInfo objects describing available tools.
362
+
363
+ Raises:
364
+ RuntimeError: If connection to MCP server fails.
365
+ """
366
+ mcp_url = function.mcp_url
367
+ logger.debug(
368
+ "Listing MCP tools (async)", mcp_url=mcp_url, apply_filters=apply_filters
369
+ )
370
+
371
+ # Log connection type
372
+ if function.connection:
373
+ logger.debug(
374
+ "Using UC Connection for MCP",
375
+ connection_name=function.connection.name,
376
+ mcp_url=mcp_url,
377
+ )
378
+ else:
379
+ logger.debug(
380
+ "Using direct connection for MCP",
381
+ transport=function.transport,
382
+ mcp_url=mcp_url,
383
+ )
384
+
385
+ # Fetch tools from server (async)
386
+ mcp_tools: list[Tool] = await _afetch_tools_from_server(function)
387
+
388
+ # Log discovered tools
389
+ logger.info(
390
+ "Discovered MCP tools from server",
391
+ tools_count=len(mcp_tools),
392
+ tool_names=[t.name for t in mcp_tools],
393
+ mcp_url=mcp_url,
394
+ )
395
+
396
+ # Apply filtering if requested and configured
397
+ if apply_filters and (function.include_tools or function.exclude_tools):
398
+ original_count = len(mcp_tools)
399
+ mcp_tools = [
400
+ tool
401
+ for tool in mcp_tools
402
+ if _should_include_tool(
403
+ tool.name,
404
+ function.include_tools,
405
+ function.exclude_tools,
406
+ )
407
+ ]
408
+ filtered_count = original_count - len(mcp_tools)
409
+
410
+ logger.info(
411
+ "Filtered MCP tools",
412
+ original_count=original_count,
413
+ filtered_count=filtered_count,
414
+ final_count=len(mcp_tools),
415
+ include_patterns=function.include_tools,
416
+ exclude_patterns=function.exclude_tools,
417
+ )
418
+
419
+ # Convert to MCPToolInfo for cleaner API
420
+ tool_infos: list[MCPToolInfo] = []
421
+ for mcp_tool in mcp_tools:
422
+ tool_info = MCPToolInfo(
423
+ name=mcp_tool.name,
424
+ description=mcp_tool.description,
425
+ input_schema=mcp_tool.inputSchema or {},
426
+ )
427
+ tool_infos.append(tool_info)
428
+
429
+ logger.debug(
430
+ "MCP tool available",
431
+ tool_name=mcp_tool.name,
432
+ tool_description=(
433
+ mcp_tool.description[:100] if mcp_tool.description else None
434
+ ),
435
+ )
436
+
437
+ return tool_infos
438
+
439
+
440
+ def list_mcp_tools(
441
+ function: McpFunctionModel,
442
+ apply_filters: bool = True,
443
+ ) -> list[MCPToolInfo]:
444
+ """
445
+ Sync wrapper: List available tools from an MCP server.
446
+
447
+ For async contexts, use alist_mcp_tools() directly.
448
+
449
+ Args:
450
+ function: The MCP function model configuration.
451
+ apply_filters: Whether to apply include_tools/exclude_tools filters.
452
+
453
+ Returns:
454
+ List of MCPToolInfo objects describing available tools.
455
+
456
+ Raises:
457
+ RuntimeError: If connection to MCP server fails.
458
+ """
459
+ return asyncio.run(alist_mcp_tools(function, apply_filters))
460
+
461
+
462
+ async def acreate_mcp_tools(
463
+ function: McpFunctionModel,
464
+ ) -> Sequence[RunnableLike]:
465
+ """
466
+ Async version: Create executable LangChain tools for invoking Databricks MCP functions.
467
+
468
+ This is the primary async implementation. For sync contexts, use create_mcp_tools().
469
+
470
+ Args:
471
+ function: The MCP function model configuration.
472
+
473
+ Returns:
474
+ A sequence of LangChain tools that can be used by agents.
475
+
476
+ Raises:
477
+ RuntimeError: If connection to MCP server fails.
478
+ """
479
+ mcp_url = function.mcp_url
480
+ logger.debug("Creating MCP tools (async)", mcp_url=mcp_url)
481
+
482
+ # Fetch tools from server (async)
483
+ mcp_tools: list[Tool] = await _afetch_tools_from_server(function)
484
+
485
+ # Log discovered tools
486
+ logger.info(
487
+ "Discovered MCP tools from server",
488
+ tools_count=len(mcp_tools),
489
+ tool_names=[t.name for t in mcp_tools],
490
+ mcp_url=mcp_url,
491
+ )
492
+
493
+ # Apply filtering if configured
494
+ if function.include_tools or function.exclude_tools:
495
+ original_count = len(mcp_tools)
496
+ mcp_tools = [
497
+ tool
498
+ for tool in mcp_tools
499
+ if _should_include_tool(
500
+ tool.name,
501
+ function.include_tools,
502
+ function.exclude_tools,
503
+ )
504
+ ]
505
+ filtered_count = original_count - len(mcp_tools)
506
+
507
+ logger.info(
508
+ "Filtered MCP tools",
509
+ original_count=original_count,
510
+ filtered_count=filtered_count,
511
+ final_count=len(mcp_tools),
512
+ include_patterns=function.include_tools,
513
+ exclude_patterns=function.exclude_tools,
514
+ )
515
+
516
+ # Log final tool list
517
+ for mcp_tool in mcp_tools:
518
+ logger.debug(
519
+ "MCP tool available",
520
+ tool_name=mcp_tool.name,
521
+ tool_description=(
522
+ mcp_tool.description[:100] if mcp_tool.description else None
523
+ ),
524
+ )
525
+
526
+ def _create_tool_wrapper(mcp_tool: Tool) -> RunnableLike:
527
+ """
528
+ Create a LangChain tool wrapper for an MCP tool.
529
+
530
+ Supports OBO authentication via context headers.
531
+ """
532
+ from langchain.tools import ToolRuntime
533
+
534
+ @create_tool(
535
+ mcp_tool.name,
536
+ description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
537
+ args_schema=mcp_tool.inputSchema,
538
+ )
539
+ async def tool_wrapper(
540
+ runtime: ToolRuntime[Context] = None,
541
+ **kwargs: Any,
542
+ ) -> str:
543
+ """Execute MCP tool with fresh session."""
544
+ logger.trace("Invoking MCP tool", tool_name=mcp_tool.name, args=kwargs)
545
+
546
+ # Get context for OBO support
547
+ context: Context | None = runtime.context if runtime else None
548
+
549
+ invocation_client: MultiServerMCPClient = MultiServerMCPClient(
550
+ {"mcp_function": _build_connection_config(function, context)}
551
+ )
552
+
553
+ try:
554
+ async with invocation_client.session("mcp_function") as session:
555
+ result: CallToolResult = await session.call_tool(
556
+ mcp_tool.name, kwargs
557
+ )
558
+
559
+ text_result: str = _extract_text_content(result)
560
+
561
+ logger.trace(
562
+ "MCP tool completed",
563
+ tool_name=mcp_tool.name,
564
+ result_length=len(text_result),
565
+ )
566
+
567
+ return text_result
568
+
569
+ except Exception as e:
570
+ logger.error(
571
+ "MCP tool failed",
572
+ tool_name=mcp_tool.name,
573
+ error=str(e),
574
+ )
575
+ raise
576
+
577
+ return tool_wrapper
578
+
579
+ return [_create_tool_wrapper(tool) for tool in mcp_tools]
580
+
581
+
582
+ def create_mcp_tools(
583
+ function: McpFunctionModel,
584
+ ) -> Sequence[RunnableLike]:
585
+ """
586
+ Sync wrapper: Create executable LangChain tools for invoking Databricks MCP functions.
587
+
588
+ For async contexts, use acreate_mcp_tools() directly.
589
+
590
+ Args:
591
+ function: The MCP function model configuration.
592
+
593
+ Returns:
594
+ A sequence of LangChain tools that can be used by agents.
595
+
596
+ Raises:
597
+ RuntimeError: If connection to MCP server fails.
598
+ """
599
+ mcp_url = function.mcp_url
600
+ logger.debug("Creating MCP tools", mcp_url=mcp_url)
601
+
602
+ # Fetch and filter tools using shared logic
603
+ # We need the raw Tool objects here, not MCPToolInfo
604
+ mcp_tools: list[Tool] = _fetch_tools_from_server(function)
605
+
606
+ # Log discovered tools
607
+ logger.info(
608
+ "Discovered MCP tools from server",
609
+ tools_count=len(mcp_tools),
610
+ tool_names=[t.name for t in mcp_tools],
611
+ mcp_url=mcp_url,
612
+ )
613
+
614
+ # Apply filtering if configured
615
+ if function.include_tools or function.exclude_tools:
616
+ original_count = len(mcp_tools)
617
+ mcp_tools = [
618
+ tool
619
+ for tool in mcp_tools
620
+ if _should_include_tool(
621
+ tool.name,
622
+ function.include_tools,
623
+ function.exclude_tools,
624
+ )
625
+ ]
626
+ filtered_count = original_count - len(mcp_tools)
627
+
628
+ logger.info(
629
+ "Filtered MCP tools",
630
+ original_count=original_count,
631
+ filtered_count=filtered_count,
632
+ final_count=len(mcp_tools),
633
+ include_patterns=function.include_tools,
634
+ exclude_patterns=function.exclude_tools,
635
+ )
636
+
637
+ # Log final tool list
638
+ for mcp_tool in mcp_tools:
639
+ logger.debug(
640
+ "MCP tool available",
641
+ tool_name=mcp_tool.name,
642
+ tool_description=(
643
+ mcp_tool.description[:100] if mcp_tool.description else None
644
+ ),
645
+ )
646
+
213
647
  def _create_tool_wrapper(mcp_tool: Tool) -> RunnableLike:
214
648
  """
215
649
  Create a LangChain tool wrapper for an MCP tool.
@@ -217,20 +651,28 @@ def create_mcp_tools(
217
651
  This wrapper handles:
218
652
  - Fresh session creation per invocation (stateless)
219
653
  - Content extraction to plain text (avoiding extra fields)
654
+ - OBO authentication via context headers
220
655
  """
656
+ from langchain.tools import ToolRuntime
221
657
 
222
658
  @create_tool(
223
659
  mcp_tool.name,
224
660
  description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
225
661
  args_schema=mcp_tool.inputSchema,
226
662
  )
227
- async def tool_wrapper(**kwargs: Any) -> str:
663
+ async def tool_wrapper(
664
+ runtime: ToolRuntime[Context] = None,
665
+ **kwargs: Any,
666
+ ) -> str:
228
667
  """Execute MCP tool with fresh session."""
229
668
  logger.trace("Invoking MCP tool", tool_name=mcp_tool.name, args=kwargs)
230
669
 
231
- # Create a fresh client/session for each invocation
232
- invocation_client = MultiServerMCPClient(
233
- {"mcp_function": _build_connection_config(function)}
670
+ # Get context for OBO support
671
+ context: Context | None = runtime.context if runtime else None
672
+
673
+ # Create a fresh client/session for each invocation with OBO support
674
+ invocation_client: MultiServerMCPClient = MultiServerMCPClient(
675
+ {"mcp_function": _build_connection_config(function, context)}
234
676
  )
235
677
 
236
678
  try:
@@ -240,7 +682,7 @@ def create_mcp_tools(
240
682
  )
241
683
 
242
684
  # Extract text content, avoiding extra fields
243
- text_result = _extract_text_content(result)
685
+ text_result: str = _extract_text_content(result)
244
686
 
245
687
  logger.trace(
246
688
  "MCP tool completed",