datarobot-genai 0.1.59__tar.gz → 0.1.70__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/PKG-INFO +2 -2
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/pyproject.toml +2 -2
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/agents/base.py +7 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/custom_model.py +5 -0
- datarobot_genai-0.1.70/src/datarobot_genai/core/mcp/common.py +218 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/utils/auth.py +64 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/crewai/base.py +34 -55
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/crewai/mcp.py +4 -7
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/auth.py +28 -25
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/clients.py +67 -3
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/config.py +0 -8
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dr_mcp_server.py +10 -3
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dr_mcp_server_logo.py +12 -1
- datarobot_genai-0.1.70/src/datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai-0.1.70/src/datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +128 -0
- datarobot_genai-0.1.70/src/datarobot_genai/drmcp/core/dynamic_prompts/register.py +206 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/mcp_instance.py +10 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/mcp_server_tools.py +2 -2
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/routes.py +125 -28
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/langgraph/agent.py +5 -6
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/langgraph/mcp.py +5 -7
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/llama_index/base.py +1 -2
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/llama_index/mcp.py +4 -5
- datarobot_genai-0.1.70/src/datarobot_genai/nat/agent.py +258 -0
- datarobot_genai-0.1.59/src/datarobot_genai/core/mcp/common.py +0 -109
- datarobot_genai-0.1.59/src/datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +0 -91
- datarobot_genai-0.1.59/src/datarobot_genai/drmcp/core/dynamic_prompts/register.py +0 -150
- datarobot_genai-0.1.59/src/datarobot_genai/nat/agent.py +0 -137
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/.gitignore +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/AUTHORS +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/LICENSE +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/README.md +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/agents/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/chat/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/chat/auth.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/chat/client.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/chat/responses.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/cli/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/cli/agent_environment.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/cli/agent_kernel.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/mcp/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/telemetry_agent.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/utils/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/utils/urls.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/crewai/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/crewai/agent.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/crewai/events.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/config_utils.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/constants.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/credentials.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_prompts/utils.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/register.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/schema.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/exceptions.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/logging.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/memory_management/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/memory_management/manager.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/memory_management/memory_tools.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/routes_utils.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/server_life_cycle.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/telemetry.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/tool_filter.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/utils.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/server.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/integration_mcp_server.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/tool_base_ete.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/utils.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/data.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/deployment.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/deployment_info.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/model.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/predict.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/predict_realtime.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/project.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/training.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/langgraph/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/llama_index/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/llama_index/agent.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/nat/__init__.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/nat/datarobot_llm_clients.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/nat/datarobot_llm_providers.py +0 -0
- {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/py.typed +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: datarobot-genai
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.70
|
|
4
4
|
Summary: Generic helpers for GenAI
|
|
5
5
|
Project-URL: Homepage, https://github.com/datarobot-oss/datarobot-genai
|
|
6
6
|
Author: DataRobot, Inc.
|
|
@@ -32,7 +32,7 @@ Requires-Dist: aiohttp<4.0.0,>=3.9.0; extra == 'drmcp'
|
|
|
32
32
|
Requires-Dist: aiosignal<2.0.0,>=1.3.1; extra == 'drmcp'
|
|
33
33
|
Requires-Dist: boto3<2.0.0,>=1.34.0; extra == 'drmcp'
|
|
34
34
|
Requires-Dist: datarobot-asgi-middleware<1.0.0,>=0.2.0; extra == 'drmcp'
|
|
35
|
-
Requires-Dist: fastmcp
|
|
35
|
+
Requires-Dist: fastmcp<3.0.0,>=2.13.0.2; extra == 'drmcp'
|
|
36
36
|
Requires-Dist: httpx<1.0.0,>=0.28.1; extra == 'drmcp'
|
|
37
37
|
Requires-Dist: opentelemetry-api<2.0.0,>=1.22.0; extra == 'drmcp'
|
|
38
38
|
Requires-Dist: opentelemetry-exporter-otlp-proto-http<2.0.0,>=1.22.0; extra == 'drmcp'
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "datarobot-genai"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.70"
|
|
8
8
|
description = "Generic helpers for GenAI"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.10, <3.13"
|
|
@@ -84,7 +84,7 @@ drmcp = [
|
|
|
84
84
|
"aiohttp>=3.9.0,<4.0.0",
|
|
85
85
|
"aiohttp-retry>=2.8.3,<3.0.0",
|
|
86
86
|
"aiosignal>=1.3.1,<2.0.0",
|
|
87
|
-
"fastmcp
|
|
87
|
+
"fastmcp>=2.13.0.2,<3.0.0",
|
|
88
88
|
]
|
|
89
89
|
|
|
90
90
|
[tool.hatch.build.targets.wheel]
|
|
@@ -52,6 +52,7 @@ class BaseAgent(Generic[TTool], abc.ABC):
|
|
|
52
52
|
verbose: bool | str | None = True,
|
|
53
53
|
timeout: int | None = 90,
|
|
54
54
|
authorization_context: dict[str, Any] | None = None,
|
|
55
|
+
forwarded_headers: dict[str, str] | None = None,
|
|
55
56
|
**_: Any,
|
|
56
57
|
) -> None:
|
|
57
58
|
self.api_key = api_key or os.environ.get("DATAROBOT_API_TOKEN")
|
|
@@ -68,6 +69,7 @@ class BaseAgent(Generic[TTool], abc.ABC):
|
|
|
68
69
|
self.verbose = bool(verbose)
|
|
69
70
|
self._mcp_tools: list[TTool] = []
|
|
70
71
|
self._authorization_context = authorization_context or {}
|
|
72
|
+
self._forwarded_headers: dict[str, str] = forwarded_headers or {}
|
|
71
73
|
|
|
72
74
|
def set_mcp_tools(self, tools: list[TTool]) -> None:
|
|
73
75
|
self._mcp_tools = tools
|
|
@@ -86,6 +88,11 @@ class BaseAgent(Generic[TTool], abc.ABC):
|
|
|
86
88
|
"""Return the authorization context for this agent."""
|
|
87
89
|
return self._authorization_context
|
|
88
90
|
|
|
91
|
+
@property
|
|
92
|
+
def forwarded_headers(self) -> dict[str, str]:
|
|
93
|
+
"""Return the forwarded headers for this agent."""
|
|
94
|
+
return self._forwarded_headers
|
|
95
|
+
|
|
89
96
|
def litellm_api_base(self, deployment_id: str | None) -> str:
|
|
90
97
|
return get_api_base(self.api_base, deployment_id)
|
|
91
98
|
|
|
@@ -139,6 +139,11 @@ def chat_entrypoint(
|
|
|
139
139
|
completion_create_params["authorization_context"] = resolve_authorization_context(
|
|
140
140
|
completion_create_params, **kwargs
|
|
141
141
|
)
|
|
142
|
+
# Keep only allowed headers from the forwarded_headers.
|
|
143
|
+
incoming_headers = kwargs.get("headers", {}) or {}
|
|
144
|
+
allowed_headers = {"x-datarobot-api-token", "x-datarobot-api-key"}
|
|
145
|
+
forwarded_headers = {k: v for k, v in incoming_headers.items() if k.lower() in allowed_headers}
|
|
146
|
+
completion_create_params["forwarded_headers"] = forwarded_headers
|
|
142
147
|
|
|
143
148
|
# Instantiate user agent with all supplied completion params including auth context
|
|
144
149
|
agent = agent_cls(**completion_create_params)
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
import re
|
|
18
|
+
from http import HTTPStatus
|
|
19
|
+
from typing import Any
|
|
20
|
+
from typing import Literal
|
|
21
|
+
|
|
22
|
+
import requests
|
|
23
|
+
from datarobot.core.config import DataRobotAppFrameworkBaseSettings
|
|
24
|
+
from pydantic import field_validator
|
|
25
|
+
|
|
26
|
+
from datarobot_genai.core.utils.auth import AuthContextHeaderHandler
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MCPConfig(DataRobotAppFrameworkBaseSettings):
|
|
32
|
+
"""Configuration for MCP server connection.
|
|
33
|
+
|
|
34
|
+
Derived values are exposed as properties rather than stored, avoiding
|
|
35
|
+
Pydantic field validation/serialization concerns for internal helpers.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
external_mcp_url: str | None = None
|
|
39
|
+
external_mcp_headers: str | None = None
|
|
40
|
+
external_mcp_transport: Literal["sse", "streamable-http"] = "streamable-http"
|
|
41
|
+
mcp_deployment_id: str | None = None
|
|
42
|
+
datarobot_endpoint: str | None = None
|
|
43
|
+
datarobot_api_token: str | None = None
|
|
44
|
+
authorization_context: dict[str, Any] | None = None
|
|
45
|
+
forwarded_headers: dict[str, str] | None = None
|
|
46
|
+
mcp_server_port: int | None = None
|
|
47
|
+
|
|
48
|
+
_auth_context_handler: AuthContextHeaderHandler | None = None
|
|
49
|
+
_server_config: dict[str, Any] | None = None
|
|
50
|
+
|
|
51
|
+
@field_validator("external_mcp_headers", mode="before")
|
|
52
|
+
@classmethod
|
|
53
|
+
def validate_external_mcp_headers(cls, value: str | None) -> str | None:
|
|
54
|
+
if value is None:
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
candidate = value.strip()
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
json.loads(candidate)
|
|
61
|
+
except json.JSONDecodeError:
|
|
62
|
+
msg = "external_mcp_headers must be valid JSON"
|
|
63
|
+
logger.warning(msg)
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
return candidate
|
|
67
|
+
|
|
68
|
+
@field_validator("mcp_deployment_id", mode="before")
|
|
69
|
+
@classmethod
|
|
70
|
+
def validate_mcp_deployment_id(cls, value: str | None) -> str | None:
|
|
71
|
+
if value is None:
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
candidate = value.strip()
|
|
75
|
+
|
|
76
|
+
if not re.fullmatch(r"[0-9a-fA-F]{24}", candidate):
|
|
77
|
+
msg = "mcp_deployment_id must be a valid 24-character hex ID"
|
|
78
|
+
logger.warning(msg)
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
return candidate
|
|
82
|
+
|
|
83
|
+
def _authorization_bearer_header(self) -> dict[str, str]:
|
|
84
|
+
"""Return Authorization header with Bearer token or empty dict."""
|
|
85
|
+
if not self.datarobot_api_token:
|
|
86
|
+
return {}
|
|
87
|
+
auth = (
|
|
88
|
+
self.datarobot_api_token
|
|
89
|
+
if self.datarobot_api_token.startswith("Bearer ")
|
|
90
|
+
else f"Bearer {self.datarobot_api_token}"
|
|
91
|
+
)
|
|
92
|
+
return {"Authorization": auth}
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def auth_context_handler(self) -> AuthContextHeaderHandler:
|
|
96
|
+
if self._auth_context_handler is None:
|
|
97
|
+
self._auth_context_handler = AuthContextHeaderHandler()
|
|
98
|
+
return self._auth_context_handler
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def server_config(self) -> dict[str, Any] | None:
|
|
102
|
+
if self._server_config is None:
|
|
103
|
+
self._server_config = self._build_server_config()
|
|
104
|
+
return self._server_config
|
|
105
|
+
|
|
106
|
+
def _authorization_context_header(self) -> dict[str, str]:
|
|
107
|
+
"""Return X-DataRobot-Authorization-Context header or empty dict."""
|
|
108
|
+
try:
|
|
109
|
+
return self.auth_context_handler.get_header(self.authorization_context)
|
|
110
|
+
except (LookupError, RuntimeError):
|
|
111
|
+
# Authorization context not available (e.g., in tests)
|
|
112
|
+
return {}
|
|
113
|
+
|
|
114
|
+
def _build_authenticated_headers(self) -> dict[str, str]:
|
|
115
|
+
"""Build headers for authenticated requests.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
Dictionary containing forwarded headers (if available) and authentication headers.
|
|
120
|
+
"""
|
|
121
|
+
headers: dict[str, str] = {}
|
|
122
|
+
if self.forwarded_headers:
|
|
123
|
+
headers.update(self.forwarded_headers)
|
|
124
|
+
headers.update(self._authorization_bearer_header())
|
|
125
|
+
headers.update(self._authorization_context_header())
|
|
126
|
+
return headers
|
|
127
|
+
|
|
128
|
+
def _check_localhost_server(self, url: str, timeout: float = 2.0) -> bool:
|
|
129
|
+
"""Check if MCP server is running on localhost.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
url : str
|
|
134
|
+
The URL to check.
|
|
135
|
+
timeout : float, optional
|
|
136
|
+
Request timeout in seconds (default: 2.0).
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
bool
|
|
141
|
+
True if server is running and responding with OK status, False otherwise.
|
|
142
|
+
"""
|
|
143
|
+
try:
|
|
144
|
+
response = requests.get(url, timeout=timeout)
|
|
145
|
+
return (
|
|
146
|
+
response.status_code == HTTPStatus.OK
|
|
147
|
+
and response.json().get("message") == "DataRobot MCP Server is running"
|
|
148
|
+
)
|
|
149
|
+
except requests.RequestException as e:
|
|
150
|
+
logger.debug(f"Failed to connect to MCP server at {url}: {e}")
|
|
151
|
+
return False
|
|
152
|
+
|
|
153
|
+
def _build_server_config(self) -> dict[str, Any] | None:
|
|
154
|
+
"""
|
|
155
|
+
Get MCP server configuration.
|
|
156
|
+
|
|
157
|
+
Returns
|
|
158
|
+
-------
|
|
159
|
+
Server configuration dict with url, transport, and optional headers,
|
|
160
|
+
or None if not configured.
|
|
161
|
+
"""
|
|
162
|
+
if self.mcp_deployment_id:
|
|
163
|
+
# DataRobot deployment ID - requires authentication
|
|
164
|
+
if self.datarobot_endpoint is None:
|
|
165
|
+
raise ValueError(
|
|
166
|
+
"When using a DataRobot hosted MCP deployment, datarobot_endpoint must be set."
|
|
167
|
+
)
|
|
168
|
+
if self.datarobot_api_token is None:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
"When using a DataRobot hosted MCP deployment, datarobot_api_token must be set."
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
base_url = self.datarobot_endpoint.rstrip("/")
|
|
174
|
+
if not base_url.endswith("/api/v2"):
|
|
175
|
+
base_url = f"{base_url}/api/v2"
|
|
176
|
+
|
|
177
|
+
url = f"{base_url}/deployments/{self.mcp_deployment_id}/directAccess/mcp"
|
|
178
|
+
headers = self._build_authenticated_headers()
|
|
179
|
+
|
|
180
|
+
logger.info(f"Using DataRobot hosted MCP deployment: {url}")
|
|
181
|
+
|
|
182
|
+
return {
|
|
183
|
+
"url": url,
|
|
184
|
+
"transport": "streamable-http",
|
|
185
|
+
"headers": headers,
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
if self.external_mcp_url:
|
|
189
|
+
# External MCP URL - no authentication needed
|
|
190
|
+
headers = {}
|
|
191
|
+
|
|
192
|
+
# Merge external headers if provided
|
|
193
|
+
if self.external_mcp_headers:
|
|
194
|
+
external_headers = json.loads(self.external_mcp_headers)
|
|
195
|
+
headers.update(external_headers)
|
|
196
|
+
|
|
197
|
+
logger.info(f"Using external MCP URL: {self.external_mcp_url}")
|
|
198
|
+
|
|
199
|
+
return {
|
|
200
|
+
"url": self.external_mcp_url.rstrip("/"),
|
|
201
|
+
"transport": self.external_mcp_transport,
|
|
202
|
+
"headers": headers,
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
# No MCP configuration found, setup localhost if running locally
|
|
206
|
+
if self.mcp_server_port:
|
|
207
|
+
url = f"http://localhost:{self.mcp_server_port}"
|
|
208
|
+
if self._check_localhost_server(url):
|
|
209
|
+
headers = self._build_authenticated_headers()
|
|
210
|
+
logger.info(f"Using localhost MCP server: {url}")
|
|
211
|
+
return {
|
|
212
|
+
"url": f"{url}/mcp",
|
|
213
|
+
"transport": "streamable-http",
|
|
214
|
+
"headers": headers,
|
|
215
|
+
}
|
|
216
|
+
logger.warning(f"MCP server is not running or not responding at {url}")
|
|
217
|
+
|
|
218
|
+
return None
|
|
@@ -16,9 +16,14 @@ import warnings
|
|
|
16
16
|
from typing import Any
|
|
17
17
|
|
|
18
18
|
import jwt
|
|
19
|
+
from datarobot.auth.datarobot.oauth import AsyncOAuth as DatarobotAsyncOAuthClient
|
|
20
|
+
from datarobot.auth.identity import Identity
|
|
21
|
+
from datarobot.auth.oauth import AsyncOAuthComponent
|
|
19
22
|
from datarobot.auth.session import AuthCtx
|
|
20
23
|
from datarobot.core.config import DataRobotAppFrameworkBaseSettings
|
|
24
|
+
from datarobot.models.genai.agent.auth import ToolAuth
|
|
21
25
|
from datarobot.models.genai.agent.auth import get_authorization_context
|
|
26
|
+
from pydantic import BaseModel
|
|
22
27
|
|
|
23
28
|
logger = logging.getLogger(__name__)
|
|
24
29
|
|
|
@@ -27,6 +32,13 @@ class AuthContextConfig(DataRobotAppFrameworkBaseSettings):
|
|
|
27
32
|
session_secret_key: str = ""
|
|
28
33
|
|
|
29
34
|
|
|
35
|
+
class DRAppCtx(BaseModel):
|
|
36
|
+
"""DataRobot application context from authorization metadata."""
|
|
37
|
+
|
|
38
|
+
email: str | None = None
|
|
39
|
+
api_key: str | None = None
|
|
40
|
+
|
|
41
|
+
|
|
30
42
|
class AuthContextHeaderHandler:
|
|
31
43
|
"""Manages encoding and decoding of authorization context into JWT tokens.
|
|
32
44
|
|
|
@@ -146,6 +158,7 @@ class AuthContextHeaderHandler:
|
|
|
146
158
|
|
|
147
159
|
auth_ctx_dict = self.decode(token)
|
|
148
160
|
if not auth_ctx_dict:
|
|
161
|
+
logger.debug("Failed to decode auth context from token")
|
|
149
162
|
return None
|
|
150
163
|
|
|
151
164
|
try:
|
|
@@ -153,3 +166,54 @@ class AuthContextHeaderHandler:
|
|
|
153
166
|
except Exception as e:
|
|
154
167
|
logger.error(f"Failed to create AuthCtx from decoded token: {e}", exc_info=True)
|
|
155
168
|
return None
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class AsyncOAuthTokenProvider:
|
|
172
|
+
"""Manages OAuth access tokens using generic OAuth client."""
|
|
173
|
+
|
|
174
|
+
def __init__(self, auth_ctx: AuthCtx) -> None:
|
|
175
|
+
self.auth_ctx = auth_ctx
|
|
176
|
+
self.oauth_client = self._create_oauth_client()
|
|
177
|
+
|
|
178
|
+
def _get_identity(self, provider_type: str | None) -> Identity:
|
|
179
|
+
"""Retrieve the appropriate identity from the authentication context."""
|
|
180
|
+
identities = [x for x in self.auth_ctx.identities if x.provider_identity_id is not None]
|
|
181
|
+
|
|
182
|
+
if not identities:
|
|
183
|
+
raise ValueError("No identities found in authorization context.")
|
|
184
|
+
|
|
185
|
+
if provider_type is None:
|
|
186
|
+
if len(identities) > 1:
|
|
187
|
+
raise ValueError(
|
|
188
|
+
"Multiple identities found. Please specify 'provider_type' parameter."
|
|
189
|
+
)
|
|
190
|
+
return identities[0]
|
|
191
|
+
|
|
192
|
+
identity = next((id for id in identities if id.provider_type == provider_type), None)
|
|
193
|
+
|
|
194
|
+
if identity is None:
|
|
195
|
+
raise ValueError(f"No identity found for provider '{provider_type}'.")
|
|
196
|
+
|
|
197
|
+
return identity
|
|
198
|
+
|
|
199
|
+
async def get_token(self, auth_type: ToolAuth, provider_type: str | None = None) -> str:
|
|
200
|
+
"""Get OAuth access token using the specified method."""
|
|
201
|
+
if auth_type != ToolAuth.OBO:
|
|
202
|
+
raise ValueError(
|
|
203
|
+
f"Unsupported auth type: {auth_type}. Only {ToolAuth.OBO} is supported."
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
identity = self._get_identity(provider_type)
|
|
207
|
+
token_data = await self.oauth_client.refresh_access_token(
|
|
208
|
+
identity_id=identity.provider_identity_id
|
|
209
|
+
)
|
|
210
|
+
return token_data.access_token
|
|
211
|
+
|
|
212
|
+
def _create_oauth_client(self) -> AsyncOAuthComponent:
|
|
213
|
+
"""Create either DataRobot or Authlib OAuth client based on
|
|
214
|
+
authorization context.
|
|
215
|
+
|
|
216
|
+
Note: at the moment, only DataRobot OAuth client is supported.
|
|
217
|
+
"""
|
|
218
|
+
logger.debug("Using DataRobot OAuth client")
|
|
219
|
+
return DatarobotAsyncOAuthClient()
|
|
@@ -80,6 +80,37 @@ class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
|
|
|
80
80
|
"""
|
|
81
81
|
raise NotImplementedError
|
|
82
82
|
|
|
83
|
+
def _extract_pipeline_interactions(self) -> MultiTurnSample | None:
|
|
84
|
+
"""Extract pipeline interactions from event listener if available."""
|
|
85
|
+
if not hasattr(self, "event_listener"):
|
|
86
|
+
return None
|
|
87
|
+
try:
|
|
88
|
+
listener = getattr(self, "event_listener", None)
|
|
89
|
+
messages = getattr(listener, "messages", None) if listener is not None else None
|
|
90
|
+
return create_pipeline_interactions_from_messages(messages)
|
|
91
|
+
except Exception:
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
def _extract_usage_metrics(self, crew_output: Any) -> UsageMetrics:
|
|
95
|
+
"""Extract usage metrics from crew output."""
|
|
96
|
+
token_usage = getattr(crew_output, "token_usage", None)
|
|
97
|
+
if token_usage is not None:
|
|
98
|
+
return {
|
|
99
|
+
"completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
|
|
100
|
+
"prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
|
|
101
|
+
"total_tokens": int(getattr(token_usage, "total_tokens", 0)),
|
|
102
|
+
}
|
|
103
|
+
return default_usage_metrics()
|
|
104
|
+
|
|
105
|
+
def _process_crew_output(
|
|
106
|
+
self, crew_output: Any
|
|
107
|
+
) -> tuple[str, MultiTurnSample | None, UsageMetrics]:
|
|
108
|
+
"""Process crew output into response tuple."""
|
|
109
|
+
response_text = str(crew_output.raw)
|
|
110
|
+
pipeline_interactions = self._extract_pipeline_interactions()
|
|
111
|
+
usage_metrics = self._extract_usage_metrics(crew_output)
|
|
112
|
+
return response_text, pipeline_interactions, usage_metrics
|
|
113
|
+
|
|
83
114
|
async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
|
|
84
115
|
"""Run the CrewAI workflow with the provided completion parameters."""
|
|
85
116
|
user_prompt_content = extract_user_prompt_content(completion_create_params)
|
|
@@ -92,9 +123,8 @@ class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
|
|
|
92
123
|
|
|
93
124
|
# Use MCP context manager to handle connection lifecycle
|
|
94
125
|
with mcp_tools_context(
|
|
95
|
-
api_base=self.api_base,
|
|
96
|
-
api_key=self.api_key,
|
|
97
126
|
authorization_context=self._authorization_context,
|
|
127
|
+
forwarded_headers=self.forwarded_headers,
|
|
98
128
|
) as mcp_tools:
|
|
99
129
|
# Set MCP tools for all agents if MCP is not configured this is effectively a no-op
|
|
100
130
|
self.set_mcp_tools(mcp_tools)
|
|
@@ -117,64 +147,13 @@ class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
|
|
|
117
147
|
async def _gen() -> AsyncGenerator[
|
|
118
148
|
tuple[str, MultiTurnSample | None, UsageMetrics]
|
|
119
149
|
]:
|
|
120
|
-
# Run kickoff in a worker thread.
|
|
121
150
|
crew_output = await asyncio.to_thread(
|
|
122
151
|
crew.kickoff,
|
|
123
152
|
inputs=self.make_kickoff_inputs(user_prompt_content),
|
|
124
153
|
)
|
|
125
|
-
|
|
126
|
-
pipeline_interactions = None
|
|
127
|
-
if hasattr(self, "event_listener"):
|
|
128
|
-
try:
|
|
129
|
-
listener = getattr(self, "event_listener", None)
|
|
130
|
-
messages = (
|
|
131
|
-
getattr(listener, "messages", None)
|
|
132
|
-
if listener is not None
|
|
133
|
-
else None
|
|
134
|
-
)
|
|
135
|
-
pipeline_interactions = create_pipeline_interactions_from_messages(
|
|
136
|
-
messages
|
|
137
|
-
)
|
|
138
|
-
except Exception:
|
|
139
|
-
pipeline_interactions = None
|
|
140
|
-
|
|
141
|
-
token_usage = getattr(crew_output, "token_usage", None)
|
|
142
|
-
if token_usage is not None:
|
|
143
|
-
usage_metrics: UsageMetrics = {
|
|
144
|
-
"completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
|
|
145
|
-
"prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
|
|
146
|
-
"total_tokens": int(getattr(token_usage, "total_tokens", 0)),
|
|
147
|
-
}
|
|
148
|
-
else:
|
|
149
|
-
usage_metrics = default_usage_metrics()
|
|
150
|
-
|
|
151
|
-
# Finalize stream with empty chunk carrying interactions and usage
|
|
152
|
-
yield "", pipeline_interactions, usage_metrics
|
|
154
|
+
yield self._process_crew_output(crew_output)
|
|
153
155
|
|
|
154
156
|
return _gen()
|
|
155
157
|
|
|
156
|
-
# Non-streaming: run to completion and return final result
|
|
157
158
|
crew_output = crew.kickoff(inputs=self.make_kickoff_inputs(user_prompt_content))
|
|
158
|
-
|
|
159
|
-
response_text = str(crew_output.raw)
|
|
160
|
-
|
|
161
|
-
pipeline_interactions = None
|
|
162
|
-
if hasattr(self, "event_listener"):
|
|
163
|
-
try:
|
|
164
|
-
listener = getattr(self, "event_listener", None)
|
|
165
|
-
messages = getattr(listener, "messages", None) if listener is not None else None
|
|
166
|
-
pipeline_interactions = create_pipeline_interactions_from_messages(messages)
|
|
167
|
-
except Exception:
|
|
168
|
-
pipeline_interactions = None
|
|
169
|
-
|
|
170
|
-
token_usage = getattr(crew_output, "token_usage", None)
|
|
171
|
-
if token_usage is not None:
|
|
172
|
-
usage_metrics: UsageMetrics = {
|
|
173
|
-
"completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
|
|
174
|
-
"prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
|
|
175
|
-
"total_tokens": int(getattr(token_usage, "total_tokens", 0)),
|
|
176
|
-
}
|
|
177
|
-
else:
|
|
178
|
-
usage_metrics = default_usage_metrics()
|
|
179
|
-
|
|
180
|
-
return response_text, pipeline_interactions, usage_metrics
|
|
159
|
+
return self._process_crew_output(crew_output)
|
|
@@ -29,15 +29,14 @@ from datarobot_genai.core.mcp.common import MCPConfig
|
|
|
29
29
|
|
|
30
30
|
@contextmanager
|
|
31
31
|
def mcp_tools_context(
|
|
32
|
-
api_base: str | None = None,
|
|
33
|
-
api_key: str | None = None,
|
|
34
32
|
authorization_context: dict[str, Any] | None = None,
|
|
33
|
+
forwarded_headers: dict[str, str] | None = None,
|
|
35
34
|
) -> Generator[list[Any], None, None]:
|
|
36
35
|
"""Context manager for MCP tools that handles connection lifecycle."""
|
|
37
36
|
config = MCPConfig(
|
|
38
|
-
|
|
37
|
+
authorization_context=authorization_context,
|
|
38
|
+
forwarded_headers=forwarded_headers,
|
|
39
39
|
)
|
|
40
|
-
|
|
41
40
|
# If no MCP server configured, return empty tools list
|
|
42
41
|
if not config.server_config:
|
|
43
42
|
print("No MCP server configured, using empty tools list", flush=True)
|
|
@@ -47,10 +46,8 @@ def mcp_tools_context(
|
|
|
47
46
|
print(f"Connecting to MCP server: {config.server_config['url']}", flush=True)
|
|
48
47
|
|
|
49
48
|
# Use MCPServerAdapter as context manager with the server config
|
|
50
|
-
adapter_setting = config.server_config.copy()
|
|
51
|
-
adapter_setting["transport"] = "streamable-http"
|
|
52
49
|
try:
|
|
53
|
-
with MCPServerAdapter(
|
|
50
|
+
with MCPServerAdapter(config.server_config) as tools:
|
|
54
51
|
print(
|
|
55
52
|
f"Successfully connected to MCP server, got {len(tools)} tools",
|
|
56
53
|
flush=True,
|
|
@@ -18,7 +18,6 @@ import logging
|
|
|
18
18
|
from typing import Any
|
|
19
19
|
|
|
20
20
|
from datarobot.auth.session import AuthCtx
|
|
21
|
-
from datarobot.models.genai.agent.auth import OAuthAccessTokenProvider
|
|
22
21
|
from datarobot.models.genai.agent.auth import ToolAuth
|
|
23
22
|
from fastmcp.server.dependencies import get_context
|
|
24
23
|
from fastmcp.server.dependencies import get_http_headers
|
|
@@ -27,12 +26,15 @@ from fastmcp.server.middleware import Middleware
|
|
|
27
26
|
from fastmcp.server.middleware import MiddlewareContext
|
|
28
27
|
from fastmcp.tools.tool import ToolResult
|
|
29
28
|
|
|
29
|
+
from datarobot_genai.core.utils.auth import AsyncOAuthTokenProvider
|
|
30
30
|
from datarobot_genai.core.utils.auth import AuthContextHeaderHandler
|
|
31
|
-
from datarobot_genai.drmcp import get_config
|
|
32
31
|
|
|
33
32
|
logger = logging.getLogger(__name__)
|
|
34
33
|
|
|
35
34
|
|
|
35
|
+
AUTH_CTX_KEY = "authorization_context"
|
|
36
|
+
|
|
37
|
+
|
|
36
38
|
class OAuthMiddleWare(Middleware):
|
|
37
39
|
"""Middleware that parses `x-datarobot-authorization-context` for tool calls.
|
|
38
40
|
|
|
@@ -45,16 +47,8 @@ class OAuthMiddleWare(Middleware):
|
|
|
45
47
|
Handler for encoding/decoding JWT tokens containing auth context.
|
|
46
48
|
"""
|
|
47
49
|
|
|
48
|
-
def __init__(self,
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
Parameters
|
|
52
|
-
----------
|
|
53
|
-
secret_key : Optional[str]
|
|
54
|
-
Secret key for JWT validation. If None, uses the value from config.
|
|
55
|
-
"""
|
|
56
|
-
secret_key = secret_key or get_config().session_secret_key
|
|
57
|
-
self.auth_handler = AuthContextHeaderHandler(secret_key)
|
|
50
|
+
def __init__(self, auth_handler: AuthContextHeaderHandler | None = None) -> None:
|
|
51
|
+
self.auth_handler = auth_handler or AuthContextHeaderHandler()
|
|
58
52
|
|
|
59
53
|
async def on_call_tool(
|
|
60
54
|
self, context: MiddlewareContext, call_next: CallNext[Any, ToolResult]
|
|
@@ -74,9 +68,12 @@ class OAuthMiddleWare(Middleware):
|
|
|
74
68
|
The result from the tool execution.
|
|
75
69
|
"""
|
|
76
70
|
auth_context = self._extract_auth_context()
|
|
71
|
+
if not auth_context:
|
|
72
|
+
logger.debug("No valid authorization context extracted from request headers.")
|
|
77
73
|
|
|
78
74
|
if context.fastmcp_context is not None:
|
|
79
|
-
context.fastmcp_context.
|
|
75
|
+
context.fastmcp_context.set_state(AUTH_CTX_KEY, auth_context)
|
|
76
|
+
logger.debug("Authorization context attached to state.")
|
|
80
77
|
|
|
81
78
|
return await call_next(context)
|
|
82
79
|
|
|
@@ -99,8 +96,8 @@ class OAuthMiddleWare(Middleware):
|
|
|
99
96
|
return None
|
|
100
97
|
|
|
101
98
|
|
|
102
|
-
async def
|
|
103
|
-
"""Retrieve the AuthCtx from the current request context
|
|
99
|
+
async def must_get_auth_context() -> AuthCtx:
|
|
100
|
+
"""Retrieve the AuthCtx from the current request context or raise error.
|
|
104
101
|
|
|
105
102
|
Raises
|
|
106
103
|
------
|
|
@@ -113,14 +110,15 @@ async def get_auth_context() -> AuthCtx:
|
|
|
113
110
|
The authorization context associated with the current request.
|
|
114
111
|
"""
|
|
115
112
|
context = get_context()
|
|
116
|
-
|
|
113
|
+
|
|
114
|
+
auth_ctx = context.get_state(AUTH_CTX_KEY)
|
|
117
115
|
if not auth_ctx:
|
|
118
|
-
raise RuntimeError("
|
|
116
|
+
raise RuntimeError("Could not retrieve authorization context from FastMCP context state.")
|
|
119
117
|
|
|
120
118
|
return auth_ctx
|
|
121
119
|
|
|
122
120
|
|
|
123
|
-
async def get_access_token(
|
|
121
|
+
async def get_access_token(provider_type: str | None = None) -> str:
|
|
124
122
|
"""Retrieve access token from the DataRobot OAuth Provider Service.
|
|
125
123
|
|
|
126
124
|
OAuth access tokens can be retrieved only for providers where the user completed
|
|
@@ -132,7 +130,7 @@ async def get_access_token(provider: str | None = None) -> str:
|
|
|
132
130
|
|
|
133
131
|
Parameters
|
|
134
132
|
----------
|
|
135
|
-
|
|
133
|
+
provider_type : str, optional
|
|
136
134
|
The name of the OAuth provider. It should match the name of the provider configured
|
|
137
135
|
during provider setup. If no value is provided and only one OAuth provider exists, that
|
|
138
136
|
provider will be used. If multiple providers exist and none is specified, an error will be
|
|
@@ -142,12 +140,18 @@ async def get_access_token(provider: str | None = None) -> str:
|
|
|
142
140
|
-------
|
|
143
141
|
The oauth access token.
|
|
144
142
|
"""
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
143
|
+
auth_ctx = await must_get_auth_context()
|
|
144
|
+
logger.debug("Retrieved authorization context")
|
|
145
|
+
|
|
146
|
+
oauth_token_provider = AsyncOAuthTokenProvider(auth_ctx)
|
|
147
|
+
oauth_access_token = await oauth_token_provider.get_token(
|
|
148
|
+
auth_type=ToolAuth.OBO,
|
|
149
|
+
provider_type=provider_type,
|
|
150
|
+
)
|
|
151
|
+
return oauth_access_token
|
|
148
152
|
|
|
149
153
|
|
|
150
|
-
def initialize_oauth_middleware(mcp: Any
|
|
154
|
+
def initialize_oauth_middleware(mcp: Any) -> None:
|
|
151
155
|
"""Initialize and register OAuth middleware with the MCP server.
|
|
152
156
|
|
|
153
157
|
Parameters
|
|
@@ -157,6 +161,5 @@ def initialize_oauth_middleware(mcp: Any, secret_key: str | None = None) -> None
|
|
|
157
161
|
secret_key : Optional[str]
|
|
158
162
|
Secret key for JWT validation. If None, uses the value from config.
|
|
159
163
|
"""
|
|
160
|
-
|
|
161
|
-
mcp.add_middleware(middleware)
|
|
164
|
+
mcp.add_middleware(OAuthMiddleWare())
|
|
162
165
|
logger.info("OAuth middleware registered successfully")
|