datarobot-genai 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/__init__.py +19 -0
- datarobot_genai/core/__init__.py +0 -0
- datarobot_genai/core/agents/__init__.py +43 -0
- datarobot_genai/core/agents/base.py +195 -0
- datarobot_genai/core/chat/__init__.py +19 -0
- datarobot_genai/core/chat/auth.py +146 -0
- datarobot_genai/core/chat/client.py +178 -0
- datarobot_genai/core/chat/responses.py +297 -0
- datarobot_genai/core/cli/__init__.py +18 -0
- datarobot_genai/core/cli/agent_environment.py +47 -0
- datarobot_genai/core/cli/agent_kernel.py +211 -0
- datarobot_genai/core/custom_model.py +141 -0
- datarobot_genai/core/mcp/__init__.py +0 -0
- datarobot_genai/core/mcp/common.py +218 -0
- datarobot_genai/core/telemetry_agent.py +126 -0
- datarobot_genai/core/utils/__init__.py +3 -0
- datarobot_genai/core/utils/auth.py +234 -0
- datarobot_genai/core/utils/urls.py +64 -0
- datarobot_genai/crewai/__init__.py +24 -0
- datarobot_genai/crewai/agent.py +42 -0
- datarobot_genai/crewai/base.py +159 -0
- datarobot_genai/crewai/events.py +117 -0
- datarobot_genai/crewai/mcp.py +59 -0
- datarobot_genai/drmcp/__init__.py +78 -0
- datarobot_genai/drmcp/core/__init__.py +13 -0
- datarobot_genai/drmcp/core/auth.py +165 -0
- datarobot_genai/drmcp/core/clients.py +180 -0
- datarobot_genai/drmcp/core/config.py +364 -0
- datarobot_genai/drmcp/core/config_utils.py +174 -0
- datarobot_genai/drmcp/core/constants.py +18 -0
- datarobot_genai/drmcp/core/credentials.py +190 -0
- datarobot_genai/drmcp/core/dr_mcp_server.py +350 -0
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
- datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
- datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +70 -0
- datarobot_genai/drmcp/core/dynamic_prompts/register.py +205 -0
- datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
- datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
- datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
- datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
- datarobot_genai/drmcp/core/exceptions.py +25 -0
- datarobot_genai/drmcp/core/logging.py +98 -0
- datarobot_genai/drmcp/core/mcp_instance.py +515 -0
- datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
- datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
- datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
- datarobot_genai/drmcp/core/routes.py +439 -0
- datarobot_genai/drmcp/core/routes_utils.py +30 -0
- datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
- datarobot_genai/drmcp/core/telemetry.py +424 -0
- datarobot_genai/drmcp/core/tool_config.py +111 -0
- datarobot_genai/drmcp/core/tool_filter.py +117 -0
- datarobot_genai/drmcp/core/utils.py +138 -0
- datarobot_genai/drmcp/server.py +19 -0
- datarobot_genai/drmcp/test_utils/__init__.py +13 -0
- datarobot_genai/drmcp/test_utils/clients/__init__.py +0 -0
- datarobot_genai/drmcp/test_utils/clients/anthropic.py +68 -0
- datarobot_genai/drmcp/test_utils/clients/base.py +300 -0
- datarobot_genai/drmcp/test_utils/clients/dr_gateway.py +58 -0
- datarobot_genai/drmcp/test_utils/clients/openai.py +68 -0
- datarobot_genai/drmcp/test_utils/elicitation_test_tool.py +89 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +109 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +133 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +107 -0
- datarobot_genai/drmcp/test_utils/test_interactive.py +205 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +220 -0
- datarobot_genai/drmcp/test_utils/utils.py +91 -0
- datarobot_genai/drmcp/tools/__init__.py +14 -0
- datarobot_genai/drmcp/tools/clients/__init__.py +14 -0
- datarobot_genai/drmcp/tools/clients/atlassian.py +188 -0
- datarobot_genai/drmcp/tools/clients/confluence.py +584 -0
- datarobot_genai/drmcp/tools/clients/gdrive.py +832 -0
- datarobot_genai/drmcp/tools/clients/jira.py +334 -0
- datarobot_genai/drmcp/tools/clients/microsoft_graph.py +479 -0
- datarobot_genai/drmcp/tools/clients/s3.py +28 -0
- datarobot_genai/drmcp/tools/confluence/__init__.py +14 -0
- datarobot_genai/drmcp/tools/confluence/tools.py +321 -0
- datarobot_genai/drmcp/tools/gdrive/__init__.py +0 -0
- datarobot_genai/drmcp/tools/gdrive/tools.py +347 -0
- datarobot_genai/drmcp/tools/jira/__init__.py +14 -0
- datarobot_genai/drmcp/tools/jira/tools.py +243 -0
- datarobot_genai/drmcp/tools/microsoft_graph/__init__.py +13 -0
- datarobot_genai/drmcp/tools/microsoft_graph/tools.py +198 -0
- datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
- datarobot_genai/drmcp/tools/predictive/data.py +133 -0
- datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
- datarobot_genai/drmcp/tools/predictive/model.py +148 -0
- datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
- datarobot_genai/drmcp/tools/predictive/project.py +90 -0
- datarobot_genai/drmcp/tools/predictive/training.py +661 -0
- datarobot_genai/langgraph/__init__.py +0 -0
- datarobot_genai/langgraph/agent.py +341 -0
- datarobot_genai/langgraph/mcp.py +73 -0
- datarobot_genai/llama_index/__init__.py +16 -0
- datarobot_genai/llama_index/agent.py +50 -0
- datarobot_genai/llama_index/base.py +299 -0
- datarobot_genai/llama_index/mcp.py +79 -0
- datarobot_genai/nat/__init__.py +0 -0
- datarobot_genai/nat/agent.py +275 -0
- datarobot_genai/nat/datarobot_auth_provider.py +110 -0
- datarobot_genai/nat/datarobot_llm_clients.py +318 -0
- datarobot_genai/nat/datarobot_llm_providers.py +130 -0
- datarobot_genai/nat/datarobot_mcp_client.py +266 -0
- datarobot_genai/nat/helpers.py +87 -0
- datarobot_genai/py.typed +0 -0
- datarobot_genai-0.2.31.dist-info/METADATA +145 -0
- datarobot_genai-0.2.31.dist-info/RECORD +125 -0
- datarobot_genai-0.2.31.dist-info/WHEEL +4 -0
- datarobot_genai-0.2.31.dist-info/entry_points.txt +5 -0
- datarobot_genai-0.2.31.dist-info/licenses/AUTHORS +2 -0
- datarobot_genai-0.2.31.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,515 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from collections.abc import Callable
|
|
17
|
+
from functools import wraps
|
|
18
|
+
from typing import Any
|
|
19
|
+
from typing import TypedDict
|
|
20
|
+
|
|
21
|
+
from fastmcp import Context
|
|
22
|
+
from fastmcp import FastMCP
|
|
23
|
+
from fastmcp.exceptions import NotFoundError
|
|
24
|
+
from fastmcp.prompts.prompt import Prompt
|
|
25
|
+
from fastmcp.server.dependencies import get_context
|
|
26
|
+
from fastmcp.tools import Tool
|
|
27
|
+
from mcp.types import AnyFunction
|
|
28
|
+
from mcp.types import Tool as MCPTool
|
|
29
|
+
from mcp.types import ToolAnnotations
|
|
30
|
+
from typing_extensions import Unpack
|
|
31
|
+
|
|
32
|
+
from .config import MCPServerConfig
|
|
33
|
+
from .config import get_config
|
|
34
|
+
from .dynamic_prompts.utils import get_prompt_name_no_duplicate
|
|
35
|
+
from .logging import log_execution
|
|
36
|
+
from .memory_management.manager import MemoryManager
|
|
37
|
+
from .memory_management.manager import get_memory_manager
|
|
38
|
+
from .telemetry import trace_execution
|
|
39
|
+
from .tool_filter import filter_tools_by_tags
|
|
40
|
+
from .tool_filter import list_all_tags
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
async def get_agent_and_storage_ids(
|
|
46
|
+
args: tuple[Any, ...], kwargs: dict[str, Any]
|
|
47
|
+
) -> tuple[str | None, str | None]:
|
|
48
|
+
"""
|
|
49
|
+
Extract agent ID from request context and get corresponding storage ID.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
args: Positional arguments that may contain a Context object
|
|
53
|
+
kwargs: Keyword arguments that may contain a Context object
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
Tuple of (agent_id, storage_id), both may be None if not found
|
|
58
|
+
"""
|
|
59
|
+
# Find the context argument if it exists
|
|
60
|
+
ctx = next((arg for arg in args if isinstance(arg, Context)), kwargs.get("ctx"))
|
|
61
|
+
|
|
62
|
+
# Extract X-Agent-Id if context and headers exist
|
|
63
|
+
agent_id = None
|
|
64
|
+
if (
|
|
65
|
+
ctx
|
|
66
|
+
and ctx.request_context
|
|
67
|
+
and ctx.request_context.request
|
|
68
|
+
and hasattr(ctx.request_context.request, "headers")
|
|
69
|
+
):
|
|
70
|
+
headers = ctx.request_context.request.headers
|
|
71
|
+
agent_id = headers.get("x-agent-id")
|
|
72
|
+
|
|
73
|
+
# If agent_id was found, get the active storage_id
|
|
74
|
+
storage_id = None
|
|
75
|
+
if agent_id and MemoryManager.is_initialized():
|
|
76
|
+
memory_manager = get_memory_manager()
|
|
77
|
+
if memory_manager:
|
|
78
|
+
storage_id = await memory_manager.get_active_storage_id_for_agent(agent_id)
|
|
79
|
+
|
|
80
|
+
return agent_id, storage_id
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class TaggedFastMCP(FastMCP):
|
|
84
|
+
"""Extended FastMCP that supports tags, deployments and other annotations directly in the
|
|
85
|
+
tool decorator.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
89
|
+
super().__init__(*args, **kwargs)
|
|
90
|
+
self._deployments_map: dict[str, str] = {}
|
|
91
|
+
self._prompts_map: dict[str, tuple[str, str]] = {}
|
|
92
|
+
|
|
93
|
+
async def notify_prompts_changed(self) -> None:
|
|
94
|
+
"""
|
|
95
|
+
Notify connected clients that the prompt list has changed.
|
|
96
|
+
|
|
97
|
+
This method attempts to send a prompts/list_changed notification to inform
|
|
98
|
+
clients that they should refresh their prompt list.
|
|
99
|
+
|
|
100
|
+
Note: In stateless HTTP mode (default for this server), notifications may not
|
|
101
|
+
reach clients since each request is independent. This method still logs the
|
|
102
|
+
change for auditing purposes and will work if the server is configured for
|
|
103
|
+
stateful connections.
|
|
104
|
+
|
|
105
|
+
See: https://github.com/modelcontextprotocol/python-sdk/issues/710
|
|
106
|
+
"""
|
|
107
|
+
logger.info("Prompt list changed - attempting to notify connected clients")
|
|
108
|
+
|
|
109
|
+
# Try to use FastMCP's built-in notification mechanism if in an MCP context
|
|
110
|
+
try:
|
|
111
|
+
context = get_context()
|
|
112
|
+
context._queue_prompt_list_changed()
|
|
113
|
+
logger.debug("Queued prompts_changed notification via MCP context")
|
|
114
|
+
except RuntimeError:
|
|
115
|
+
# No active MCP context - this is expected when called from REST API
|
|
116
|
+
logger.debug(
|
|
117
|
+
"No active MCP context for notification. "
|
|
118
|
+
"In stateless mode, clients will see changes on next request."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
async def list_tools(
|
|
122
|
+
self, tags: list[str] | None = None, match_all: bool = False
|
|
123
|
+
) -> list[MCPTool]:
|
|
124
|
+
"""
|
|
125
|
+
List all available tools, optionally filtered by tags.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
tags: Optional list of tags to filter by. If None, returns all tools.
|
|
129
|
+
match_all: If True, tool must have all specified tags (AND logic).
|
|
130
|
+
If False, tool must have at least one tag (OR logic).
|
|
131
|
+
Only used when tags is provided.
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
List of MCPTool objects that match the tag criteria.
|
|
136
|
+
"""
|
|
137
|
+
# Get all tools from the parent class
|
|
138
|
+
all_tools = await super()._list_tools_mcp()
|
|
139
|
+
|
|
140
|
+
# If no tags specified, return all tools
|
|
141
|
+
if not tags:
|
|
142
|
+
return all_tools
|
|
143
|
+
|
|
144
|
+
# Filter tools by tags
|
|
145
|
+
filtered_tools = filter_tools_by_tags(list(all_tools), tags, match_all)
|
|
146
|
+
|
|
147
|
+
return filtered_tools # type: ignore[return-value]
|
|
148
|
+
|
|
149
|
+
async def get_all_tags(self) -> list[str]:
|
|
150
|
+
"""
|
|
151
|
+
Get all unique tags from all registered tools.
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
List of all unique tags sorted alphabetically.
|
|
156
|
+
"""
|
|
157
|
+
all_tools = await self._list_tools_mcp()
|
|
158
|
+
return list_all_tags(list(all_tools))
|
|
159
|
+
|
|
160
|
+
async def get_deployment_mapping(self) -> dict[str, str]:
|
|
161
|
+
"""
|
|
162
|
+
Get the list of deployment IDs for all registered dynamic tools.
|
|
163
|
+
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
Dictionary mapping deployment IDs to tool names.
|
|
167
|
+
"""
|
|
168
|
+
return self._deployments_map.copy()
|
|
169
|
+
|
|
170
|
+
async def set_deployment_mapping(self, deployment_id: str, tool_name: str) -> None:
|
|
171
|
+
"""
|
|
172
|
+
Add or update the mapping of a deployment ID to a tool name.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
deployment_id: The ID of the deployment.
|
|
176
|
+
tool_name: The name of the tool associated with the deployment.
|
|
177
|
+
"""
|
|
178
|
+
existing = self._deployments_map.get(deployment_id)
|
|
179
|
+
if existing and existing != tool_name:
|
|
180
|
+
logger.debug(
|
|
181
|
+
f"Deployment ID {deployment_id} already mapped to {existing}, updating to "
|
|
182
|
+
f"{tool_name}"
|
|
183
|
+
)
|
|
184
|
+
try:
|
|
185
|
+
self.remove_tool(existing)
|
|
186
|
+
except NotFoundError:
|
|
187
|
+
logger.debug(f"Tool {existing} not found in registry, skipping removal")
|
|
188
|
+
self._deployments_map[deployment_id] = tool_name
|
|
189
|
+
|
|
190
|
+
async def remove_deployment_mapping(self, deployment_id: str) -> None:
|
|
191
|
+
"""
|
|
192
|
+
Remove the mapping of a deployment ID to a tool name.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
deployment_id: The ID of the deployment to remove.
|
|
196
|
+
"""
|
|
197
|
+
removed = self._deployments_map.pop(deployment_id, None)
|
|
198
|
+
if removed is not None:
|
|
199
|
+
logger.debug(f"Removed deployment mapping for ID {deployment_id} with tool {removed}")
|
|
200
|
+
try:
|
|
201
|
+
self.remove_tool(removed)
|
|
202
|
+
except NotFoundError:
|
|
203
|
+
logger.debug(f"Tool {removed} not found in registry, skipping removal")
|
|
204
|
+
|
|
205
|
+
async def get_prompt_mapping(self) -> dict[str, tuple[str, str]]:
|
|
206
|
+
"""
|
|
207
|
+
Get the list of prompt ID for all registered dynamic prompts.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
Dictionary mapping prompt template id to prompt template version id and name
|
|
212
|
+
"""
|
|
213
|
+
return self._prompts_map.copy()
|
|
214
|
+
|
|
215
|
+
async def set_prompt_mapping(
|
|
216
|
+
self, prompt_template_id: str, prompt_template_version_id: str, prompt_name: str
|
|
217
|
+
) -> None:
|
|
218
|
+
"""
|
|
219
|
+
Add or update the mapping of a deployment ID to a tool name.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
prompt_template_id: The ID of the prompt template.
|
|
223
|
+
prompt_template_version_id: The ID of the prompt template version.
|
|
224
|
+
prompt_name: The prompt name associated with the prompt template id and version.
|
|
225
|
+
"""
|
|
226
|
+
existing_prompt_template = self._prompts_map.get(prompt_template_id)
|
|
227
|
+
|
|
228
|
+
if existing_prompt_template:
|
|
229
|
+
existing_prompt_template_version_id, _ = existing_prompt_template
|
|
230
|
+
|
|
231
|
+
logger.debug(
|
|
232
|
+
f"Prompt template ID {prompt_template_id} "
|
|
233
|
+
f"already mapped to {existing_prompt_template_version_id}. "
|
|
234
|
+
f"Updating to version id = {prompt_template_version_id} and name = {prompt_name}"
|
|
235
|
+
)
|
|
236
|
+
await self.remove_prompt_mapping(
|
|
237
|
+
prompt_template_id, existing_prompt_template_version_id
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
self._prompts_map[prompt_template_id] = (prompt_template_version_id, prompt_name)
|
|
241
|
+
|
|
242
|
+
async def remove_prompt_mapping(
|
|
243
|
+
self, prompt_template_id: str, prompt_template_version_id: str
|
|
244
|
+
) -> None:
|
|
245
|
+
"""
|
|
246
|
+
Remove the mapping of a prompt_template ID to a version and prompt name.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
prompt_template_id: The ID of the prompt template to remove.
|
|
250
|
+
prompt_template_version_id: The ID of the prompt template version to remove.
|
|
251
|
+
"""
|
|
252
|
+
if existing_prompt_template := self._prompts_map.get(prompt_template_id):
|
|
253
|
+
existing_prompt_template_version_id, _ = existing_prompt_template
|
|
254
|
+
if existing_prompt_template_version_id != prompt_template_version_id:
|
|
255
|
+
logger.debug(
|
|
256
|
+
f"Found prompt template with id = {prompt_template_id} in registry, "
|
|
257
|
+
f"but with different version = {existing_prompt_template_version_id}, "
|
|
258
|
+
f"skipping removal."
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
prompts_d = await self.get_prompts()
|
|
262
|
+
for prompt in prompts_d.values():
|
|
263
|
+
if (
|
|
264
|
+
prompt.meta is not None
|
|
265
|
+
and prompt.meta.get("prompt_template_id", "") == prompt_template_id
|
|
266
|
+
and prompt.meta.get("prompt_template_version_id", "")
|
|
267
|
+
== prompt_template_version_id
|
|
268
|
+
):
|
|
269
|
+
prompt.disable()
|
|
270
|
+
|
|
271
|
+
self._prompts_map.pop(prompt_template_id, None)
|
|
272
|
+
|
|
273
|
+
# Notify clients that the prompt list has changed
|
|
274
|
+
await self.notify_prompts_changed()
|
|
275
|
+
else:
|
|
276
|
+
logger.debug(
|
|
277
|
+
f"Do not found prompt template with id = {prompt_template_id} in registry, "
|
|
278
|
+
f"skipping removal."
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
# Create the tagged MCP instance
|
|
283
|
+
mcp_server_configs: MCPServerConfig = get_config()
|
|
284
|
+
|
|
285
|
+
mcp = TaggedFastMCP(
|
|
286
|
+
name=mcp_server_configs.mcp_server_name,
|
|
287
|
+
on_duplicate_tools=mcp_server_configs.tool_registration_duplicate_behavior,
|
|
288
|
+
on_duplicate_prompts=mcp_server_configs.prompt_registration_duplicate_behavior,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class ToolKwargs(TypedDict, total=False):
|
|
293
|
+
"""Keyword arguments passed through to FastMCP's mcp.tool() decorator.
|
|
294
|
+
|
|
295
|
+
All parameters are optional and forwarded directly to FastMCP tool registration.
|
|
296
|
+
See FastMCP documentation for full details on each parameter.
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
name: str | None
|
|
300
|
+
title: str | None
|
|
301
|
+
description: str | None
|
|
302
|
+
icons: list[Any] | None
|
|
303
|
+
tags: set[str] | None
|
|
304
|
+
output_schema: dict[str, Any] | None
|
|
305
|
+
annotations: Any | None
|
|
306
|
+
exclude_args: list[str] | None
|
|
307
|
+
meta: dict[str, Any] | None
|
|
308
|
+
enabled: bool | None
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def dr_core_mcp_tool(
|
|
312
|
+
**kwargs: Unpack[ToolKwargs],
|
|
313
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
314
|
+
"""Combine decorator that includes mcp.tool() and dr_mcp_extras().
|
|
315
|
+
|
|
316
|
+
All keyword arguments are passed through to FastMCP's mcp.tool() decorator.
|
|
317
|
+
See ToolKwargs for available parameters.
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
321
|
+
instrumented = dr_mcp_extras()(func)
|
|
322
|
+
mcp.tool(**kwargs)(instrumented)
|
|
323
|
+
return instrumented
|
|
324
|
+
|
|
325
|
+
return decorator
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
async def memory_aware_wrapper(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
|
329
|
+
"""
|
|
330
|
+
Add memory management capabilities to any async function.
|
|
331
|
+
Extracts agent and storage IDs from the context and adds them to kwargs if found.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
func: The async function to wrap
|
|
335
|
+
*args: Positional arguments to pass to the function
|
|
336
|
+
**kwargs: Keyword arguments to pass to the function
|
|
337
|
+
|
|
338
|
+
Returns
|
|
339
|
+
-------
|
|
340
|
+
The result of calling the wrapped function
|
|
341
|
+
"""
|
|
342
|
+
# Get agent and storage IDs from context
|
|
343
|
+
agent_id, storage_id = await get_agent_and_storage_ids(args, kwargs)
|
|
344
|
+
|
|
345
|
+
# Add IDs to kwargs if found
|
|
346
|
+
if agent_id and storage_id:
|
|
347
|
+
kwargs["agent_id"] = agent_id
|
|
348
|
+
kwargs["storage_id"] = storage_id
|
|
349
|
+
|
|
350
|
+
# Call the original function
|
|
351
|
+
return await func(*args, **kwargs)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def dr_mcp_tool(
|
|
355
|
+
**kwargs: Unpack[ToolKwargs],
|
|
356
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
357
|
+
"""Combine decorator that includes mcp.tool(), dr_mcp_extras(), and capture memory ids from
|
|
358
|
+
the request headers if they exist.
|
|
359
|
+
|
|
360
|
+
All keyword arguments are passed through to FastMCP's mcp.tool() decorator.
|
|
361
|
+
See ToolKwargs for available parameters.
|
|
362
|
+
"""
|
|
363
|
+
|
|
364
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
365
|
+
@wraps(func)
|
|
366
|
+
async def wrapper(*args: Any, **inner_kwargs: Any) -> Any:
|
|
367
|
+
return await memory_aware_wrapper(func, *args, **inner_kwargs)
|
|
368
|
+
|
|
369
|
+
# Apply the MCP decorators
|
|
370
|
+
instrumented = dr_mcp_extras()(wrapper)
|
|
371
|
+
mcp.tool(**kwargs)(instrumented)
|
|
372
|
+
return instrumented
|
|
373
|
+
|
|
374
|
+
return decorator
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def dr_mcp_extras(
|
|
378
|
+
type: str = "tool",
|
|
379
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
380
|
+
"""Combine decorator that includes log_execution and trace_execution().
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
type: default is "tool", other options are "prompt", "resource"
|
|
384
|
+
"""
|
|
385
|
+
|
|
386
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
387
|
+
return log_execution(trace_execution(trace_type=type)(func))
|
|
388
|
+
|
|
389
|
+
return decorator
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
async def register_tools(
|
|
393
|
+
fn: AnyFunction,
|
|
394
|
+
name: str | None = None,
|
|
395
|
+
title: str | None = None,
|
|
396
|
+
description: str | None = None,
|
|
397
|
+
tags: set[str] | None = None,
|
|
398
|
+
deployment_id: str | None = None,
|
|
399
|
+
) -> Tool:
|
|
400
|
+
"""
|
|
401
|
+
Register new tools after server has started.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
fn: The function to register as a tool
|
|
405
|
+
name: Optional name for the tool (defaults to function name)
|
|
406
|
+
title: Optional human-readable title for the tool
|
|
407
|
+
description: Optional description of what the tool does
|
|
408
|
+
tags: Optional set of tags to apply to the tool
|
|
409
|
+
deployment_id: Optional deployment ID associated with the tool
|
|
410
|
+
|
|
411
|
+
Returns
|
|
412
|
+
-------
|
|
413
|
+
The registered Tool object
|
|
414
|
+
"""
|
|
415
|
+
tool_name = name or fn.__name__
|
|
416
|
+
logger.info(f"Registering new tool: {tool_name}")
|
|
417
|
+
|
|
418
|
+
# Create a memory-aware version of the function
|
|
419
|
+
@wraps(fn)
|
|
420
|
+
async def memory_aware_fn(*args: Any, **kwargs: Any) -> Any:
|
|
421
|
+
return await memory_aware_wrapper(fn, *args, **kwargs)
|
|
422
|
+
|
|
423
|
+
# Apply dr_mcp_extras to the memory-aware function
|
|
424
|
+
wrapped_fn = dr_mcp_extras()(memory_aware_fn)
|
|
425
|
+
|
|
426
|
+
# Create annotations only when additional metadata is required
|
|
427
|
+
annotations: ToolAnnotations | None = None # type: ignore[assignment]
|
|
428
|
+
if deployment_id is not None:
|
|
429
|
+
annotations = ToolAnnotations() # type: ignore[call-arg]
|
|
430
|
+
annotations.deployment_id = deployment_id # type: ignore[attr-defined]
|
|
431
|
+
|
|
432
|
+
tool = Tool.from_function(
|
|
433
|
+
fn=wrapped_fn,
|
|
434
|
+
name=tool_name,
|
|
435
|
+
title=title,
|
|
436
|
+
description=description,
|
|
437
|
+
annotations=annotations,
|
|
438
|
+
tags=tags,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# Register the tool
|
|
442
|
+
registered_tool = mcp.add_tool(tool)
|
|
443
|
+
|
|
444
|
+
# Map deployment ID to tool name if provided
|
|
445
|
+
if deployment_id:
|
|
446
|
+
await mcp.set_deployment_mapping(deployment_id, tool_name)
|
|
447
|
+
|
|
448
|
+
# Verify tool is registered
|
|
449
|
+
tools = await mcp.list_tools()
|
|
450
|
+
if not any(tool.name == tool_name for tool in tools):
|
|
451
|
+
raise RuntimeError(f"Tool {tool_name} was not registered successfully")
|
|
452
|
+
logger.info(f"Registered tools: {len(tools)}")
|
|
453
|
+
|
|
454
|
+
return registered_tool
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
async def register_prompt(
|
|
458
|
+
fn: AnyFunction,
|
|
459
|
+
name: str | None = None,
|
|
460
|
+
title: str | None = None,
|
|
461
|
+
description: str | None = None,
|
|
462
|
+
tags: set[str] | None = None,
|
|
463
|
+
meta: dict[str, Any] | None = None,
|
|
464
|
+
prompt_template: tuple[str, str] | None = None,
|
|
465
|
+
) -> Prompt:
|
|
466
|
+
"""
|
|
467
|
+
Register new prompt after server has started.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
fn: The function to register as a prompt
|
|
471
|
+
name: Optional name for the prompt (defaults to function name)
|
|
472
|
+
title: Optional human-readable title for the prompt
|
|
473
|
+
description: Optional description of what the prompt does
|
|
474
|
+
tags: Optional set of tags to apply to the prompt
|
|
475
|
+
meta: Optional dict of metadata to apply to the prompt
|
|
476
|
+
prompt_template: Optional (id, version id) of the prompt template
|
|
477
|
+
|
|
478
|
+
Returns
|
|
479
|
+
-------
|
|
480
|
+
The registered Prompt object
|
|
481
|
+
"""
|
|
482
|
+
prompt_name = name or fn.__name__
|
|
483
|
+
logger.info(f"Registering new prompt: {prompt_name}")
|
|
484
|
+
wrapped_fn = dr_mcp_extras(type="prompt")(fn)
|
|
485
|
+
|
|
486
|
+
prompt_name_no_duplicate = await get_prompt_name_no_duplicate(mcp, prompt_name)
|
|
487
|
+
|
|
488
|
+
prompt = Prompt.from_function(
|
|
489
|
+
fn=wrapped_fn,
|
|
490
|
+
name=prompt_name_no_duplicate,
|
|
491
|
+
title=title,
|
|
492
|
+
description=description,
|
|
493
|
+
tags=tags,
|
|
494
|
+
meta=meta,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Register the prompt
|
|
498
|
+
if prompt_template:
|
|
499
|
+
prompt_template_id, prompt_template_version_id = prompt_template
|
|
500
|
+
await mcp.set_prompt_mapping(
|
|
501
|
+
prompt_template_id, prompt_template_version_id, prompt_name_no_duplicate
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
registered_prompt = mcp.add_prompt(prompt)
|
|
505
|
+
|
|
506
|
+
# Verify prompt is registered
|
|
507
|
+
prompts = await mcp.get_prompts()
|
|
508
|
+
if not any(prompt.name == prompt_name_no_duplicate for prompt in prompts.values()):
|
|
509
|
+
raise RuntimeError(f"Prompt {prompt_name_no_duplicate} was not registered successfully")
|
|
510
|
+
logger.info(f"Registered prompts: {len(prompts)}")
|
|
511
|
+
|
|
512
|
+
# Notify clients that the prompt list has changed
|
|
513
|
+
await mcp.notify_prompts_changed()
|
|
514
|
+
|
|
515
|
+
return registered_prompt
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|