datarobot-genai 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/__init__.py +19 -0
- datarobot_genai/core/__init__.py +0 -0
- datarobot_genai/core/agents/__init__.py +43 -0
- datarobot_genai/core/agents/base.py +195 -0
- datarobot_genai/core/chat/__init__.py +19 -0
- datarobot_genai/core/chat/auth.py +146 -0
- datarobot_genai/core/chat/client.py +178 -0
- datarobot_genai/core/chat/responses.py +297 -0
- datarobot_genai/core/cli/__init__.py +18 -0
- datarobot_genai/core/cli/agent_environment.py +47 -0
- datarobot_genai/core/cli/agent_kernel.py +211 -0
- datarobot_genai/core/custom_model.py +141 -0
- datarobot_genai/core/mcp/__init__.py +0 -0
- datarobot_genai/core/mcp/common.py +218 -0
- datarobot_genai/core/telemetry_agent.py +126 -0
- datarobot_genai/core/utils/__init__.py +3 -0
- datarobot_genai/core/utils/auth.py +234 -0
- datarobot_genai/core/utils/urls.py +64 -0
- datarobot_genai/crewai/__init__.py +24 -0
- datarobot_genai/crewai/agent.py +42 -0
- datarobot_genai/crewai/base.py +159 -0
- datarobot_genai/crewai/events.py +117 -0
- datarobot_genai/crewai/mcp.py +59 -0
- datarobot_genai/drmcp/__init__.py +78 -0
- datarobot_genai/drmcp/core/__init__.py +13 -0
- datarobot_genai/drmcp/core/auth.py +165 -0
- datarobot_genai/drmcp/core/clients.py +180 -0
- datarobot_genai/drmcp/core/config.py +250 -0
- datarobot_genai/drmcp/core/config_utils.py +174 -0
- datarobot_genai/drmcp/core/constants.py +18 -0
- datarobot_genai/drmcp/core/credentials.py +190 -0
- datarobot_genai/drmcp/core/dr_mcp_server.py +316 -0
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
- datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
- datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +128 -0
- datarobot_genai/drmcp/core/dynamic_prompts/register.py +206 -0
- datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
- datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
- datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
- datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
- datarobot_genai/drmcp/core/exceptions.py +25 -0
- datarobot_genai/drmcp/core/logging.py +98 -0
- datarobot_genai/drmcp/core/mcp_instance.py +542 -0
- datarobot_genai/drmcp/core/mcp_server_tools.py +129 -0
- datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
- datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
- datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
- datarobot_genai/drmcp/core/routes.py +436 -0
- datarobot_genai/drmcp/core/routes_utils.py +30 -0
- datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
- datarobot_genai/drmcp/core/telemetry.py +424 -0
- datarobot_genai/drmcp/core/tool_filter.py +108 -0
- datarobot_genai/drmcp/core/utils.py +131 -0
- datarobot_genai/drmcp/server.py +19 -0
- datarobot_genai/drmcp/test_utils/__init__.py +13 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +102 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +96 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +94 -0
- datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +234 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +151 -0
- datarobot_genai/drmcp/test_utils/utils.py +91 -0
- datarobot_genai/drmcp/tools/__init__.py +14 -0
- datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
- datarobot_genai/drmcp/tools/predictive/data.py +97 -0
- datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
- datarobot_genai/drmcp/tools/predictive/model.py +148 -0
- datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
- datarobot_genai/drmcp/tools/predictive/project.py +72 -0
- datarobot_genai/drmcp/tools/predictive/training.py +651 -0
- datarobot_genai/langgraph/__init__.py +0 -0
- datarobot_genai/langgraph/agent.py +341 -0
- datarobot_genai/langgraph/mcp.py +73 -0
- datarobot_genai/llama_index/__init__.py +16 -0
- datarobot_genai/llama_index/agent.py +50 -0
- datarobot_genai/llama_index/base.py +299 -0
- datarobot_genai/llama_index/mcp.py +79 -0
- datarobot_genai/nat/__init__.py +0 -0
- datarobot_genai/nat/agent.py +258 -0
- datarobot_genai/nat/datarobot_llm_clients.py +249 -0
- datarobot_genai/nat/datarobot_llm_providers.py +130 -0
- datarobot_genai/py.typed +0 -0
- datarobot_genai-0.2.0.dist-info/METADATA +139 -0
- datarobot_genai-0.2.0.dist-info/RECORD +101 -0
- datarobot_genai-0.2.0.dist-info/WHEEL +4 -0
- datarobot_genai-0.2.0.dist-info/entry_points.txt +3 -0
- datarobot_genai-0.2.0.dist-info/licenses/AUTHORS +2 -0
- datarobot_genai-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,542 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from collections.abc import Callable
|
|
17
|
+
from functools import wraps
|
|
18
|
+
from typing import Any
|
|
19
|
+
from typing import overload
|
|
20
|
+
|
|
21
|
+
from fastmcp import Context
|
|
22
|
+
from fastmcp import FastMCP
|
|
23
|
+
from fastmcp.exceptions import NotFoundError
|
|
24
|
+
from fastmcp.prompts.prompt import Prompt
|
|
25
|
+
from fastmcp.tools import FunctionTool
|
|
26
|
+
from fastmcp.tools import Tool
|
|
27
|
+
from fastmcp.utilities.types import NotSet
|
|
28
|
+
from fastmcp.utilities.types import NotSetT
|
|
29
|
+
from mcp.types import AnyFunction
|
|
30
|
+
from mcp.types import Tool as MCPTool
|
|
31
|
+
from mcp.types import ToolAnnotations
|
|
32
|
+
|
|
33
|
+
from .config import MCPServerConfig
|
|
34
|
+
from .config import get_config
|
|
35
|
+
from .dynamic_prompts.utils import get_prompt_name_no_duplicate
|
|
36
|
+
from .logging import log_execution
|
|
37
|
+
from .memory_management.manager import MemoryManager
|
|
38
|
+
from .memory_management.manager import get_memory_manager
|
|
39
|
+
from .telemetry import trace_execution
|
|
40
|
+
from .tool_filter import filter_tools_by_tags
|
|
41
|
+
from .tool_filter import list_all_tags
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
async def get_agent_and_storage_ids(
|
|
47
|
+
args: tuple[Any, ...], kwargs: dict[str, Any]
|
|
48
|
+
) -> tuple[str | None, str | None]:
|
|
49
|
+
"""
|
|
50
|
+
Extract agent ID from request context and get corresponding storage ID.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
args: Positional arguments that may contain a Context object
|
|
54
|
+
kwargs: Keyword arguments that may contain a Context object
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
Tuple of (agent_id, storage_id), both may be None if not found
|
|
59
|
+
"""
|
|
60
|
+
# Find the context argument if it exists
|
|
61
|
+
ctx = next((arg for arg in args if isinstance(arg, Context)), kwargs.get("ctx"))
|
|
62
|
+
|
|
63
|
+
# Extract X-Agent-Id if context and headers exist
|
|
64
|
+
agent_id = None
|
|
65
|
+
if (
|
|
66
|
+
ctx
|
|
67
|
+
and ctx.request_context
|
|
68
|
+
and ctx.request_context.request
|
|
69
|
+
and hasattr(ctx.request_context.request, "headers")
|
|
70
|
+
):
|
|
71
|
+
headers = ctx.request_context.request.headers
|
|
72
|
+
agent_id = headers.get("x-agent-id")
|
|
73
|
+
|
|
74
|
+
# If agent_id was found, get the active storage_id
|
|
75
|
+
storage_id = None
|
|
76
|
+
if agent_id and MemoryManager.is_initialized():
|
|
77
|
+
memory_manager = get_memory_manager()
|
|
78
|
+
if memory_manager:
|
|
79
|
+
storage_id = await memory_manager.get_active_storage_id_for_agent(agent_id)
|
|
80
|
+
|
|
81
|
+
return agent_id, storage_id
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class TaggedFastMCP(FastMCP):
|
|
85
|
+
"""Extended FastMCP that supports tags, deployments and other annotations directly in the
|
|
86
|
+
tool decorator.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
90
|
+
super().__init__(*args, **kwargs)
|
|
91
|
+
self._deployments_map: dict[str, str] = {}
|
|
92
|
+
self._prompts_map: dict[str, tuple[str, str]] = {}
|
|
93
|
+
|
|
94
|
+
@overload
|
|
95
|
+
def tool(
|
|
96
|
+
self,
|
|
97
|
+
name_or_fn: AnyFunction,
|
|
98
|
+
*,
|
|
99
|
+
name: str | None = None,
|
|
100
|
+
title: str | None = None,
|
|
101
|
+
description: str | None = None,
|
|
102
|
+
tags: set[str] | None = None,
|
|
103
|
+
output_schema: dict[str, Any] | None | NotSetT = NotSet,
|
|
104
|
+
annotations: ToolAnnotations | dict[str, Any] | None = None,
|
|
105
|
+
exclude_args: list[str] | None = None,
|
|
106
|
+
meta: dict[str, Any] | None = None,
|
|
107
|
+
enabled: bool | None = None,
|
|
108
|
+
) -> FunctionTool: ...
|
|
109
|
+
|
|
110
|
+
@overload
|
|
111
|
+
def tool(
|
|
112
|
+
self,
|
|
113
|
+
name_or_fn: str | None = None,
|
|
114
|
+
*,
|
|
115
|
+
name: str | None = None,
|
|
116
|
+
title: str | None = None,
|
|
117
|
+
description: str | None = None,
|
|
118
|
+
tags: set[str] | None = None,
|
|
119
|
+
output_schema: dict[str, Any] | None | NotSetT = NotSet,
|
|
120
|
+
annotations: ToolAnnotations | dict[str, Any] | None = None,
|
|
121
|
+
exclude_args: list[str] | None = None,
|
|
122
|
+
meta: dict[str, Any] | None = None,
|
|
123
|
+
enabled: bool | None = None,
|
|
124
|
+
) -> Callable[[AnyFunction], FunctionTool]: ...
|
|
125
|
+
|
|
126
|
+
def tool(
|
|
127
|
+
self,
|
|
128
|
+
name_or_fn: str | Callable[..., Any] | None = None,
|
|
129
|
+
*,
|
|
130
|
+
name: str | None = None,
|
|
131
|
+
title: str | None = None,
|
|
132
|
+
description: str | None = None,
|
|
133
|
+
tags: set[str] | None = None,
|
|
134
|
+
output_schema: dict[str, Any] | None | NotSetT = NotSet,
|
|
135
|
+
annotations: ToolAnnotations | dict[str, Any] | None = None,
|
|
136
|
+
exclude_args: list[str] | None = None,
|
|
137
|
+
meta: dict[str, Any] | None = None,
|
|
138
|
+
enabled: bool | None = None,
|
|
139
|
+
**kwargs: Any,
|
|
140
|
+
) -> Callable[[AnyFunction], FunctionTool] | FunctionTool:
|
|
141
|
+
"""
|
|
142
|
+
Extend tool decorator that supports tags and other annotations, while remaining
|
|
143
|
+
signature-compatible with FastMCP.tool to avoid recursion issues with partials.
|
|
144
|
+
"""
|
|
145
|
+
if isinstance(annotations, dict):
|
|
146
|
+
annotations = ToolAnnotations(**annotations)
|
|
147
|
+
|
|
148
|
+
# Ensure tags are available both via native fastmcp `tags` and inside annotations
|
|
149
|
+
if tags is not None:
|
|
150
|
+
tags_ = sorted(tags)
|
|
151
|
+
if annotations is None:
|
|
152
|
+
annotations = ToolAnnotations() # type: ignore[call-arg]
|
|
153
|
+
annotations.tags = tags_ # type: ignore[attr-defined, union-attr]
|
|
154
|
+
else:
|
|
155
|
+
# At this point, annotations is ToolAnnotations (not dict)
|
|
156
|
+
assert isinstance(annotations, ToolAnnotations)
|
|
157
|
+
annotations.tags = tags_ # type: ignore[attr-defined]
|
|
158
|
+
|
|
159
|
+
return super().tool(
|
|
160
|
+
name_or_fn,
|
|
161
|
+
name=name,
|
|
162
|
+
title=title,
|
|
163
|
+
description=description,
|
|
164
|
+
tags=tags,
|
|
165
|
+
output_schema=output_schema
|
|
166
|
+
if output_schema is not None
|
|
167
|
+
else kwargs.get("output_schema"),
|
|
168
|
+
annotations=annotations,
|
|
169
|
+
exclude_args=exclude_args,
|
|
170
|
+
meta=meta,
|
|
171
|
+
enabled=enabled,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
async def list_tools(
|
|
175
|
+
self, tags: list[str] | None = None, match_all: bool = False
|
|
176
|
+
) -> list[MCPTool]:
|
|
177
|
+
"""
|
|
178
|
+
List all available tools, optionally filtered by tags.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
tags: Optional list of tags to filter by. If None, returns all tools.
|
|
182
|
+
match_all: If True, tool must have all specified tags (AND logic).
|
|
183
|
+
If False, tool must have at least one tag (OR logic).
|
|
184
|
+
Only used when tags is provided.
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
List of MCPTool objects that match the tag criteria.
|
|
189
|
+
"""
|
|
190
|
+
# Get all tools from the parent class
|
|
191
|
+
all_tools = await super()._list_tools_mcp()
|
|
192
|
+
|
|
193
|
+
# If no tags specified, return all tools
|
|
194
|
+
if not tags:
|
|
195
|
+
return all_tools
|
|
196
|
+
|
|
197
|
+
# Filter tools by tags
|
|
198
|
+
filtered_tools = filter_tools_by_tags(list(all_tools), tags, match_all)
|
|
199
|
+
|
|
200
|
+
return filtered_tools # type: ignore[return-value]
|
|
201
|
+
|
|
202
|
+
async def get_all_tags(self) -> list[str]:
|
|
203
|
+
"""
|
|
204
|
+
Get all unique tags from all registered tools.
|
|
205
|
+
|
|
206
|
+
Returns
|
|
207
|
+
-------
|
|
208
|
+
List of all unique tags sorted alphabetically.
|
|
209
|
+
"""
|
|
210
|
+
all_tools = await self._list_tools_mcp()
|
|
211
|
+
return list_all_tags(list(all_tools))
|
|
212
|
+
|
|
213
|
+
async def get_deployment_mapping(self) -> dict[str, str]:
|
|
214
|
+
"""
|
|
215
|
+
Get the list of deployment IDs for all registered dynamic tools.
|
|
216
|
+
|
|
217
|
+
Returns
|
|
218
|
+
-------
|
|
219
|
+
Dictionary mapping deployment IDs to tool names.
|
|
220
|
+
"""
|
|
221
|
+
return self._deployments_map.copy()
|
|
222
|
+
|
|
223
|
+
async def set_deployment_mapping(self, deployment_id: str, tool_name: str) -> None:
|
|
224
|
+
"""
|
|
225
|
+
Add or update the mapping of a deployment ID to a tool name.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
deployment_id: The ID of the deployment.
|
|
229
|
+
tool_name: The name of the tool associated with the deployment.
|
|
230
|
+
"""
|
|
231
|
+
existing = self._deployments_map.get(deployment_id)
|
|
232
|
+
if existing and existing != tool_name:
|
|
233
|
+
logger.debug(
|
|
234
|
+
f"Deployment ID {deployment_id} already mapped to {existing}, updating to "
|
|
235
|
+
f"{tool_name}"
|
|
236
|
+
)
|
|
237
|
+
try:
|
|
238
|
+
self.remove_tool(existing)
|
|
239
|
+
except NotFoundError:
|
|
240
|
+
logger.debug(f"Tool {existing} not found in registry, skipping removal")
|
|
241
|
+
self._deployments_map[deployment_id] = tool_name
|
|
242
|
+
|
|
243
|
+
async def remove_deployment_mapping(self, deployment_id: str) -> None:
|
|
244
|
+
"""
|
|
245
|
+
Remove the mapping of a deployment ID to a tool name.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
deployment_id: The ID of the deployment to remove.
|
|
249
|
+
"""
|
|
250
|
+
removed = self._deployments_map.pop(deployment_id, None)
|
|
251
|
+
if removed is not None:
|
|
252
|
+
logger.debug(f"Removed deployment mapping for ID {deployment_id} with tool {removed}")
|
|
253
|
+
try:
|
|
254
|
+
self.remove_tool(removed)
|
|
255
|
+
except NotFoundError:
|
|
256
|
+
logger.debug(f"Tool {removed} not found in registry, skipping removal")
|
|
257
|
+
|
|
258
|
+
async def get_prompt_mapping(self) -> dict[str, tuple[str, str]]:
|
|
259
|
+
"""
|
|
260
|
+
Get the list of prompt ID for all registered dynamic prompts.
|
|
261
|
+
|
|
262
|
+
Returns
|
|
263
|
+
-------
|
|
264
|
+
Dictionary mapping prompt template id to prompt template version id and name
|
|
265
|
+
"""
|
|
266
|
+
return self._prompts_map.copy()
|
|
267
|
+
|
|
268
|
+
async def set_prompt_mapping(
|
|
269
|
+
self, prompt_template_id: str, prompt_template_version_id: str, prompt_name: str
|
|
270
|
+
) -> None:
|
|
271
|
+
"""
|
|
272
|
+
Add or update the mapping of a deployment ID to a tool name.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
prompt_template_id: The ID of the prompt template.
|
|
276
|
+
prompt_template_version_id: The ID of the prompt template version.
|
|
277
|
+
prompt_name: The prompt name associated with the prompt template id and version.
|
|
278
|
+
"""
|
|
279
|
+
existing_prompt_template = self._prompts_map.get(prompt_template_id)
|
|
280
|
+
|
|
281
|
+
if existing_prompt_template:
|
|
282
|
+
existing_prompt_template_version_id, _ = existing_prompt_template
|
|
283
|
+
|
|
284
|
+
logger.debug(
|
|
285
|
+
f"Prompt template ID {prompt_template_id} "
|
|
286
|
+
f"already mapped to {existing_prompt_template_version_id}. "
|
|
287
|
+
f"Updating to version id = {prompt_template_version_id} and name = {prompt_name}"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
self._prompts_map[prompt_template_id] = (prompt_template_version_id, prompt_name)
|
|
291
|
+
|
|
292
|
+
async def remove_prompt_mapping(
|
|
293
|
+
self, prompt_template_id: str, prompt_template_version_id: str
|
|
294
|
+
) -> None:
|
|
295
|
+
"""
|
|
296
|
+
Remove the mapping of a prompt_template ID to a version and prompt name.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
prompt_template_id: The ID of the prompt template to remove.
|
|
300
|
+
prompt_template_version_id: The ID of the prompt template version to remove.
|
|
301
|
+
"""
|
|
302
|
+
if existing_prompt_template := self._prompts_map.get(prompt_template_id):
|
|
303
|
+
existing_prompt_template_version_id, _ = existing_prompt_template
|
|
304
|
+
if existing_prompt_template_version_id != prompt_template_version_id:
|
|
305
|
+
logger.debug(
|
|
306
|
+
f"Found prompt template with id = {prompt_template_id} in registry, "
|
|
307
|
+
f"but with different version = {existing_prompt_template_version_id}, "
|
|
308
|
+
f"skipping removal."
|
|
309
|
+
)
|
|
310
|
+
else:
|
|
311
|
+
prompts_d = await mcp.get_prompts()
|
|
312
|
+
for prompt in prompts_d.values():
|
|
313
|
+
if (
|
|
314
|
+
prompt.meta is not None
|
|
315
|
+
and prompt.meta.get("prompt_template_id", "") == prompt_template_id
|
|
316
|
+
and prompt.meta.get("prompt_template_version_id", "")
|
|
317
|
+
== prompt_template_version_id
|
|
318
|
+
):
|
|
319
|
+
prompt.disable()
|
|
320
|
+
|
|
321
|
+
self._prompts_map.pop(prompt_template_id, None)
|
|
322
|
+
else:
|
|
323
|
+
logger.debug(
|
|
324
|
+
f"Do not found prompt template with id = {prompt_template_id} in registry, "
|
|
325
|
+
f"skipping removal."
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
# Create the tagged MCP instance
|
|
330
|
+
mcp_server_configs: MCPServerConfig = get_config()
|
|
331
|
+
|
|
332
|
+
mcp = TaggedFastMCP(
|
|
333
|
+
name=mcp_server_configs.mcp_server_name,
|
|
334
|
+
on_duplicate_tools=mcp_server_configs.tool_registration_duplicate_behavior,
|
|
335
|
+
on_duplicate_prompts=mcp_server_configs.prompt_registration_duplicate_behavior,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def dr_core_mcp_tool(
|
|
340
|
+
name: str | None = None,
|
|
341
|
+
description: str | None = None,
|
|
342
|
+
tags: set[str] | None = None,
|
|
343
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
344
|
+
"""Combine decorator that includes mcp.tool() and dr_mcp_extras()."""
|
|
345
|
+
|
|
346
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
347
|
+
instrumented = dr_mcp_extras()(func)
|
|
348
|
+
mcp.tool(name=name, description=description, tags=tags)(instrumented)
|
|
349
|
+
return instrumented
|
|
350
|
+
|
|
351
|
+
return decorator
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
async def memory_aware_wrapper(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
|
355
|
+
"""
|
|
356
|
+
Add memory management capabilities to any async function.
|
|
357
|
+
Extracts agent and storage IDs from the context and adds them to kwargs if found.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
func: The async function to wrap
|
|
361
|
+
*args: Positional arguments to pass to the function
|
|
362
|
+
**kwargs: Keyword arguments to pass to the function
|
|
363
|
+
|
|
364
|
+
Returns
|
|
365
|
+
-------
|
|
366
|
+
The result of calling the wrapped function
|
|
367
|
+
"""
|
|
368
|
+
# Get agent and storage IDs from context
|
|
369
|
+
agent_id, storage_id = await get_agent_and_storage_ids(args, kwargs)
|
|
370
|
+
|
|
371
|
+
# Add IDs to kwargs if found
|
|
372
|
+
if agent_id and storage_id:
|
|
373
|
+
kwargs["agent_id"] = agent_id
|
|
374
|
+
kwargs["storage_id"] = storage_id
|
|
375
|
+
|
|
376
|
+
# Call the original function
|
|
377
|
+
return await func(*args, **kwargs)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def dr_mcp_tool(
|
|
381
|
+
name: str | None = None,
|
|
382
|
+
description: str | None = None,
|
|
383
|
+
tags: set[str] | None = None,
|
|
384
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
385
|
+
"""Combine decorator that includes mcp.tool(), dr_mcp_extras(), and capture memory ids from
|
|
386
|
+
the request headers if they exist.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
name: Tool name
|
|
390
|
+
description: Tool description
|
|
391
|
+
tags: Optional set of tags to apply to the tool
|
|
392
|
+
"""
|
|
393
|
+
|
|
394
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
395
|
+
@wraps(func)
|
|
396
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
397
|
+
return await memory_aware_wrapper(func, *args, **kwargs)
|
|
398
|
+
|
|
399
|
+
# Apply the MCP decorators
|
|
400
|
+
instrumented = dr_mcp_extras()(wrapper)
|
|
401
|
+
mcp.tool(name=name, description=description, tags=tags)(instrumented)
|
|
402
|
+
return instrumented
|
|
403
|
+
|
|
404
|
+
return decorator
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def dr_mcp_extras(
|
|
408
|
+
type: str = "tool",
|
|
409
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
410
|
+
"""Combine decorator that includes log_execution and trace_execution().
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
type: default is "tool", other options are "prompt", "resource"
|
|
414
|
+
"""
|
|
415
|
+
|
|
416
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
417
|
+
return log_execution(trace_execution(trace_type=type)(func))
|
|
418
|
+
|
|
419
|
+
return decorator
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
async def register_tools(
|
|
423
|
+
fn: AnyFunction,
|
|
424
|
+
name: str | None = None,
|
|
425
|
+
title: str | None = None,
|
|
426
|
+
description: str | None = None,
|
|
427
|
+
tags: set[str] | None = None,
|
|
428
|
+
deployment_id: str | None = None,
|
|
429
|
+
) -> Tool:
|
|
430
|
+
"""
|
|
431
|
+
Register new tools after server has started.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
fn: The function to register as a tool
|
|
435
|
+
name: Optional name for the tool (defaults to function name)
|
|
436
|
+
title: Optional human-readable title for the tool
|
|
437
|
+
description: Optional description of what the tool does
|
|
438
|
+
tags: Optional set of tags to apply to the tool
|
|
439
|
+
deployment_id: Optional deployment ID associated with the tool
|
|
440
|
+
|
|
441
|
+
Returns
|
|
442
|
+
-------
|
|
443
|
+
The registered Tool object
|
|
444
|
+
"""
|
|
445
|
+
tool_name = name or fn.__name__
|
|
446
|
+
logger.info(f"Registering new tool: {tool_name}")
|
|
447
|
+
|
|
448
|
+
# Create a memory-aware version of the function
|
|
449
|
+
@wraps(fn)
|
|
450
|
+
async def memory_aware_fn(*args: Any, **kwargs: Any) -> Any:
|
|
451
|
+
return await memory_aware_wrapper(fn, *args, **kwargs)
|
|
452
|
+
|
|
453
|
+
# Apply dr_mcp_extras to the memory-aware function
|
|
454
|
+
wrapped_fn = dr_mcp_extras()(memory_aware_fn)
|
|
455
|
+
|
|
456
|
+
# Create annotations with tags, deployment_id if provided
|
|
457
|
+
annotations = ToolAnnotations() # type: ignore[call-arg]
|
|
458
|
+
if tags is not None:
|
|
459
|
+
annotations.tags = tags # type: ignore[attr-defined]
|
|
460
|
+
if deployment_id is not None:
|
|
461
|
+
annotations.deployment_id = deployment_id # type: ignore[attr-defined]
|
|
462
|
+
|
|
463
|
+
tool = Tool.from_function(
|
|
464
|
+
fn=wrapped_fn,
|
|
465
|
+
name=tool_name,
|
|
466
|
+
title=title,
|
|
467
|
+
description=description,
|
|
468
|
+
annotations=annotations,
|
|
469
|
+
tags=tags,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Register the tool
|
|
473
|
+
registered_tool = mcp.add_tool(tool)
|
|
474
|
+
|
|
475
|
+
# Map deployment ID to tool name if provided
|
|
476
|
+
if deployment_id:
|
|
477
|
+
await mcp.set_deployment_mapping(deployment_id, tool_name)
|
|
478
|
+
|
|
479
|
+
# Verify tool is registered
|
|
480
|
+
tools = await mcp.list_tools()
|
|
481
|
+
if not any(tool.name == tool_name for tool in tools):
|
|
482
|
+
raise RuntimeError(f"Tool {tool_name} was not registered successfully")
|
|
483
|
+
logger.info(f"Registered tools: {len(tools)}")
|
|
484
|
+
|
|
485
|
+
return registered_tool
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
async def register_prompt(
|
|
489
|
+
fn: AnyFunction,
|
|
490
|
+
name: str | None = None,
|
|
491
|
+
title: str | None = None,
|
|
492
|
+
description: str | None = None,
|
|
493
|
+
tags: set[str] | None = None,
|
|
494
|
+
meta: dict[str, Any] | None = None,
|
|
495
|
+
prompt_template: tuple[str, str] | None = None,
|
|
496
|
+
) -> Prompt:
|
|
497
|
+
"""
|
|
498
|
+
Register new prompt after server has started.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
fn: The function to register as a prompt
|
|
502
|
+
name: Optional name for the prompt (defaults to function name)
|
|
503
|
+
title: Optional human-readable title for the prompt
|
|
504
|
+
description: Optional description of what the prompt does
|
|
505
|
+
tags: Optional set of tags to apply to the prompt
|
|
506
|
+
meta: Optional dict of metadata to apply to the prompt
|
|
507
|
+
prompt_template: Optional (id, version id) of the prompt template
|
|
508
|
+
|
|
509
|
+
Returns
|
|
510
|
+
-------
|
|
511
|
+
The registered Prompt object
|
|
512
|
+
"""
|
|
513
|
+
prompt_name = name or fn.__name__
|
|
514
|
+
logger.info(f"Registering new prompt: {prompt_name}")
|
|
515
|
+
wrapped_fn = dr_mcp_extras(type="prompt")(fn)
|
|
516
|
+
|
|
517
|
+
prompt_name_no_duplicate = await get_prompt_name_no_duplicate(mcp, prompt_name)
|
|
518
|
+
|
|
519
|
+
prompt = Prompt.from_function(
|
|
520
|
+
fn=wrapped_fn,
|
|
521
|
+
name=prompt_name_no_duplicate,
|
|
522
|
+
title=title,
|
|
523
|
+
description=description,
|
|
524
|
+
tags=tags,
|
|
525
|
+
meta=meta,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
# Register the prompt
|
|
529
|
+
registered_prompt = mcp.add_prompt(prompt)
|
|
530
|
+
if prompt_template:
|
|
531
|
+
prompt_template_id, prompt_template_version_id = prompt_template
|
|
532
|
+
await mcp.set_prompt_mapping(
|
|
533
|
+
prompt_template_id, prompt_template_version_id, prompt_name_no_duplicate
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
# Verify prompt is registered
|
|
537
|
+
prompts = await mcp.get_prompts()
|
|
538
|
+
if not any(prompt.name == prompt_name_no_duplicate for prompt in prompts.values()):
|
|
539
|
+
raise RuntimeError(f"Prompt {prompt_name_no_duplicate} was not registered successfully")
|
|
540
|
+
logger.info(f"Registered prompts: {len(prompts)}")
|
|
541
|
+
|
|
542
|
+
return registered_prompt
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
|
|
17
|
+
from .mcp_instance import dr_core_mcp_tool
|
|
18
|
+
from .mcp_instance import mcp
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dr_core_mcp_tool(tags={"mcp_server_tools", "metadata"})
|
|
24
|
+
async def get_all_available_tags() -> str:
|
|
25
|
+
"""
|
|
26
|
+
List all unique tags from all registered tools.
|
|
27
|
+
|
|
28
|
+
Returns
|
|
29
|
+
-------
|
|
30
|
+
A string containing all available tags, one per line.
|
|
31
|
+
"""
|
|
32
|
+
tags = await mcp.get_all_tags()
|
|
33
|
+
if not tags:
|
|
34
|
+
return "No tags found in any tools."
|
|
35
|
+
|
|
36
|
+
return "\n".join(sorted(tags))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dr_core_mcp_tool(tags={"mcp_server_tools", "metadata", "discovery"})
|
|
40
|
+
async def list_tools_by_tags(tags: list[str] | None = None, match_all: bool = False) -> str:
|
|
41
|
+
"""
|
|
42
|
+
List tools filtered by tags.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
tags: Optional list of tags to filter by. If None, returns all tools.
|
|
46
|
+
match_all: If True, tool must have all specified tags (AND logic).
|
|
47
|
+
If False, tool must have at least one tag (OR logic).
|
|
48
|
+
Only used when tags is provided.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
A formatted string listing tools that match the tag criteria.
|
|
53
|
+
"""
|
|
54
|
+
tools = await mcp.list_tools(tags=tags, match_all=match_all)
|
|
55
|
+
|
|
56
|
+
if not tools:
|
|
57
|
+
if tags:
|
|
58
|
+
logic = "all" if match_all else "any"
|
|
59
|
+
return f"No tools found with {logic} of the tags: {', '.join(tags)}"
|
|
60
|
+
else:
|
|
61
|
+
return "No tools found."
|
|
62
|
+
|
|
63
|
+
result = []
|
|
64
|
+
if tags:
|
|
65
|
+
logic = "all" if match_all else "any"
|
|
66
|
+
result.append(f"Tools with {logic} of the tags: {', '.join(tags)}")
|
|
67
|
+
else:
|
|
68
|
+
result.append("All available tools:")
|
|
69
|
+
|
|
70
|
+
result.append("")
|
|
71
|
+
|
|
72
|
+
for i, tool in enumerate(tools, 1):
|
|
73
|
+
tool_tags = []
|
|
74
|
+
if tool.annotations and hasattr(tool.annotations, "extra") and tool.annotations.extra:
|
|
75
|
+
tool_tags = tool.annotations.extra.get("tags", [])
|
|
76
|
+
|
|
77
|
+
result.append(f"{i}. {tool.name}")
|
|
78
|
+
result.append(f" Description: {tool.description}")
|
|
79
|
+
if tool_tags:
|
|
80
|
+
result.append(f" Tags: {', '.join(tool_tags)}")
|
|
81
|
+
result.append("")
|
|
82
|
+
|
|
83
|
+
return "\n".join(result)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dr_core_mcp_tool(tags={"mcp_server_tools", "metadata", "discovery"})
|
|
87
|
+
async def get_tool_info_by_name(tool_name: str) -> str:
|
|
88
|
+
"""
|
|
89
|
+
Get detailed information about a specific tool by name.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
tool_name: The name of the tool to get information about.
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
A formatted string with detailed information about the tool.
|
|
97
|
+
"""
|
|
98
|
+
all_tools = await mcp.list_tools()
|
|
99
|
+
|
|
100
|
+
for tool in all_tools:
|
|
101
|
+
if tool.name == tool_name:
|
|
102
|
+
result = [f"Tool: {tool.name}"]
|
|
103
|
+
result.append(f"Description: {tool.description}")
|
|
104
|
+
|
|
105
|
+
# Get tags
|
|
106
|
+
tool_tags = []
|
|
107
|
+
if tool.annotations and hasattr(tool.annotations, "extra") and tool.annotations.extra:
|
|
108
|
+
tool_tags = tool.annotations.extra.get("tags", [])
|
|
109
|
+
|
|
110
|
+
if tool_tags:
|
|
111
|
+
result.append(f"Tags: {', '.join(tool_tags)}")
|
|
112
|
+
else:
|
|
113
|
+
result.append("Tags: None")
|
|
114
|
+
|
|
115
|
+
# Get input schema info
|
|
116
|
+
if (
|
|
117
|
+
tool.inputSchema
|
|
118
|
+
and hasattr(tool.inputSchema, "properties")
|
|
119
|
+
and tool.inputSchema.properties
|
|
120
|
+
):
|
|
121
|
+
result.append("Parameters:")
|
|
122
|
+
for param_name, param_info in tool.inputSchema.properties.items():
|
|
123
|
+
param_type = param_info.get("type", "unknown")
|
|
124
|
+
param_desc = param_info.get("description", "No description")
|
|
125
|
+
result.append(f" - {param_name} ({param_type}): {param_desc}")
|
|
126
|
+
|
|
127
|
+
return "\n".join(result)
|
|
128
|
+
|
|
129
|
+
return f"Tool '{tool_name}' not found."
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|