datarobot-genai 0.1.59__tar.gz → 0.1.70__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/PKG-INFO +2 -2
  2. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/pyproject.toml +2 -2
  3. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/agents/base.py +7 -0
  4. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/custom_model.py +5 -0
  5. datarobot_genai-0.1.70/src/datarobot_genai/core/mcp/common.py +218 -0
  6. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/utils/auth.py +64 -0
  7. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/crewai/base.py +34 -55
  8. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/crewai/mcp.py +4 -7
  9. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/auth.py +28 -25
  10. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/clients.py +67 -3
  11. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/config.py +0 -8
  12. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dr_mcp_server.py +10 -3
  13. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dr_mcp_server_logo.py +12 -1
  14. datarobot_genai-0.1.70/src/datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
  15. datarobot_genai-0.1.70/src/datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +128 -0
  16. datarobot_genai-0.1.70/src/datarobot_genai/drmcp/core/dynamic_prompts/register.py +206 -0
  17. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/mcp_instance.py +10 -0
  18. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/mcp_server_tools.py +2 -2
  19. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/routes.py +125 -28
  20. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/langgraph/agent.py +5 -6
  21. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/langgraph/mcp.py +5 -7
  22. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/llama_index/base.py +1 -2
  23. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/llama_index/mcp.py +4 -5
  24. datarobot_genai-0.1.70/src/datarobot_genai/nat/agent.py +258 -0
  25. datarobot_genai-0.1.59/src/datarobot_genai/core/mcp/common.py +0 -109
  26. datarobot_genai-0.1.59/src/datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +0 -91
  27. datarobot_genai-0.1.59/src/datarobot_genai/drmcp/core/dynamic_prompts/register.py +0 -150
  28. datarobot_genai-0.1.59/src/datarobot_genai/nat/agent.py +0 -137
  29. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/.gitignore +0 -0
  30. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/AUTHORS +0 -0
  31. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/LICENSE +0 -0
  32. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/README.md +0 -0
  33. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/__init__.py +0 -0
  34. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/__init__.py +0 -0
  35. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/agents/__init__.py +0 -0
  36. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/chat/__init__.py +0 -0
  37. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/chat/auth.py +0 -0
  38. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/chat/client.py +0 -0
  39. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/chat/responses.py +0 -0
  40. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/cli/__init__.py +0 -0
  41. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/cli/agent_environment.py +0 -0
  42. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/cli/agent_kernel.py +0 -0
  43. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/mcp/__init__.py +0 -0
  44. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/telemetry_agent.py +0 -0
  45. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/utils/__init__.py +0 -0
  46. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/core/utils/urls.py +0 -0
  47. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/crewai/__init__.py +0 -0
  48. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/crewai/agent.py +0 -0
  49. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/crewai/events.py +0 -0
  50. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/__init__.py +0 -0
  51. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/__init__.py +0 -0
  52. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/config_utils.py +0 -0
  53. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/constants.py +0 -0
  54. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/credentials.py +0 -0
  55. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +0 -0
  56. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_prompts/utils.py +0 -0
  57. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/__init__.py +0 -0
  58. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
  59. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +0 -0
  60. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +0 -0
  61. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +0 -0
  62. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +0 -0
  63. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +0 -0
  64. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +0 -0
  65. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +0 -0
  66. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +0 -0
  67. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +0 -0
  68. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +0 -0
  69. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/register.py +0 -0
  70. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/dynamic_tools/schema.py +0 -0
  71. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/exceptions.py +0 -0
  72. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/logging.py +0 -0
  73. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/memory_management/__init__.py +0 -0
  74. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/memory_management/manager.py +0 -0
  75. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/memory_management/memory_tools.py +0 -0
  76. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/routes_utils.py +0 -0
  77. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/server_life_cycle.py +0 -0
  78. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/telemetry.py +0 -0
  79. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/tool_filter.py +0 -0
  80. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/core/utils.py +0 -0
  81. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/server.py +0 -0
  82. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/__init__.py +0 -0
  83. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/integration_mcp_server.py +0 -0
  84. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +0 -0
  85. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +0 -0
  86. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +0 -0
  87. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/tool_base_ete.py +0 -0
  88. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/test_utils/utils.py +0 -0
  89. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/__init__.py +0 -0
  90. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/__init__.py +0 -0
  91. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/data.py +0 -0
  92. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/deployment.py +0 -0
  93. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/deployment_info.py +0 -0
  94. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/model.py +0 -0
  95. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/predict.py +0 -0
  96. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/predict_realtime.py +0 -0
  97. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/project.py +0 -0
  98. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/drmcp/tools/predictive/training.py +0 -0
  99. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/langgraph/__init__.py +0 -0
  100. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/llama_index/__init__.py +0 -0
  101. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/llama_index/agent.py +0 -0
  102. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/nat/__init__.py +0 -0
  103. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/nat/datarobot_llm_clients.py +0 -0
  104. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/nat/datarobot_llm_providers.py +0 -0
  105. {datarobot_genai-0.1.59 → datarobot_genai-0.1.70}/src/datarobot_genai/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datarobot-genai
3
- Version: 0.1.59
3
+ Version: 0.1.70
4
4
  Summary: Generic helpers for GenAI
5
5
  Project-URL: Homepage, https://github.com/datarobot-oss/datarobot-genai
6
6
  Author: DataRobot, Inc.
@@ -32,7 +32,7 @@ Requires-Dist: aiohttp<4.0.0,>=3.9.0; extra == 'drmcp'
32
32
  Requires-Dist: aiosignal<2.0.0,>=1.3.1; extra == 'drmcp'
33
33
  Requires-Dist: boto3<2.0.0,>=1.34.0; extra == 'drmcp'
34
34
  Requires-Dist: datarobot-asgi-middleware<1.0.0,>=0.2.0; extra == 'drmcp'
35
- Requires-Dist: fastmcp==2.13.0.2; extra == 'drmcp'
35
+ Requires-Dist: fastmcp<3.0.0,>=2.13.0.2; extra == 'drmcp'
36
36
  Requires-Dist: httpx<1.0.0,>=0.28.1; extra == 'drmcp'
37
37
  Requires-Dist: opentelemetry-api<2.0.0,>=1.22.0; extra == 'drmcp'
38
38
  Requires-Dist: opentelemetry-exporter-otlp-proto-http<2.0.0,>=1.22.0; extra == 'drmcp'
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "datarobot-genai"
7
- version = "0.1.59"
7
+ version = "0.1.70"
8
8
  description = "Generic helpers for GenAI"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10, <3.13"
@@ -84,7 +84,7 @@ drmcp = [
84
84
  "aiohttp>=3.9.0,<4.0.0",
85
85
  "aiohttp-retry>=2.8.3,<3.0.0",
86
86
  "aiosignal>=1.3.1,<2.0.0",
87
- "fastmcp==2.13.0.2",
87
+ "fastmcp>=2.13.0.2,<3.0.0",
88
88
  ]
89
89
 
90
90
  [tool.hatch.build.targets.wheel]
@@ -52,6 +52,7 @@ class BaseAgent(Generic[TTool], abc.ABC):
52
52
  verbose: bool | str | None = True,
53
53
  timeout: int | None = 90,
54
54
  authorization_context: dict[str, Any] | None = None,
55
+ forwarded_headers: dict[str, str] | None = None,
55
56
  **_: Any,
56
57
  ) -> None:
57
58
  self.api_key = api_key or os.environ.get("DATAROBOT_API_TOKEN")
@@ -68,6 +69,7 @@ class BaseAgent(Generic[TTool], abc.ABC):
68
69
  self.verbose = bool(verbose)
69
70
  self._mcp_tools: list[TTool] = []
70
71
  self._authorization_context = authorization_context or {}
72
+ self._forwarded_headers: dict[str, str] = forwarded_headers or {}
71
73
 
72
74
  def set_mcp_tools(self, tools: list[TTool]) -> None:
73
75
  self._mcp_tools = tools
@@ -86,6 +88,11 @@ class BaseAgent(Generic[TTool], abc.ABC):
86
88
  """Return the authorization context for this agent."""
87
89
  return self._authorization_context
88
90
 
91
+ @property
92
+ def forwarded_headers(self) -> dict[str, str]:
93
+ """Return the forwarded headers for this agent."""
94
+ return self._forwarded_headers
95
+
89
96
  def litellm_api_base(self, deployment_id: str | None) -> str:
90
97
  return get_api_base(self.api_base, deployment_id)
91
98
 
@@ -139,6 +139,11 @@ def chat_entrypoint(
139
139
  completion_create_params["authorization_context"] = resolve_authorization_context(
140
140
  completion_create_params, **kwargs
141
141
  )
142
+ # Keep only allowed headers from the forwarded_headers.
143
+ incoming_headers = kwargs.get("headers", {}) or {}
144
+ allowed_headers = {"x-datarobot-api-token", "x-datarobot-api-key"}
145
+ forwarded_headers = {k: v for k, v in incoming_headers.items() if k.lower() in allowed_headers}
146
+ completion_create_params["forwarded_headers"] = forwarded_headers
142
147
 
143
148
  # Instantiate user agent with all supplied completion params including auth context
144
149
  agent = agent_cls(**completion_create_params)
@@ -0,0 +1,218 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+ import re
18
+ from http import HTTPStatus
19
+ from typing import Any
20
+ from typing import Literal
21
+
22
+ import requests
23
+ from datarobot.core.config import DataRobotAppFrameworkBaseSettings
24
+ from pydantic import field_validator
25
+
26
+ from datarobot_genai.core.utils.auth import AuthContextHeaderHandler
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class MCPConfig(DataRobotAppFrameworkBaseSettings):
32
+ """Configuration for MCP server connection.
33
+
34
+ Derived values are exposed as properties rather than stored, avoiding
35
+ Pydantic field validation/serialization concerns for internal helpers.
36
+ """
37
+
38
+ external_mcp_url: str | None = None
39
+ external_mcp_headers: str | None = None
40
+ external_mcp_transport: Literal["sse", "streamable-http"] = "streamable-http"
41
+ mcp_deployment_id: str | None = None
42
+ datarobot_endpoint: str | None = None
43
+ datarobot_api_token: str | None = None
44
+ authorization_context: dict[str, Any] | None = None
45
+ forwarded_headers: dict[str, str] | None = None
46
+ mcp_server_port: int | None = None
47
+
48
+ _auth_context_handler: AuthContextHeaderHandler | None = None
49
+ _server_config: dict[str, Any] | None = None
50
+
51
+ @field_validator("external_mcp_headers", mode="before")
52
+ @classmethod
53
+ def validate_external_mcp_headers(cls, value: str | None) -> str | None:
54
+ if value is None:
55
+ return None
56
+
57
+ candidate = value.strip()
58
+
59
+ try:
60
+ json.loads(candidate)
61
+ except json.JSONDecodeError:
62
+ msg = "external_mcp_headers must be valid JSON"
63
+ logger.warning(msg)
64
+ return None
65
+
66
+ return candidate
67
+
68
+ @field_validator("mcp_deployment_id", mode="before")
69
+ @classmethod
70
+ def validate_mcp_deployment_id(cls, value: str | None) -> str | None:
71
+ if value is None:
72
+ return None
73
+
74
+ candidate = value.strip()
75
+
76
+ if not re.fullmatch(r"[0-9a-fA-F]{24}", candidate):
77
+ msg = "mcp_deployment_id must be a valid 24-character hex ID"
78
+ logger.warning(msg)
79
+ return None
80
+
81
+ return candidate
82
+
83
+ def _authorization_bearer_header(self) -> dict[str, str]:
84
+ """Return Authorization header with Bearer token or empty dict."""
85
+ if not self.datarobot_api_token:
86
+ return {}
87
+ auth = (
88
+ self.datarobot_api_token
89
+ if self.datarobot_api_token.startswith("Bearer ")
90
+ else f"Bearer {self.datarobot_api_token}"
91
+ )
92
+ return {"Authorization": auth}
93
+
94
+ @property
95
+ def auth_context_handler(self) -> AuthContextHeaderHandler:
96
+ if self._auth_context_handler is None:
97
+ self._auth_context_handler = AuthContextHeaderHandler()
98
+ return self._auth_context_handler
99
+
100
+ @property
101
+ def server_config(self) -> dict[str, Any] | None:
102
+ if self._server_config is None:
103
+ self._server_config = self._build_server_config()
104
+ return self._server_config
105
+
106
+ def _authorization_context_header(self) -> dict[str, str]:
107
+ """Return X-DataRobot-Authorization-Context header or empty dict."""
108
+ try:
109
+ return self.auth_context_handler.get_header(self.authorization_context)
110
+ except (LookupError, RuntimeError):
111
+ # Authorization context not available (e.g., in tests)
112
+ return {}
113
+
114
+ def _build_authenticated_headers(self) -> dict[str, str]:
115
+ """Build headers for authenticated requests.
116
+
117
+ Returns
118
+ -------
119
+ Dictionary containing forwarded headers (if available) and authentication headers.
120
+ """
121
+ headers: dict[str, str] = {}
122
+ if self.forwarded_headers:
123
+ headers.update(self.forwarded_headers)
124
+ headers.update(self._authorization_bearer_header())
125
+ headers.update(self._authorization_context_header())
126
+ return headers
127
+
128
+ def _check_localhost_server(self, url: str, timeout: float = 2.0) -> bool:
129
+ """Check if MCP server is running on localhost.
130
+
131
+ Parameters
132
+ ----------
133
+ url : str
134
+ The URL to check.
135
+ timeout : float, optional
136
+ Request timeout in seconds (default: 2.0).
137
+
138
+ Returns
139
+ -------
140
+ bool
141
+ True if server is running and responding with OK status, False otherwise.
142
+ """
143
+ try:
144
+ response = requests.get(url, timeout=timeout)
145
+ return (
146
+ response.status_code == HTTPStatus.OK
147
+ and response.json().get("message") == "DataRobot MCP Server is running"
148
+ )
149
+ except requests.RequestException as e:
150
+ logger.debug(f"Failed to connect to MCP server at {url}: {e}")
151
+ return False
152
+
153
+ def _build_server_config(self) -> dict[str, Any] | None:
154
+ """
155
+ Get MCP server configuration.
156
+
157
+ Returns
158
+ -------
159
+ Server configuration dict with url, transport, and optional headers,
160
+ or None if not configured.
161
+ """
162
+ if self.mcp_deployment_id:
163
+ # DataRobot deployment ID - requires authentication
164
+ if self.datarobot_endpoint is None:
165
+ raise ValueError(
166
+ "When using a DataRobot hosted MCP deployment, datarobot_endpoint must be set."
167
+ )
168
+ if self.datarobot_api_token is None:
169
+ raise ValueError(
170
+ "When using a DataRobot hosted MCP deployment, datarobot_api_token must be set."
171
+ )
172
+
173
+ base_url = self.datarobot_endpoint.rstrip("/")
174
+ if not base_url.endswith("/api/v2"):
175
+ base_url = f"{base_url}/api/v2"
176
+
177
+ url = f"{base_url}/deployments/{self.mcp_deployment_id}/directAccess/mcp"
178
+ headers = self._build_authenticated_headers()
179
+
180
+ logger.info(f"Using DataRobot hosted MCP deployment: {url}")
181
+
182
+ return {
183
+ "url": url,
184
+ "transport": "streamable-http",
185
+ "headers": headers,
186
+ }
187
+
188
+ if self.external_mcp_url:
189
+ # External MCP URL - no authentication needed
190
+ headers = {}
191
+
192
+ # Merge external headers if provided
193
+ if self.external_mcp_headers:
194
+ external_headers = json.loads(self.external_mcp_headers)
195
+ headers.update(external_headers)
196
+
197
+ logger.info(f"Using external MCP URL: {self.external_mcp_url}")
198
+
199
+ return {
200
+ "url": self.external_mcp_url.rstrip("/"),
201
+ "transport": self.external_mcp_transport,
202
+ "headers": headers,
203
+ }
204
+
205
+ # No MCP configuration found, setup localhost if running locally
206
+ if self.mcp_server_port:
207
+ url = f"http://localhost:{self.mcp_server_port}"
208
+ if self._check_localhost_server(url):
209
+ headers = self._build_authenticated_headers()
210
+ logger.info(f"Using localhost MCP server: {url}")
211
+ return {
212
+ "url": f"{url}/mcp",
213
+ "transport": "streamable-http",
214
+ "headers": headers,
215
+ }
216
+ logger.warning(f"MCP server is not running or not responding at {url}")
217
+
218
+ return None
@@ -16,9 +16,14 @@ import warnings
16
16
  from typing import Any
17
17
 
18
18
  import jwt
19
+ from datarobot.auth.datarobot.oauth import AsyncOAuth as DatarobotAsyncOAuthClient
20
+ from datarobot.auth.identity import Identity
21
+ from datarobot.auth.oauth import AsyncOAuthComponent
19
22
  from datarobot.auth.session import AuthCtx
20
23
  from datarobot.core.config import DataRobotAppFrameworkBaseSettings
24
+ from datarobot.models.genai.agent.auth import ToolAuth
21
25
  from datarobot.models.genai.agent.auth import get_authorization_context
26
+ from pydantic import BaseModel
22
27
 
23
28
  logger = logging.getLogger(__name__)
24
29
 
@@ -27,6 +32,13 @@ class AuthContextConfig(DataRobotAppFrameworkBaseSettings):
27
32
  session_secret_key: str = ""
28
33
 
29
34
 
35
+ class DRAppCtx(BaseModel):
36
+ """DataRobot application context from authorization metadata."""
37
+
38
+ email: str | None = None
39
+ api_key: str | None = None
40
+
41
+
30
42
  class AuthContextHeaderHandler:
31
43
  """Manages encoding and decoding of authorization context into JWT tokens.
32
44
 
@@ -146,6 +158,7 @@ class AuthContextHeaderHandler:
146
158
 
147
159
  auth_ctx_dict = self.decode(token)
148
160
  if not auth_ctx_dict:
161
+ logger.debug("Failed to decode auth context from token")
149
162
  return None
150
163
 
151
164
  try:
@@ -153,3 +166,54 @@ class AuthContextHeaderHandler:
153
166
  except Exception as e:
154
167
  logger.error(f"Failed to create AuthCtx from decoded token: {e}", exc_info=True)
155
168
  return None
169
+
170
+
171
+ class AsyncOAuthTokenProvider:
172
+ """Manages OAuth access tokens using generic OAuth client."""
173
+
174
+ def __init__(self, auth_ctx: AuthCtx) -> None:
175
+ self.auth_ctx = auth_ctx
176
+ self.oauth_client = self._create_oauth_client()
177
+
178
+ def _get_identity(self, provider_type: str | None) -> Identity:
179
+ """Retrieve the appropriate identity from the authentication context."""
180
+ identities = [x for x in self.auth_ctx.identities if x.provider_identity_id is not None]
181
+
182
+ if not identities:
183
+ raise ValueError("No identities found in authorization context.")
184
+
185
+ if provider_type is None:
186
+ if len(identities) > 1:
187
+ raise ValueError(
188
+ "Multiple identities found. Please specify 'provider_type' parameter."
189
+ )
190
+ return identities[0]
191
+
192
+ identity = next((id for id in identities if id.provider_type == provider_type), None)
193
+
194
+ if identity is None:
195
+ raise ValueError(f"No identity found for provider '{provider_type}'.")
196
+
197
+ return identity
198
+
199
+ async def get_token(self, auth_type: ToolAuth, provider_type: str | None = None) -> str:
200
+ """Get OAuth access token using the specified method."""
201
+ if auth_type != ToolAuth.OBO:
202
+ raise ValueError(
203
+ f"Unsupported auth type: {auth_type}. Only {ToolAuth.OBO} is supported."
204
+ )
205
+
206
+ identity = self._get_identity(provider_type)
207
+ token_data = await self.oauth_client.refresh_access_token(
208
+ identity_id=identity.provider_identity_id
209
+ )
210
+ return token_data.access_token
211
+
212
+ def _create_oauth_client(self) -> AsyncOAuthComponent:
213
+ """Create either DataRobot or Authlib OAuth client based on
214
+ authorization context.
215
+
216
+ Note: at the moment, only DataRobot OAuth client is supported.
217
+ """
218
+ logger.debug("Using DataRobot OAuth client")
219
+ return DatarobotAsyncOAuthClient()
@@ -80,6 +80,37 @@ class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
80
80
  """
81
81
  raise NotImplementedError
82
82
 
83
+ def _extract_pipeline_interactions(self) -> MultiTurnSample | None:
84
+ """Extract pipeline interactions from event listener if available."""
85
+ if not hasattr(self, "event_listener"):
86
+ return None
87
+ try:
88
+ listener = getattr(self, "event_listener", None)
89
+ messages = getattr(listener, "messages", None) if listener is not None else None
90
+ return create_pipeline_interactions_from_messages(messages)
91
+ except Exception:
92
+ return None
93
+
94
+ def _extract_usage_metrics(self, crew_output: Any) -> UsageMetrics:
95
+ """Extract usage metrics from crew output."""
96
+ token_usage = getattr(crew_output, "token_usage", None)
97
+ if token_usage is not None:
98
+ return {
99
+ "completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
100
+ "prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
101
+ "total_tokens": int(getattr(token_usage, "total_tokens", 0)),
102
+ }
103
+ return default_usage_metrics()
104
+
105
+ def _process_crew_output(
106
+ self, crew_output: Any
107
+ ) -> tuple[str, MultiTurnSample | None, UsageMetrics]:
108
+ """Process crew output into response tuple."""
109
+ response_text = str(crew_output.raw)
110
+ pipeline_interactions = self._extract_pipeline_interactions()
111
+ usage_metrics = self._extract_usage_metrics(crew_output)
112
+ return response_text, pipeline_interactions, usage_metrics
113
+
83
114
  async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
84
115
  """Run the CrewAI workflow with the provided completion parameters."""
85
116
  user_prompt_content = extract_user_prompt_content(completion_create_params)
@@ -92,9 +123,8 @@ class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
92
123
 
93
124
  # Use MCP context manager to handle connection lifecycle
94
125
  with mcp_tools_context(
95
- api_base=self.api_base,
96
- api_key=self.api_key,
97
126
  authorization_context=self._authorization_context,
127
+ forwarded_headers=self.forwarded_headers,
98
128
  ) as mcp_tools:
99
129
  # Set MCP tools for all agents if MCP is not configured this is effectively a no-op
100
130
  self.set_mcp_tools(mcp_tools)
@@ -117,64 +147,13 @@ class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
117
147
  async def _gen() -> AsyncGenerator[
118
148
  tuple[str, MultiTurnSample | None, UsageMetrics]
119
149
  ]:
120
- # Run kickoff in a worker thread.
121
150
  crew_output = await asyncio.to_thread(
122
151
  crew.kickoff,
123
152
  inputs=self.make_kickoff_inputs(user_prompt_content),
124
153
  )
125
-
126
- pipeline_interactions = None
127
- if hasattr(self, "event_listener"):
128
- try:
129
- listener = getattr(self, "event_listener", None)
130
- messages = (
131
- getattr(listener, "messages", None)
132
- if listener is not None
133
- else None
134
- )
135
- pipeline_interactions = create_pipeline_interactions_from_messages(
136
- messages
137
- )
138
- except Exception:
139
- pipeline_interactions = None
140
-
141
- token_usage = getattr(crew_output, "token_usage", None)
142
- if token_usage is not None:
143
- usage_metrics: UsageMetrics = {
144
- "completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
145
- "prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
146
- "total_tokens": int(getattr(token_usage, "total_tokens", 0)),
147
- }
148
- else:
149
- usage_metrics = default_usage_metrics()
150
-
151
- # Finalize stream with empty chunk carrying interactions and usage
152
- yield "", pipeline_interactions, usage_metrics
154
+ yield self._process_crew_output(crew_output)
153
155
 
154
156
  return _gen()
155
157
 
156
- # Non-streaming: run to completion and return final result
157
158
  crew_output = crew.kickoff(inputs=self.make_kickoff_inputs(user_prompt_content))
158
-
159
- response_text = str(crew_output.raw)
160
-
161
- pipeline_interactions = None
162
- if hasattr(self, "event_listener"):
163
- try:
164
- listener = getattr(self, "event_listener", None)
165
- messages = getattr(listener, "messages", None) if listener is not None else None
166
- pipeline_interactions = create_pipeline_interactions_from_messages(messages)
167
- except Exception:
168
- pipeline_interactions = None
169
-
170
- token_usage = getattr(crew_output, "token_usage", None)
171
- if token_usage is not None:
172
- usage_metrics: UsageMetrics = {
173
- "completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
174
- "prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
175
- "total_tokens": int(getattr(token_usage, "total_tokens", 0)),
176
- }
177
- else:
178
- usage_metrics = default_usage_metrics()
179
-
180
- return response_text, pipeline_interactions, usage_metrics
159
+ return self._process_crew_output(crew_output)
@@ -29,15 +29,14 @@ from datarobot_genai.core.mcp.common import MCPConfig
29
29
 
30
30
  @contextmanager
31
31
  def mcp_tools_context(
32
- api_base: str | None = None,
33
- api_key: str | None = None,
34
32
  authorization_context: dict[str, Any] | None = None,
33
+ forwarded_headers: dict[str, str] | None = None,
35
34
  ) -> Generator[list[Any], None, None]:
36
35
  """Context manager for MCP tools that handles connection lifecycle."""
37
36
  config = MCPConfig(
38
- api_base=api_base, api_key=api_key, authorization_context=authorization_context
37
+ authorization_context=authorization_context,
38
+ forwarded_headers=forwarded_headers,
39
39
  )
40
-
41
40
  # If no MCP server configured, return empty tools list
42
41
  if not config.server_config:
43
42
  print("No MCP server configured, using empty tools list", flush=True)
@@ -47,10 +46,8 @@ def mcp_tools_context(
47
46
  print(f"Connecting to MCP server: {config.server_config['url']}", flush=True)
48
47
 
49
48
  # Use MCPServerAdapter as context manager with the server config
50
- adapter_setting = config.server_config.copy()
51
- adapter_setting["transport"] = "streamable-http"
52
49
  try:
53
- with MCPServerAdapter(adapter_setting) as tools:
50
+ with MCPServerAdapter(config.server_config) as tools:
54
51
  print(
55
52
  f"Successfully connected to MCP server, got {len(tools)} tools",
56
53
  flush=True,
@@ -18,7 +18,6 @@ import logging
18
18
  from typing import Any
19
19
 
20
20
  from datarobot.auth.session import AuthCtx
21
- from datarobot.models.genai.agent.auth import OAuthAccessTokenProvider
22
21
  from datarobot.models.genai.agent.auth import ToolAuth
23
22
  from fastmcp.server.dependencies import get_context
24
23
  from fastmcp.server.dependencies import get_http_headers
@@ -27,12 +26,15 @@ from fastmcp.server.middleware import Middleware
27
26
  from fastmcp.server.middleware import MiddlewareContext
28
27
  from fastmcp.tools.tool import ToolResult
29
28
 
29
+ from datarobot_genai.core.utils.auth import AsyncOAuthTokenProvider
30
30
  from datarobot_genai.core.utils.auth import AuthContextHeaderHandler
31
- from datarobot_genai.drmcp import get_config
32
31
 
33
32
  logger = logging.getLogger(__name__)
34
33
 
35
34
 
35
+ AUTH_CTX_KEY = "authorization_context"
36
+
37
+
36
38
  class OAuthMiddleWare(Middleware):
37
39
  """Middleware that parses `x-datarobot-authorization-context` for tool calls.
38
40
 
@@ -45,16 +47,8 @@ class OAuthMiddleWare(Middleware):
45
47
  Handler for encoding/decoding JWT tokens containing auth context.
46
48
  """
47
49
 
48
- def __init__(self, secret_key: str | None = None) -> None:
49
- """Initialize the middleware with authentication handler.
50
-
51
- Parameters
52
- ----------
53
- secret_key : Optional[str]
54
- Secret key for JWT validation. If None, uses the value from config.
55
- """
56
- secret_key = secret_key or get_config().session_secret_key
57
- self.auth_handler = AuthContextHeaderHandler(secret_key)
50
+ def __init__(self, auth_handler: AuthContextHeaderHandler | None = None) -> None:
51
+ self.auth_handler = auth_handler or AuthContextHeaderHandler()
58
52
 
59
53
  async def on_call_tool(
60
54
  self, context: MiddlewareContext, call_next: CallNext[Any, ToolResult]
@@ -74,9 +68,12 @@ class OAuthMiddleWare(Middleware):
74
68
  The result from the tool execution.
75
69
  """
76
70
  auth_context = self._extract_auth_context()
71
+ if not auth_context:
72
+ logger.debug("No valid authorization context extracted from request headers.")
77
73
 
78
74
  if context.fastmcp_context is not None:
79
- context.fastmcp_context.auth_context = auth_context
75
+ context.fastmcp_context.set_state(AUTH_CTX_KEY, auth_context)
76
+ logger.debug("Authorization context attached to state.")
80
77
 
81
78
  return await call_next(context)
82
79
 
@@ -99,8 +96,8 @@ class OAuthMiddleWare(Middleware):
99
96
  return None
100
97
 
101
98
 
102
- async def get_auth_context() -> AuthCtx:
103
- """Retrieve the AuthCtx from the current request context, if available.
99
+ async def must_get_auth_context() -> AuthCtx:
100
+ """Retrieve the AuthCtx from the current request context or raise error.
104
101
 
105
102
  Raises
106
103
  ------
@@ -113,14 +110,15 @@ async def get_auth_context() -> AuthCtx:
113
110
  The authorization context associated with the current request.
114
111
  """
115
112
  context = get_context()
116
- auth_ctx = getattr(context, "auth_context", None)
113
+
114
+ auth_ctx = context.get_state(AUTH_CTX_KEY)
117
115
  if not auth_ctx:
118
- raise RuntimeError("No authorization context found.")
116
+ raise RuntimeError("Could not retrieve authorization context from FastMCP context state.")
119
117
 
120
118
  return auth_ctx
121
119
 
122
120
 
123
- async def get_access_token(provider: str | None = None) -> str:
121
+ async def get_access_token(provider_type: str | None = None) -> str:
124
122
  """Retrieve access token from the DataRobot OAuth Provider Service.
125
123
 
126
124
  OAuth access tokens can be retrieved only for providers where the user completed
@@ -132,7 +130,7 @@ async def get_access_token(provider: str | None = None) -> str:
132
130
 
133
131
  Parameters
134
132
  ----------
135
- provider : str, optional
133
+ provider_type : str, optional
136
134
  The name of the OAuth provider. It should match the name of the provider configured
137
135
  during provider setup. If no value is provided and only one OAuth provider exists, that
138
136
  provider will be used. If multiple providers exist and none is specified, an error will be
@@ -142,12 +140,18 @@ async def get_access_token(provider: str | None = None) -> str:
142
140
  -------
143
141
  The oauth access token.
144
142
  """
145
- token_provider = OAuthAccessTokenProvider(await get_auth_context())
146
- access_token = token_provider.get_token(ToolAuth.OBO, provider)
147
- return access_token
143
+ auth_ctx = await must_get_auth_context()
144
+ logger.debug("Retrieved authorization context")
145
+
146
+ oauth_token_provider = AsyncOAuthTokenProvider(auth_ctx)
147
+ oauth_access_token = await oauth_token_provider.get_token(
148
+ auth_type=ToolAuth.OBO,
149
+ provider_type=provider_type,
150
+ )
151
+ return oauth_access_token
148
152
 
149
153
 
150
- def initialize_oauth_middleware(mcp: Any, secret_key: str | None = None) -> None:
154
+ def initialize_oauth_middleware(mcp: Any) -> None:
151
155
  """Initialize and register OAuth middleware with the MCP server.
152
156
 
153
157
  Parameters
@@ -157,6 +161,5 @@ def initialize_oauth_middleware(mcp: Any, secret_key: str | None = None) -> None
157
161
  secret_key : Optional[str]
158
162
  Secret key for JWT validation. If None, uses the value from config.
159
163
  """
160
- middleware = OAuthMiddleWare(secret_key=secret_key)
161
- mcp.add_middleware(middleware)
164
+ mcp.add_middleware(OAuthMiddleWare())
162
165
  logger.info("OAuth middleware registered successfully")