datarobot-genai 0.1.64__tar.gz → 0.1.71__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/PKG-INFO +2 -2
  2. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/pyproject.toml +2 -2
  3. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/agents/base.py +7 -0
  4. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/custom_model.py +5 -0
  5. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/mcp/common.py +87 -30
  6. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/crewai/base.py +34 -53
  7. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/crewai/mcp.py +5 -1
  8. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dr_mcp_server.py +10 -3
  9. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dr_mcp_server_logo.py +13 -2
  10. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +45 -0
  11. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/mcp_server_tools.py +2 -2
  12. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/routes.py +14 -1
  13. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/langgraph/agent.py +33 -14
  14. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/langgraph/mcp.py +7 -1
  15. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/llama_index/base.py +1 -0
  16. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/llama_index/mcp.py +6 -1
  17. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/nat/datarobot_llm_clients.py +66 -7
  18. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/nat/datarobot_llm_providers.py +32 -0
  19. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/.gitignore +0 -0
  20. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/AUTHORS +0 -0
  21. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/LICENSE +0 -0
  22. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/README.md +0 -0
  23. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/__init__.py +0 -0
  24. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/__init__.py +0 -0
  25. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/agents/__init__.py +0 -0
  26. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/chat/__init__.py +0 -0
  27. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/chat/auth.py +0 -0
  28. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/chat/client.py +0 -0
  29. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/chat/responses.py +0 -0
  30. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/cli/__init__.py +0 -0
  31. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/cli/agent_environment.py +0 -0
  32. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/cli/agent_kernel.py +0 -0
  33. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/mcp/__init__.py +0 -0
  34. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/telemetry_agent.py +0 -0
  35. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/utils/__init__.py +0 -0
  36. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/utils/auth.py +0 -0
  37. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/core/utils/urls.py +0 -0
  38. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/crewai/__init__.py +0 -0
  39. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/crewai/agent.py +0 -0
  40. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/crewai/events.py +0 -0
  41. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/__init__.py +0 -0
  42. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/__init__.py +0 -0
  43. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/auth.py +0 -0
  44. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/clients.py +0 -0
  45. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/config.py +0 -0
  46. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/config_utils.py +0 -0
  47. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/constants.py +0 -0
  48. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/credentials.py +0 -0
  49. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +0 -0
  50. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +0 -0
  51. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_prompts/register.py +0 -0
  52. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_prompts/utils.py +0 -0
  53. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/__init__.py +0 -0
  54. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
  55. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +0 -0
  56. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +0 -0
  57. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +0 -0
  58. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +0 -0
  59. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +0 -0
  60. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +0 -0
  61. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +0 -0
  62. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +0 -0
  63. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +0 -0
  64. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +0 -0
  65. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/register.py +0 -0
  66. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/dynamic_tools/schema.py +0 -0
  67. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/exceptions.py +0 -0
  68. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/logging.py +0 -0
  69. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/mcp_instance.py +0 -0
  70. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/memory_management/__init__.py +0 -0
  71. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/memory_management/manager.py +0 -0
  72. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/memory_management/memory_tools.py +0 -0
  73. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/routes_utils.py +0 -0
  74. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/server_life_cycle.py +0 -0
  75. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/telemetry.py +0 -0
  76. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/tool_filter.py +0 -0
  77. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/core/utils.py +0 -0
  78. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/server.py +0 -0
  79. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/test_utils/__init__.py +0 -0
  80. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/test_utils/integration_mcp_server.py +0 -0
  81. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +0 -0
  82. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +0 -0
  83. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +0 -0
  84. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/test_utils/tool_base_ete.py +0 -0
  85. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/test_utils/utils.py +0 -0
  86. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/tools/__init__.py +0 -0
  87. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/tools/predictive/__init__.py +0 -0
  88. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/tools/predictive/data.py +0 -0
  89. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/tools/predictive/deployment.py +0 -0
  90. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/tools/predictive/deployment_info.py +0 -0
  91. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/tools/predictive/model.py +0 -0
  92. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/tools/predictive/predict.py +0 -0
  93. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/tools/predictive/predict_realtime.py +0 -0
  94. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/tools/predictive/project.py +0 -0
  95. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/drmcp/tools/predictive/training.py +0 -0
  96. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/langgraph/__init__.py +0 -0
  97. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/llama_index/__init__.py +0 -0
  98. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/llama_index/agent.py +0 -0
  99. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/nat/__init__.py +0 -0
  100. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/nat/agent.py +0 -0
  101. {datarobot_genai-0.1.64 → datarobot_genai-0.1.71}/src/datarobot_genai/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datarobot-genai
3
- Version: 0.1.64
3
+ Version: 0.1.71
4
4
  Summary: Generic helpers for GenAI
5
5
  Project-URL: Homepage, https://github.com/datarobot-oss/datarobot-genai
6
6
  Author: DataRobot, Inc.
@@ -32,7 +32,7 @@ Requires-Dist: aiohttp<4.0.0,>=3.9.0; extra == 'drmcp'
32
32
  Requires-Dist: aiosignal<2.0.0,>=1.3.1; extra == 'drmcp'
33
33
  Requires-Dist: boto3<2.0.0,>=1.34.0; extra == 'drmcp'
34
34
  Requires-Dist: datarobot-asgi-middleware<1.0.0,>=0.2.0; extra == 'drmcp'
35
- Requires-Dist: fastmcp==2.13.0.2; extra == 'drmcp'
35
+ Requires-Dist: fastmcp<3.0.0,>=2.13.0.2; extra == 'drmcp'
36
36
  Requires-Dist: httpx<1.0.0,>=0.28.1; extra == 'drmcp'
37
37
  Requires-Dist: opentelemetry-api<2.0.0,>=1.22.0; extra == 'drmcp'
38
38
  Requires-Dist: opentelemetry-exporter-otlp-proto-http<2.0.0,>=1.22.0; extra == 'drmcp'
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "datarobot-genai"
7
- version = "0.1.64"
7
+ version = "0.1.71"
8
8
  description = "Generic helpers for GenAI"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10, <3.13"
@@ -84,7 +84,7 @@ drmcp = [
84
84
  "aiohttp>=3.9.0,<4.0.0",
85
85
  "aiohttp-retry>=2.8.3,<3.0.0",
86
86
  "aiosignal>=1.3.1,<2.0.0",
87
- "fastmcp==2.13.0.2",
87
+ "fastmcp>=2.13.0.2,<3.0.0",
88
88
  ]
89
89
 
90
90
  [tool.hatch.build.targets.wheel]
@@ -52,6 +52,7 @@ class BaseAgent(Generic[TTool], abc.ABC):
52
52
  verbose: bool | str | None = True,
53
53
  timeout: int | None = 90,
54
54
  authorization_context: dict[str, Any] | None = None,
55
+ forwarded_headers: dict[str, str] | None = None,
55
56
  **_: Any,
56
57
  ) -> None:
57
58
  self.api_key = api_key or os.environ.get("DATAROBOT_API_TOKEN")
@@ -68,6 +69,7 @@ class BaseAgent(Generic[TTool], abc.ABC):
68
69
  self.verbose = bool(verbose)
69
70
  self._mcp_tools: list[TTool] = []
70
71
  self._authorization_context = authorization_context or {}
72
+ self._forwarded_headers: dict[str, str] = forwarded_headers or {}
71
73
 
72
74
  def set_mcp_tools(self, tools: list[TTool]) -> None:
73
75
  self._mcp_tools = tools
@@ -86,6 +88,11 @@ class BaseAgent(Generic[TTool], abc.ABC):
86
88
  """Return the authorization context for this agent."""
87
89
  return self._authorization_context
88
90
 
91
+ @property
92
+ def forwarded_headers(self) -> dict[str, str]:
93
+ """Return the forwarded headers for this agent."""
94
+ return self._forwarded_headers
95
+
89
96
  def litellm_api_base(self, deployment_id: str | None) -> str:
90
97
  return get_api_base(self.api_base, deployment_id)
91
98
 
@@ -139,6 +139,11 @@ def chat_entrypoint(
139
139
  completion_create_params["authorization_context"] = resolve_authorization_context(
140
140
  completion_create_params, **kwargs
141
141
  )
142
+ # Keep only allowed headers from the forwarded_headers.
143
+ incoming_headers = kwargs.get("headers", {}) or {}
144
+ allowed_headers = {"x-datarobot-api-token", "x-datarobot-api-key"}
145
+ forwarded_headers = {k: v for k, v in incoming_headers.items() if k.lower() in allowed_headers}
146
+ completion_create_params["forwarded_headers"] = forwarded_headers
142
147
 
143
148
  # Instantiate user agent with all supplied completion params including auth context
144
149
  agent = agent_cls(**completion_create_params)
@@ -13,15 +13,20 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import json
16
+ import logging
16
17
  import re
18
+ from http import HTTPStatus
17
19
  from typing import Any
18
20
  from typing import Literal
19
21
 
22
+ import requests
20
23
  from datarobot.core.config import DataRobotAppFrameworkBaseSettings
21
24
  from pydantic import field_validator
22
25
 
23
26
  from datarobot_genai.core.utils.auth import AuthContextHeaderHandler
24
27
 
28
+ logger = logging.getLogger(__name__)
29
+
25
30
 
26
31
  class MCPConfig(DataRobotAppFrameworkBaseSettings):
27
32
  """Configuration for MCP server connection.
@@ -37,6 +42,8 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
37
42
  datarobot_endpoint: str | None = None
38
43
  datarobot_api_token: str | None = None
39
44
  authorization_context: dict[str, Any] | None = None
45
+ forwarded_headers: dict[str, str] | None = None
46
+ mcp_server_port: int | None = None
40
47
 
41
48
  _auth_context_handler: AuthContextHeaderHandler | None = None
42
49
  _server_config: dict[str, Any] | None = None
@@ -47,17 +54,14 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
47
54
  if value is None:
48
55
  return None
49
56
 
50
- if not isinstance(value, str):
51
- msg = "external_mcp_headers must be a JSON string"
52
- raise TypeError(msg)
53
-
54
57
  candidate = value.strip()
55
58
 
56
59
  try:
57
60
  json.loads(candidate)
58
- except json.JSONDecodeError as exc:
61
+ except json.JSONDecodeError:
59
62
  msg = "external_mcp_headers must be valid JSON"
60
- raise ValueError(msg) from exc
63
+ logger.warning(msg)
64
+ return None
61
65
 
62
66
  return candidate
63
67
 
@@ -67,15 +71,12 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
67
71
  if value is None:
68
72
  return None
69
73
 
70
- if not isinstance(value, str):
71
- msg = "mcp_deployment_id must be a string"
72
- raise TypeError(msg)
73
-
74
74
  candidate = value.strip()
75
75
 
76
76
  if not re.fullmatch(r"[0-9a-fA-F]{24}", candidate):
77
77
  msg = "mcp_deployment_id must be a valid 24-character hex ID"
78
- raise ValueError(msg)
78
+ logger.warning(msg)
79
+ return None
79
80
 
80
81
  return candidate
81
82
 
@@ -110,6 +111,45 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
110
111
  # Authorization context not available (e.g., in tests)
111
112
  return {}
112
113
 
114
+ def _build_authenticated_headers(self) -> dict[str, str]:
115
+ """Build headers for authenticated requests.
116
+
117
+ Returns
118
+ -------
119
+ Dictionary containing forwarded headers (if available) and authentication headers.
120
+ """
121
+ headers: dict[str, str] = {}
122
+ if self.forwarded_headers:
123
+ headers.update(self.forwarded_headers)
124
+ headers.update(self._authorization_bearer_header())
125
+ headers.update(self._authorization_context_header())
126
+ return headers
127
+
128
+ def _check_localhost_server(self, url: str, timeout: float = 2.0) -> bool:
129
+ """Check if MCP server is running on localhost.
130
+
131
+ Parameters
132
+ ----------
133
+ url : str
134
+ The URL to check.
135
+ timeout : float, optional
136
+ Request timeout in seconds (default: 2.0).
137
+
138
+ Returns
139
+ -------
140
+ bool
141
+ True if server is running and responding with OK status, False otherwise.
142
+ """
143
+ try:
144
+ response = requests.get(url, timeout=timeout)
145
+ return (
146
+ response.status_code == HTTPStatus.OK
147
+ and response.json().get("message") == "DataRobot MCP Server is running"
148
+ )
149
+ except requests.RequestException as e:
150
+ logger.debug(f"Failed to connect to MCP server at {url}: {e}")
151
+ return False
152
+
113
153
  def _build_server_config(self) -> dict[str, Any] | None:
114
154
  """
115
155
  Get MCP server configuration.
@@ -119,20 +159,7 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
119
159
  Server configuration dict with url, transport, and optional headers,
120
160
  or None if not configured.
121
161
  """
122
- if self.external_mcp_url:
123
- # External MCP URL - no authentication needed
124
- if self.external_mcp_headers:
125
- headers = json.loads(self.external_mcp_headers)
126
- else:
127
- headers = {}
128
-
129
- config = {
130
- "url": self.external_mcp_url.rstrip("/"),
131
- "transport": self.external_mcp_transport,
132
- "headers": headers,
133
- }
134
- return config
135
- elif self.mcp_deployment_id:
162
+ if self.mcp_deployment_id:
136
163
  # DataRobot deployment ID - requires authentication
137
164
  if self.datarobot_endpoint is None:
138
165
  raise ValueError(
@@ -142,15 +169,15 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
142
169
  raise ValueError(
143
170
  "When using a DataRobot hosted MCP deployment, datarobot_api_token must be set."
144
171
  )
172
+
145
173
  base_url = self.datarobot_endpoint.rstrip("/")
146
174
  if not base_url.endswith("/api/v2"):
147
- base_url = base_url + "/api/v2"
175
+ base_url = f"{base_url}/api/v2"
176
+
148
177
  url = f"{base_url}/deployments/{self.mcp_deployment_id}/directAccess/mcp"
178
+ headers = self._build_authenticated_headers()
149
179
 
150
- headers = {
151
- **self._authorization_bearer_header(),
152
- **self._authorization_context_header(),
153
- }
180
+ logger.info(f"Using DataRobot hosted MCP deployment: {url}")
154
181
 
155
182
  return {
156
183
  "url": url,
@@ -158,4 +185,34 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
158
185
  "headers": headers,
159
186
  }
160
187
 
188
+ if self.external_mcp_url:
189
+ # External MCP URL - no authentication needed
190
+ headers = {}
191
+
192
+ # Merge external headers if provided
193
+ if self.external_mcp_headers:
194
+ external_headers = json.loads(self.external_mcp_headers)
195
+ headers.update(external_headers)
196
+
197
+ logger.info(f"Using external MCP URL: {self.external_mcp_url}")
198
+
199
+ return {
200
+ "url": self.external_mcp_url.rstrip("/"),
201
+ "transport": self.external_mcp_transport,
202
+ "headers": headers,
203
+ }
204
+
205
+ # No MCP configuration found, setup localhost if running locally
206
+ if self.mcp_server_port:
207
+ url = f"http://localhost:{self.mcp_server_port}"
208
+ if self._check_localhost_server(url):
209
+ headers = self._build_authenticated_headers()
210
+ logger.info(f"Using localhost MCP server: {url}")
211
+ return {
212
+ "url": f"{url}/mcp",
213
+ "transport": "streamable-http",
214
+ "headers": headers,
215
+ }
216
+ logger.warning(f"MCP server is not running or not responding at {url}")
217
+
161
218
  return None
@@ -80,6 +80,37 @@ class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
80
80
  """
81
81
  raise NotImplementedError
82
82
 
83
+ def _extract_pipeline_interactions(self) -> MultiTurnSample | None:
84
+ """Extract pipeline interactions from event listener if available."""
85
+ if not hasattr(self, "event_listener"):
86
+ return None
87
+ try:
88
+ listener = getattr(self, "event_listener", None)
89
+ messages = getattr(listener, "messages", None) if listener is not None else None
90
+ return create_pipeline_interactions_from_messages(messages)
91
+ except Exception:
92
+ return None
93
+
94
+ def _extract_usage_metrics(self, crew_output: Any) -> UsageMetrics:
95
+ """Extract usage metrics from crew output."""
96
+ token_usage = getattr(crew_output, "token_usage", None)
97
+ if token_usage is not None:
98
+ return {
99
+ "completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
100
+ "prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
101
+ "total_tokens": int(getattr(token_usage, "total_tokens", 0)),
102
+ }
103
+ return default_usage_metrics()
104
+
105
+ def _process_crew_output(
106
+ self, crew_output: Any
107
+ ) -> tuple[str, MultiTurnSample | None, UsageMetrics]:
108
+ """Process crew output into response tuple."""
109
+ response_text = str(crew_output.raw)
110
+ pipeline_interactions = self._extract_pipeline_interactions()
111
+ usage_metrics = self._extract_usage_metrics(crew_output)
112
+ return response_text, pipeline_interactions, usage_metrics
113
+
83
114
  async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
84
115
  """Run the CrewAI workflow with the provided completion parameters."""
85
116
  user_prompt_content = extract_user_prompt_content(completion_create_params)
@@ -93,6 +124,7 @@ class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
93
124
  # Use MCP context manager to handle connection lifecycle
94
125
  with mcp_tools_context(
95
126
  authorization_context=self._authorization_context,
127
+ forwarded_headers=self.forwarded_headers,
96
128
  ) as mcp_tools:
97
129
  # Set MCP tools for all agents if MCP is not configured this is effectively a no-op
98
130
  self.set_mcp_tools(mcp_tools)
@@ -115,64 +147,13 @@ class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
115
147
  async def _gen() -> AsyncGenerator[
116
148
  tuple[str, MultiTurnSample | None, UsageMetrics]
117
149
  ]:
118
- # Run kickoff in a worker thread.
119
150
  crew_output = await asyncio.to_thread(
120
151
  crew.kickoff,
121
152
  inputs=self.make_kickoff_inputs(user_prompt_content),
122
153
  )
123
-
124
- pipeline_interactions = None
125
- if hasattr(self, "event_listener"):
126
- try:
127
- listener = getattr(self, "event_listener", None)
128
- messages = (
129
- getattr(listener, "messages", None)
130
- if listener is not None
131
- else None
132
- )
133
- pipeline_interactions = create_pipeline_interactions_from_messages(
134
- messages
135
- )
136
- except Exception:
137
- pipeline_interactions = None
138
-
139
- token_usage = getattr(crew_output, "token_usage", None)
140
- if token_usage is not None:
141
- usage_metrics: UsageMetrics = {
142
- "completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
143
- "prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
144
- "total_tokens": int(getattr(token_usage, "total_tokens", 0)),
145
- }
146
- else:
147
- usage_metrics = default_usage_metrics()
148
-
149
- # Finalize stream with empty chunk carrying interactions and usage
150
- yield "", pipeline_interactions, usage_metrics
154
+ yield self._process_crew_output(crew_output)
151
155
 
152
156
  return _gen()
153
157
 
154
- # Non-streaming: run to completion and return final result
155
158
  crew_output = crew.kickoff(inputs=self.make_kickoff_inputs(user_prompt_content))
156
-
157
- response_text = str(crew_output.raw)
158
-
159
- pipeline_interactions = None
160
- if hasattr(self, "event_listener"):
161
- try:
162
- listener = getattr(self, "event_listener", None)
163
- messages = getattr(listener, "messages", None) if listener is not None else None
164
- pipeline_interactions = create_pipeline_interactions_from_messages(messages)
165
- except Exception:
166
- pipeline_interactions = None
167
-
168
- token_usage = getattr(crew_output, "token_usage", None)
169
- if token_usage is not None:
170
- usage_metrics: UsageMetrics = {
171
- "completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
172
- "prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
173
- "total_tokens": int(getattr(token_usage, "total_tokens", 0)),
174
- }
175
- else:
176
- usage_metrics = default_usage_metrics()
177
-
178
- return response_text, pipeline_interactions, usage_metrics
159
+ return self._process_crew_output(crew_output)
@@ -30,9 +30,13 @@ from datarobot_genai.core.mcp.common import MCPConfig
30
30
  @contextmanager
31
31
  def mcp_tools_context(
32
32
  authorization_context: dict[str, Any] | None = None,
33
+ forwarded_headers: dict[str, str] | None = None,
33
34
  ) -> Generator[list[Any], None, None]:
34
35
  """Context manager for MCP tools that handles connection lifecycle."""
35
- config = MCPConfig(authorization_context=authorization_context)
36
+ config = MCPConfig(
37
+ authorization_context=authorization_context,
38
+ forwarded_headers=forwarded_headers,
39
+ )
36
40
  # If no MCP server configured, return empty tools list
37
41
  if not config.server_config:
38
42
  print("No MCP server configured, using empty tools list", flush=True)
@@ -184,13 +184,17 @@ class DataRobotMCPServer:
184
184
  prompts = asyncio.run(self._mcp._list_prompts_mcp())
185
185
  resources = asyncio.run(self._mcp._list_resources_mcp())
186
186
 
187
- self._logger.info(f"Registered tools: {len(tools)}")
187
+ tools_count = len(tools)
188
+ prompts_count = len(prompts)
189
+ resources_count = len(resources)
190
+
191
+ self._logger.info(f"Registered tools: {tools_count}")
188
192
  for tool in tools:
189
193
  self._logger.info(f" > {tool.name}")
190
- self._logger.info(f"Registered prompts: {len(prompts)}")
194
+ self._logger.info(f"Registered prompts: {prompts_count}")
191
195
  for prompt in prompts:
192
196
  self._logger.info(f" > {prompt.name}")
193
- self._logger.info(f"Registered resources: {len(resources)}")
197
+ self._logger.info(f"Registered resources: {resources_count}")
194
198
  for resource in resources:
195
199
  self._logger.info(f" > {resource.name}")
196
200
 
@@ -209,6 +213,9 @@ class DataRobotMCPServer:
209
213
  self._mcp,
210
214
  self._mcp_transport,
211
215
  port=self._config.mcp_server_port,
216
+ tools_count=tools_count,
217
+ prompts_count=prompts_count,
218
+ resources_count=resources_count,
212
219
  )
213
220
 
214
221
  if self._mcp_transport == "stdio":
@@ -38,7 +38,7 @@ def _apply_green(text: str) -> str:
38
38
  return "\n".join(colored_lines)
39
39
 
40
40
 
41
- DR_LOGO_ASCII = _apply_green("""\
41
+ DR_LOGO_ASCII = _apply_green(r"""
42
42
  ____ _ ____ _ _
43
43
  | _ \ __ _| |_ __ _| _ \ ___ | |__ ___ | |_
44
44
  | | | |/ _` | __/ _` | |_) / _ \| '_ \ / _ \| __|
@@ -54,6 +54,9 @@ def log_server_custom_banner(
54
54
  host: str | None = None,
55
55
  port: int | None = None,
56
56
  path: str | None = None,
57
+ tools_count: int | None = None,
58
+ prompts_count: int | None = None,
59
+ resources_count: int | None = None,
57
60
  ) -> None:
58
61
  """
59
62
  Create and log a formatted banner with server information and logo.
@@ -64,13 +67,20 @@ def log_server_custom_banner(
64
67
  host: Host address (for HTTP transports)
65
68
  port: Port number (for HTTP transports)
66
69
  path: Server path (for HTTP transports)
70
+ tools_count: Number of tools registered
71
+ prompts_count: Number of prompts registered
72
+ resources_count: Number of resources registered
67
73
  """
68
74
  # Create the logo text
69
75
  # Use Text with no_wrap and markup disabled to preserve ANSI escape codes
70
76
  logo_text = Text.from_ansi(DR_LOGO_ASCII, no_wrap=True)
71
77
 
72
78
  # Create the main title
73
- title_text = Text(f"DataRobot MCP Server {datarobot_genai_version}", style="bold green")
79
+ title_text = Text(f"DataRobot MCP Server {datarobot_genai_version}", style="dim green")
80
+ stats_text = Text(
81
+ f"{tools_count} tools, {prompts_count} prompts, {resources_count} resources",
82
+ style="bold green",
83
+ )
74
84
 
75
85
  # Create the information table
76
86
  info_table = Table.grid(padding=(0, 1))
@@ -107,6 +117,7 @@ def log_server_custom_banner(
107
117
  Align.center(logo_text),
108
118
  "",
109
119
  Align.center(title_text),
120
+ Align.center(stats_text),
110
121
  "",
111
122
  "",
112
123
  Align.center(info_table),
@@ -18,6 +18,8 @@ from fastmcp.prompts.prompt import Prompt
18
18
 
19
19
  from datarobot_genai.drmcp.core.dynamic_prompts.dr_lib import get_datarobot_prompt_template
20
20
  from datarobot_genai.drmcp.core.dynamic_prompts.dr_lib import get_datarobot_prompt_template_version
21
+ from datarobot_genai.drmcp.core.dynamic_prompts.dr_lib import get_datarobot_prompt_template_versions
22
+ from datarobot_genai.drmcp.core.dynamic_prompts.dr_lib import get_datarobot_prompt_templates
21
23
  from datarobot_genai.drmcp.core.dynamic_prompts.register import (
22
24
  register_prompt_from_datarobot_prompt_management,
23
25
  )
@@ -83,3 +85,46 @@ async def delete_registered_prompt_template(prompt_template_id: str) -> bool:
83
85
  f"version {prompt_template_version_id}"
84
86
  )
85
87
  return True
88
+
89
+
90
+ async def refresh_registered_prompt_template() -> None:
91
+ """Refresh all registered prompt templates in the MCP instance."""
92
+ prompt_templates = get_datarobot_prompt_templates()
93
+ prompt_templates_ids = {p.id for p in prompt_templates}
94
+ prompt_templates_versions = get_datarobot_prompt_template_versions(list(prompt_templates_ids))
95
+
96
+ mcp_prompt_templates_mappings = await mcp.get_prompt_mapping()
97
+
98
+ for prompt_template in prompt_templates:
99
+ prompt_template_versions = prompt_templates_versions.get(prompt_template.id)
100
+ if not prompt_template_versions:
101
+ continue
102
+
103
+ latest_version = max(prompt_template_versions, key=lambda v: v.version)
104
+
105
+ if prompt_template.id not in mcp_prompt_templates_mappings:
106
+ # New prompt template -> add
107
+ await register_prompt_from_datarobot_prompt_management(
108
+ prompt_template=prompt_template, prompt_template_version=latest_version
109
+ )
110
+ continue
111
+
112
+ mcp_prompt_template_version, mcp_prompt = mcp_prompt_templates_mappings[prompt_template.id]
113
+
114
+ if mcp_prompt_template_version != latest_version:
115
+ # Current version saved in MCP is not the latest one => update it
116
+ await register_prompt_from_datarobot_prompt_management(
117
+ prompt_template=prompt_template, prompt_template_version=latest_version
118
+ )
119
+ continue
120
+
121
+ # Else => mcp_prompt_template_version == latest_version
122
+ # For now it means nothing changed as there's no possibility to edit promp template version.
123
+
124
+ for mcp_prompt_template_id, (
125
+ mcp_prompt_template_version_id,
126
+ _,
127
+ ) in mcp_prompt_templates_mappings.items():
128
+ if mcp_prompt_template_id not in prompt_templates_ids:
129
+ # We need to also delete prompt templates that are
130
+ await mcp.remove_prompt_mapping(mcp_prompt_template_id, mcp_prompt_template_version_id)
@@ -51,7 +51,7 @@ async def list_tools_by_tags(tags: list[str] | None = None, match_all: bool = Fa
51
51
  -------
52
52
  A formatted string listing tools that match the tag criteria.
53
53
  """
54
- tools = await mcp._list_tools_mcp(tags=tags, match_all=match_all)
54
+ tools = await mcp.list_tools(tags=tags, match_all=match_all)
55
55
 
56
56
  if not tools:
57
57
  if tags:
@@ -95,7 +95,7 @@ async def get_tool_info_by_name(tool_name: str) -> str:
95
95
  -------
96
96
  A formatted string with detailed information about the tool.
97
97
  """
98
- all_tools = await mcp._list_tools_mcp()
98
+ all_tools = await mcp.list_tools()
99
99
 
100
100
  for tool in all_tools:
101
101
  if tool.name == tool_name:
@@ -19,6 +19,7 @@ from starlette.requests import Request
19
19
  from starlette.responses import JSONResponse
20
20
 
21
21
  from .dynamic_prompts.controllers import delete_registered_prompt_template
22
+ from .dynamic_prompts.controllers import refresh_registered_prompt_template
22
23
  from .dynamic_prompts.controllers import register_prompt_from_prompt_template_id_and_version
23
24
  from .dynamic_tools.deployment.controllers import delete_registered_tool_deployment
24
25
  from .dynamic_tools.deployment.controllers import get_registered_tool_deployments
@@ -418,6 +419,18 @@ def register_routes(mcp: TaggedFastMCP) -> None:
418
419
  )
419
420
  except Exception as e:
420
421
  return JSONResponse(
421
- status_code=HTTPStatus.BAD_REQUEST,
422
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
422
423
  content={"error": f"Failed to add prompt template: {str(e)}"},
423
424
  )
425
+
426
+ @mcp.custom_route(prefix_mount_path("/registeredPrompts"), methods=["PUT"])
427
+ async def refresh_prompt_templates(_: Request) -> JSONResponse:
428
+ """Refresh prompt templates."""
429
+ try:
430
+ await refresh_registered_prompt_template()
431
+ return JSONResponse(status_code=HTTPStatus.NO_CONTENT, content=None)
432
+ except Exception as e:
433
+ return JSONResponse(
434
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
435
+ content={"error": f"Failed to refresh prompt templates: {str(e)}"},
436
+ )
@@ -84,26 +84,45 @@ class LangGraphAgent(BaseAgent[BaseTool], abc.ABC):
84
84
  async def wrapped_generator() -> AsyncGenerator[
85
85
  tuple[str, Any | None, UsageMetrics], None
86
86
  ]:
87
- async with mcp_tools_context(
88
- authorization_context=self._authorization_context,
89
- ) as mcp_tools:
90
- self.set_mcp_tools(mcp_tools)
91
- result = await self._invoke(completion_create_params)
92
-
93
- # Yield all items from the result generator
94
- # The context will be closed when this generator is exhausted
95
- # Cast to async generator since we know stream=True means it's a generator
96
- result_generator = cast(
97
- AsyncGenerator[tuple[str, Any | None, UsageMetrics], None], result
98
- )
99
- async for item in result_generator:
100
- yield item
87
+ try:
88
+ async with mcp_tools_context(
89
+ authorization_context=self._authorization_context,
90
+ forwarded_headers=self.forwarded_headers,
91
+ ) as mcp_tools:
92
+ self.set_mcp_tools(mcp_tools)
93
+ result = await self._invoke(completion_create_params)
94
+
95
+ # Yield all items from the result generator
96
+ # The context will be closed when this generator is exhausted
97
+ # Cast to async generator since we know stream=True means it's a generator
98
+ result_generator = cast(
99
+ AsyncGenerator[tuple[str, Any | None, UsageMetrics], None], result
100
+ )
101
+ async for item in result_generator:
102
+ yield item
103
+ except RuntimeError as e:
104
+ error_message = str(e).lower()
105
+ if "different task" in error_message and "cancel scope" in error_message:
106
+ # Due to anyio task group constraints when consuming async generators
107
+ # across task boundaries, we cannot always clean up properly.
108
+ # The underlying HTTP client/connection pool should handle resource cleanup
109
+ # via timeouts and connection pooling, but this
110
+ # may lead to delayed resource release.
111
+ logger.debug(
112
+ "MCP context cleanup attempted in different task. "
113
+ "This is a limitation when consuming async generators "
114
+ "across task boundaries."
115
+ )
116
+ else:
117
+ # Re-raise if it's a different RuntimeError
118
+ raise
101
119
 
102
120
  return wrapped_generator()
103
121
  else:
104
122
  # For non-streaming, use async with directly
105
123
  async with mcp_tools_context(
106
124
  authorization_context=self._authorization_context,
125
+ forwarded_headers=self.forwarded_headers,
107
126
  ) as mcp_tools:
108
127
  self.set_mcp_tools(mcp_tools)
109
128
  result = await self._invoke(completion_create_params)
@@ -28,6 +28,7 @@ from datarobot_genai.core.mcp.common import MCPConfig
28
28
  @asynccontextmanager
29
29
  async def mcp_tools_context(
30
30
  authorization_context: dict[str, Any] | None = None,
31
+ forwarded_headers: dict[str, str] | None = None,
31
32
  ) -> AsyncGenerator[list[BaseTool], None]:
32
33
  """Yield a list of LangChain BaseTool instances loaded via MCP.
33
34
 
@@ -37,8 +38,13 @@ async def mcp_tools_context(
37
38
  ----------
38
39
  authorization_context : dict[str, Any] | None
39
40
  Authorization context to use for MCP connections
41
+ forwarded_headers : dict[str, str] | None
42
+ Forwarded headers, e.g. x-datarobot-api-key to use for MCP authentication
40
43
  """
41
- mcp_config = MCPConfig(authorization_context=authorization_context)
44
+ mcp_config = MCPConfig(
45
+ authorization_context=authorization_context,
46
+ forwarded_headers=forwarded_headers,
47
+ )
42
48
  server_config = mcp_config.server_config
43
49
 
44
50
  if not server_config:
@@ -84,6 +84,7 @@ class LlamaIndexAgent(BaseAgent[BaseTool], abc.ABC):
84
84
  # Load MCP tools (if configured) asynchronously before building workflow
85
85
  mcp_tools = await load_mcp_tools(
86
86
  authorization_context=self._authorization_context,
87
+ forwarded_headers=self.forwarded_headers,
87
88
  )
88
89
  self.set_mcp_tools(mcp_tools)
89
90
 
@@ -30,18 +30,23 @@ from datarobot_genai.core.mcp.common import MCPConfig
30
30
 
31
31
  async def load_mcp_tools(
32
32
  authorization_context: dict[str, Any] | None = None,
33
+ forwarded_headers: dict[str, str] | None = None,
33
34
  ) -> list[Any]:
34
35
  """
35
36
  Asynchronously load MCP tools for LlamaIndex.
36
37
 
37
38
  Args:
38
39
  authorization_context: Optional authorization context for MCP connections
40
+ forwarded_headers: Optional forwarded headers, e.g. x-datarobot-api-key for MCP auth
39
41
 
40
42
  Returns
41
43
  -------
42
44
  List of MCP tools, or empty list if no MCP configuration is present.
43
45
  """
44
- config = MCPConfig(authorization_context=authorization_context)
46
+ config = MCPConfig(
47
+ authorization_context=authorization_context,
48
+ forwarded_headers=forwarded_headers,
49
+ )
45
50
  server_params = config.server_config
46
51
 
47
52
  if not server_params:
@@ -23,6 +23,7 @@ from nat.builder.builder import Builder
23
23
  from nat.builder.framework_enum import LLMFrameworkEnum
24
24
  from nat.cli.register_workflow import register_llm_client
25
25
 
26
+ from ..nat.datarobot_llm_providers import DataRobotLLMComponentModelConfig
26
27
  from ..nat.datarobot_llm_providers import DataRobotLLMDeploymentModelConfig
27
28
  from ..nat.datarobot_llm_providers import DataRobotLLMGatewayModelConfig
28
29
  from ..nat.datarobot_llm_providers import DataRobotNIMModelConfig
@@ -75,6 +76,7 @@ async def datarobot_llm_gateway_langchain(
75
76
  config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
76
77
  config["base_url"] = config["base_url"] + "/genai/llmgw"
77
78
  config["stream_options"] = {"include_usage": True}
79
+ config["model"] = config["model"].removeprefix("datarobot/")
78
80
  yield DataRobotChatOpenAI(**config)
79
81
 
80
82
 
@@ -85,7 +87,8 @@ async def datarobot_llm_gateway_crewai(
85
87
  llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
86
88
  ) -> AsyncGenerator[LLM]:
87
89
  config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
88
- config["model"] = "datarobot/" + config["model"]
90
+ if not config["model"].startswith("datarobot/"):
91
+ config["model"] = "datarobot/" + config["model"]
89
92
  config["base_url"] = config["base_url"].removesuffix("/api/v2")
90
93
  yield LLM(**config)
91
94
 
@@ -97,7 +100,8 @@ async def datarobot_llm_gateway_llamaindex(
97
100
  llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
98
101
  ) -> AsyncGenerator[LLM]:
99
102
  config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
100
- config["model"] = "datarobot/" + config["model"]
103
+ if not config["model"].startswith("datarobot/"):
104
+ config["model"] = "datarobot/" + config["model"]
101
105
  config["api_base"] = config.pop("base_url").removesuffix("/api/v2")
102
106
  yield DataRobotLiteLLM(**config)
103
107
 
@@ -109,11 +113,12 @@ async def datarobot_llm_deployment_langchain(
109
113
  llm_config: DataRobotLLMDeploymentModelConfig, builder: Builder
110
114
  ) -> AsyncGenerator[ChatOpenAI]:
111
115
  config = llm_config.model_dump(
112
- exclude={"type", "thinking", "datarobot_endpoint", "llm_deployment_id"},
116
+ exclude={"type", "thinking"},
113
117
  by_alias=True,
114
118
  exclude_none=True,
115
119
  )
116
120
  config["stream_options"] = {"include_usage": True}
121
+ config["model"] = config["model"].removeprefix("datarobot/")
117
122
  yield DataRobotChatOpenAI(**config)
118
123
 
119
124
 
@@ -128,7 +133,8 @@ async def datarobot_llm_deployment_crewai(
128
133
  by_alias=True,
129
134
  exclude_none=True,
130
135
  )
131
- config["model"] = "datarobot/" + config["model"]
136
+ if not config["model"].startswith("datarobot/"):
137
+ config["model"] = "datarobot/" + config["model"]
132
138
  config["api_base"] = config.pop("base_url") + "/chat/completions"
133
139
  yield LLM(**config)
134
140
 
@@ -144,7 +150,8 @@ async def datarobot_llm_deployment_llamaindex(
144
150
  by_alias=True,
145
151
  exclude_none=True,
146
152
  )
147
- config["model"] = "datarobot/" + config["model"]
153
+ if not config["model"].startswith("datarobot/"):
154
+ config["model"] = "datarobot/" + config["model"]
148
155
  config["api_base"] = config.pop("base_url") + "/chat/completions"
149
156
  yield DataRobotLiteLLM(**config)
150
157
 
@@ -159,6 +166,7 @@ async def datarobot_nim_langchain(
159
166
  exclude_none=True,
160
167
  )
161
168
  config["stream_options"] = {"include_usage": True}
169
+ config["model"] = config["model"].removeprefix("datarobot/")
162
170
  yield DataRobotChatOpenAI(**config)
163
171
 
164
172
 
@@ -171,7 +179,8 @@ async def datarobot_nim_crewai(
171
179
  by_alias=True,
172
180
  exclude_none=True,
173
181
  )
174
- config["model"] = "datarobot/" + config["model"]
182
+ if not config["model"].startswith("datarobot/"):
183
+ config["model"] = "datarobot/" + config["model"]
175
184
  config["api_base"] = config.pop("base_url") + "/chat/completions"
176
185
  yield LLM(**config)
177
186
 
@@ -185,6 +194,56 @@ async def datarobot_nim_llamaindex(
185
194
  by_alias=True,
186
195
  exclude_none=True,
187
196
  )
188
- config["model"] = "datarobot/" + config["model"]
197
+ if not config["model"].startswith("datarobot/"):
198
+ config["model"] = "datarobot/" + config["model"]
189
199
  config["api_base"] = config.pop("base_url") + "/chat/completions"
190
200
  yield DataRobotLiteLLM(**config)
201
+
202
+
203
+ @register_llm_client(
204
+ config_type=DataRobotLLMComponentModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN
205
+ )
206
+ async def datarobot_llm_component_langchain(
207
+ llm_config: DataRobotLLMComponentModelConfig, builder: Builder
208
+ ) -> AsyncGenerator[ChatOpenAI]:
209
+ config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
210
+ if config["use_datarobot_llm_gateway"]:
211
+ config["base_url"] = config["base_url"] + "/genai/llmgw"
212
+ config["stream_options"] = {"include_usage": True}
213
+ config["model"] = config["model"].removeprefix("datarobot/")
214
+ config.pop("use_datarobot_llm_gateway")
215
+ yield DataRobotChatOpenAI(**config)
216
+
217
+
218
+ @register_llm_client(
219
+ config_type=DataRobotLLMComponentModelConfig, wrapper_type=LLMFrameworkEnum.CREWAI
220
+ )
221
+ async def datarobot_llm_component_crewai(
222
+ llm_config: DataRobotLLMComponentModelConfig, builder: Builder
223
+ ) -> AsyncGenerator[LLM]:
224
+ config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
225
+ if not config["model"].startswith("datarobot/"):
226
+ config["model"] = "datarobot/" + config["model"]
227
+ if config["use_datarobot_llm_gateway"]:
228
+ config["base_url"] = config["base_url"].removesuffix("/api/v2")
229
+ else:
230
+ config["api_base"] = config.pop("base_url") + "/chat/completions"
231
+ config.pop("use_datarobot_llm_gateway")
232
+ yield LLM(**config)
233
+
234
+
235
+ @register_llm_client(
236
+ config_type=DataRobotLLMComponentModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX
237
+ )
238
+ async def datarobot_llm_component_llamaindex(
239
+ llm_config: DataRobotLLMComponentModelConfig, builder: Builder
240
+ ) -> AsyncGenerator[LLM]:
241
+ config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
242
+ if not config["model"].startswith("datarobot/"):
243
+ config["model"] = "datarobot/" + config["model"]
244
+ if config["use_datarobot_llm_gateway"]:
245
+ config["api_base"] = config.pop("base_url").removesuffix("/api/v2")
246
+ else:
247
+ config["api_base"] = config.pop("base_url") + "/chat/completions"
248
+ config.pop("use_datarobot_llm_gateway")
249
+ yield DataRobotLiteLLM(**config)
@@ -32,11 +32,43 @@ class Config(DataRobotAppFrameworkBaseSettings):
32
32
  datarobot_api_token: str | None = None
33
33
  llm_deployment_id: str | None = None
34
34
  nim_deployment_id: str | None = None
35
+ use_datarobot_llm_gateway: bool = False
36
+ llm_default_model: str | None = None
35
37
 
36
38
 
37
39
  config = Config()
38
40
 
39
41
 
42
+ class DataRobotLLMComponentModelConfig(OpenAIModelConfig, name="datarobot-llm-component"): # type: ignore[call-arg]
43
+ """A DataRobot LLM provider to be used with an LLM client."""
44
+
45
+ api_key: str | None = Field(
46
+ default=config.datarobot_api_token, description="DataRobot API key."
47
+ )
48
+ base_url: str | None = Field(
49
+ default=config.datarobot_endpoint.rstrip("/")
50
+ if config.use_datarobot_llm_gateway
51
+ else config.datarobot_endpoint + f"/deployments/{config.llm_deployment_id}",
52
+ description="DataRobot LLM URL.",
53
+ )
54
+ model_name: str = Field(
55
+ validation_alias=AliasChoices("model_name", "model"),
56
+ serialization_alias="model",
57
+ description="The model name.",
58
+ default=config.llm_default_model or "datarobot-deployed-llm",
59
+ )
60
+ use_datarobot_llm_gateway: bool = config.use_datarobot_llm_gateway
61
+
62
+
63
+ @register_llm_provider(config_type=DataRobotLLMComponentModelConfig)
64
+ async def datarobot_llm_component(
65
+ config: DataRobotLLMComponentModelConfig, _builder: Builder
66
+ ) -> LLMProviderInfo:
67
+ yield LLMProviderInfo(
68
+ config=config, description="DataRobot LLM Component for use with an LLM client."
69
+ )
70
+
71
+
40
72
  class DataRobotLLMGatewayModelConfig(OpenAIModelConfig, name="datarobot-llm-gateway"): # type: ignore[call-arg]
41
73
  """A DataRobot LLM provider to be used with an LLM client."""
42
74