datarobot-genai 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. datarobot_genai/__init__.py +19 -0
  2. datarobot_genai/core/__init__.py +0 -0
  3. datarobot_genai/core/agents/__init__.py +43 -0
  4. datarobot_genai/core/agents/base.py +195 -0
  5. datarobot_genai/core/chat/__init__.py +19 -0
  6. datarobot_genai/core/chat/auth.py +146 -0
  7. datarobot_genai/core/chat/client.py +178 -0
  8. datarobot_genai/core/chat/responses.py +297 -0
  9. datarobot_genai/core/cli/__init__.py +18 -0
  10. datarobot_genai/core/cli/agent_environment.py +47 -0
  11. datarobot_genai/core/cli/agent_kernel.py +211 -0
  12. datarobot_genai/core/custom_model.py +141 -0
  13. datarobot_genai/core/mcp/__init__.py +0 -0
  14. datarobot_genai/core/mcp/common.py +218 -0
  15. datarobot_genai/core/telemetry_agent.py +126 -0
  16. datarobot_genai/core/utils/__init__.py +3 -0
  17. datarobot_genai/core/utils/auth.py +234 -0
  18. datarobot_genai/core/utils/urls.py +64 -0
  19. datarobot_genai/crewai/__init__.py +24 -0
  20. datarobot_genai/crewai/agent.py +42 -0
  21. datarobot_genai/crewai/base.py +159 -0
  22. datarobot_genai/crewai/events.py +117 -0
  23. datarobot_genai/crewai/mcp.py +59 -0
  24. datarobot_genai/drmcp/__init__.py +78 -0
  25. datarobot_genai/drmcp/core/__init__.py +13 -0
  26. datarobot_genai/drmcp/core/auth.py +165 -0
  27. datarobot_genai/drmcp/core/clients.py +180 -0
  28. datarobot_genai/drmcp/core/config.py +250 -0
  29. datarobot_genai/drmcp/core/config_utils.py +174 -0
  30. datarobot_genai/drmcp/core/constants.py +18 -0
  31. datarobot_genai/drmcp/core/credentials.py +190 -0
  32. datarobot_genai/drmcp/core/dr_mcp_server.py +316 -0
  33. datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
  34. datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
  35. datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
  36. datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +128 -0
  37. datarobot_genai/drmcp/core/dynamic_prompts/register.py +206 -0
  38. datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
  39. datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
  40. datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
  41. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
  42. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
  43. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
  44. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
  45. datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
  46. datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
  47. datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
  48. datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
  49. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
  50. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
  51. datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
  52. datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
  53. datarobot_genai/drmcp/core/exceptions.py +25 -0
  54. datarobot_genai/drmcp/core/logging.py +98 -0
  55. datarobot_genai/drmcp/core/mcp_instance.py +542 -0
  56. datarobot_genai/drmcp/core/mcp_server_tools.py +129 -0
  57. datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
  58. datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
  59. datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
  60. datarobot_genai/drmcp/core/routes.py +436 -0
  61. datarobot_genai/drmcp/core/routes_utils.py +30 -0
  62. datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
  63. datarobot_genai/drmcp/core/telemetry.py +424 -0
  64. datarobot_genai/drmcp/core/tool_filter.py +108 -0
  65. datarobot_genai/drmcp/core/utils.py +131 -0
  66. datarobot_genai/drmcp/server.py +19 -0
  67. datarobot_genai/drmcp/test_utils/__init__.py +13 -0
  68. datarobot_genai/drmcp/test_utils/integration_mcp_server.py +102 -0
  69. datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +96 -0
  70. datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +94 -0
  71. datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +234 -0
  72. datarobot_genai/drmcp/test_utils/tool_base_ete.py +151 -0
  73. datarobot_genai/drmcp/test_utils/utils.py +91 -0
  74. datarobot_genai/drmcp/tools/__init__.py +14 -0
  75. datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
  76. datarobot_genai/drmcp/tools/predictive/data.py +97 -0
  77. datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
  78. datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
  79. datarobot_genai/drmcp/tools/predictive/model.py +148 -0
  80. datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
  81. datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
  82. datarobot_genai/drmcp/tools/predictive/project.py +72 -0
  83. datarobot_genai/drmcp/tools/predictive/training.py +651 -0
  84. datarobot_genai/langgraph/__init__.py +0 -0
  85. datarobot_genai/langgraph/agent.py +341 -0
  86. datarobot_genai/langgraph/mcp.py +73 -0
  87. datarobot_genai/llama_index/__init__.py +16 -0
  88. datarobot_genai/llama_index/agent.py +50 -0
  89. datarobot_genai/llama_index/base.py +299 -0
  90. datarobot_genai/llama_index/mcp.py +79 -0
  91. datarobot_genai/nat/__init__.py +0 -0
  92. datarobot_genai/nat/agent.py +258 -0
  93. datarobot_genai/nat/datarobot_llm_clients.py +249 -0
  94. datarobot_genai/nat/datarobot_llm_providers.py +130 -0
  95. datarobot_genai/py.typed +0 -0
  96. datarobot_genai-0.2.0.dist-info/METADATA +139 -0
  97. datarobot_genai-0.2.0.dist-info/RECORD +101 -0
  98. datarobot_genai-0.2.0.dist-info/WHEEL +4 -0
  99. datarobot_genai-0.2.0.dist-info/entry_points.txt +3 -0
  100. datarobot_genai-0.2.0.dist-info/licenses/AUTHORS +2 -0
  101. datarobot_genai-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,131 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import base64
15
+ import uuid
16
+ from typing import Any
17
+
18
+ import boto3
19
+ from fastmcp.resources import HttpResource
20
+ from fastmcp.tools.tool import ToolResult
21
+ from pydantic import BaseModel
22
+
23
+ from .constants import MAX_INLINE_SIZE
24
+ from .mcp_instance import mcp
25
+
26
+
27
+ def generate_presigned_url(bucket: str, key: str, expires_in: int = 2592000) -> str:
28
+ """
29
+ Generate a presigned S3 URL for the given bucket and key.
30
+ Args:
31
+ bucket (str): S3 bucket name.
32
+ key (str): S3 object key.
33
+ expires_in (int): Expiration in seconds (default 30 days).
34
+
35
+ Returns
36
+ -------
37
+ str: Presigned S3 URL for get_object.
38
+ """
39
+ s3 = boto3.client("s3")
40
+ result = s3.generate_presigned_url(
41
+ "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=expires_in
42
+ )
43
+ return str(result)
44
+
45
+
46
+ class PredictionResponse(BaseModel):
47
+ type: str
48
+ data: str | None = None
49
+ resource_id: str | None = None
50
+ s3_url: str | None = None
51
+ show_explanations: bool | None = None
52
+
53
+
54
+ def predictions_result_response(
55
+ df: Any, bucket: str, key: str, resource_name: str, show_explanations: bool = False
56
+ ) -> PredictionResponse:
57
+ csv_str = df.to_csv(index=False)
58
+ if len(csv_str.encode("utf-8")) < MAX_INLINE_SIZE:
59
+ return PredictionResponse(type="inline", data=csv_str, show_explanations=show_explanations)
60
+ else:
61
+ resource = save_df_to_s3_and_register_resource(df, bucket, key, resource_name)
62
+ return PredictionResponse(
63
+ type="resource",
64
+ resource_id=str(resource.uri),
65
+ s3_url=resource.url,
66
+ show_explanations=show_explanations,
67
+ )
68
+
69
+
70
+ def save_df_to_s3_and_register_resource(
71
+ df: Any, bucket: str, key: str, resource_name: str, mime_type: str = "text/csv"
72
+ ) -> HttpResource:
73
+ """
74
+ Save a DataFrame to a temp CSV, upload to S3, register as a resource, and return the
75
+ presigned URL.
76
+ Args:
77
+ df (pd.DataFrame): DataFrame to save and upload.
78
+ bucket (str): S3 bucket name.
79
+ key (str): S3 object key.
80
+ resource_name (str): Name for the registered resource.
81
+ mime_type (str): MIME type for the resource (default 'text/csv').
82
+
83
+ Returns
84
+ -------
85
+ str: Presigned S3 URL for the uploaded file.
86
+ """
87
+ temp_csv = f"/tmp/{uuid.uuid4()}.csv"
88
+ df.to_csv(temp_csv, index=False)
89
+ s3 = boto3.client("s3")
90
+ s3.upload_file(temp_csv, bucket, key)
91
+ s3_url = generate_presigned_url(bucket, key)
92
+ resource = HttpResource(
93
+ uri="predictions://" + uuid.uuid4().hex, # type: ignore[arg-type]
94
+ url=s3_url,
95
+ name=resource_name,
96
+ mime_type=mime_type,
97
+ )
98
+ mcp.add_resource(resource)
99
+ return resource
100
+
101
+
102
+ def format_response_as_tool_result(data: bytes, content_type: str, charset: str) -> ToolResult:
103
+ """Format the deployment response into a ToolResult.
104
+
105
+ Using structured_content, to return as much information about
106
+ the response as possible, for LLMs to correctly interpret the
107
+ response.
108
+ """
109
+ charset = charset or "utf-8"
110
+ content_type = content_type.lower() if content_type else ""
111
+
112
+ if content_type.startswith("text/") or content_type == "application/json":
113
+ payload = {
114
+ "type": "text",
115
+ "mime_type": content_type,
116
+ "data": data.decode(charset),
117
+ }
118
+ elif content_type.startswith("image/"):
119
+ payload = {
120
+ "type": "image",
121
+ "mime_type": content_type,
122
+ "data_base64": base64.b64encode(data).decode(charset),
123
+ }
124
+ else:
125
+ payload = {
126
+ "type": "binary",
127
+ "mime_type": content_type,
128
+ "data_base64": base64.b64encode(data).decode(charset),
129
+ }
130
+
131
+ return ToolResult(structured_content=payload)
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from datarobot_genai.drmcp import create_mcp_server
16
+
17
+ if __name__ == "__main__":
18
+ server = create_mcp_server()
19
+ server.run(show_banner=True)
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,102 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2025 DataRobot, Inc.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ Integration test MCP server.
19
+
20
+ This server works standalone (base tools only) or detects and loads
21
+ user modules if they exist in the project structure.
22
+ """
23
+
24
+ from pathlib import Path
25
+ from typing import Any
26
+
27
+ from datarobot_genai.drmcp import create_mcp_server
28
+
29
+ # Import user components (will be used conditionally)
30
+ try:
31
+ from app.core.server_lifecycle import ServerLifecycle # type: ignore # noqa: F401
32
+ from app.core.user_config import get_user_config # type: ignore # noqa: F401
33
+ from app.core.user_credentials import get_user_credentials # type: ignore # noqa: F401
34
+
35
+ except ImportError:
36
+ # These imports will fail when running from library without user modules
37
+ pass
38
+
39
+
40
+ def detect_user_modules() -> Any:
41
+ """
42
+ Detect if user modules exist in the project.
43
+
44
+ Returns
45
+ -------
46
+ Tuple of (config_factory, credentials_factory, lifecycle, module_paths) or None
47
+ """
48
+ # Try to find app directory
49
+ # When run from library: won't find it
50
+ # When run from project: will find it
51
+ current_dir = Path.cwd()
52
+
53
+ # Look for app in current directory or parent directories
54
+ for search_dir in [current_dir, current_dir.parent, current_dir.parent.parent]:
55
+ app_dir = search_dir / "app"
56
+ app_core_dir = app_dir / "core"
57
+ if app_core_dir.exists():
58
+ # Found user directory - load user modules
59
+ try:
60
+ module_paths = [
61
+ (str(app_dir / "tools"), "app.tools"),
62
+ (str(app_dir / "prompts"), "app.prompts"),
63
+ (str(app_dir / "resources"), "app.resources"),
64
+ ]
65
+
66
+ return (
67
+ get_user_config,
68
+ get_user_credentials,
69
+ ServerLifecycle(),
70
+ module_paths,
71
+ )
72
+ except ImportError:
73
+ # User modules don't exist or can't be imported
74
+ pass
75
+
76
+ return None
77
+
78
+
79
+ def main() -> None:
80
+ """Run the integration test MCP server."""
81
+ # Try to detect and load user modules
82
+ user_components = detect_user_modules()
83
+
84
+ if user_components:
85
+ # User modules found - create server with user extensions
86
+ config_factory, credentials_factory, lifecycle, module_paths = user_components
87
+ server = create_mcp_server(
88
+ config_factory=config_factory,
89
+ credentials_factory=credentials_factory,
90
+ lifecycle=lifecycle,
91
+ additional_module_paths=module_paths,
92
+ transport="stdio",
93
+ )
94
+ else:
95
+ # No user modules - create server with base tools only
96
+ server = create_mcp_server(transport="stdio")
97
+
98
+ server.run()
99
+
100
+
101
+ if __name__ == "__main__":
102
+ main()
@@ -0,0 +1,96 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import asyncio
15
+ import os
16
+ from collections.abc import AsyncGenerator
17
+ from contextlib import asynccontextmanager
18
+
19
+ from mcp import ClientSession
20
+ from mcp.client.streamable_http import streamablehttp_client
21
+
22
+ from .utils import load_env
23
+
24
+ load_env()
25
+
26
+
27
+ def get_dr_mcp_server_url() -> str | None:
28
+ """Get DataRobot MCP server URL."""
29
+ return os.environ.get("DR_MCP_SERVER_URL")
30
+
31
+
32
+ def get_openai_llm_client_config() -> dict[str, str]:
33
+ """Get OpenAI LLM client configuration."""
34
+ openai_api_key = os.environ.get("OPENAI_API_KEY")
35
+ openai_api_base = os.environ.get("OPENAI_API_BASE")
36
+ openai_api_deployment_id = os.environ.get("OPENAI_API_DEPLOYMENT_ID")
37
+ openai_api_version = os.environ.get("OPENAI_API_VERSION")
38
+ save_llm_responses = os.environ.get("SAVE_LLM_RESPONSES", "false").lower() == "true"
39
+
40
+ # Check for OpenAI configuration
41
+ if not openai_api_key:
42
+ raise ValueError("Missing required environment variable: OPENAI_API_KEY")
43
+ if (
44
+ openai_api_base and not openai_api_deployment_id
45
+ ): # For Azure OpenAI, we need additional variables
46
+ raise ValueError("Missing required environment variable: OPENAI_API_DEPLOYMENT_ID")
47
+
48
+ config: dict[str, str] = {
49
+ "openai_api_key": openai_api_key,
50
+ }
51
+
52
+ if openai_api_base:
53
+ config["openai_api_base"] = openai_api_base
54
+ if openai_api_deployment_id:
55
+ config["openai_api_deployment_id"] = openai_api_deployment_id
56
+ if openai_api_version:
57
+ config["openai_api_version"] = openai_api_version
58
+ config["save_llm_responses"] = str(save_llm_responses)
59
+
60
+ return config
61
+
62
+
63
+ def get_headers() -> dict[str, str]:
64
+ # When the MCP server is deployed in DataRobot, we have to include the API token in headers for
65
+ # authentication.
66
+ api_token = os.getenv("DATAROBOT_API_TOKEN")
67
+ headers = {"Authorization": f"Bearer {api_token}"}
68
+ return headers
69
+
70
+
71
+ @asynccontextmanager
72
+ async def ete_test_mcp_session(
73
+ additional_headers: dict[str, str] | None = None,
74
+ ) -> AsyncGenerator[ClientSession, None]:
75
+ """Create an MCP session for each test.
76
+
77
+ Parameters
78
+ ----------
79
+ additional_headers : dict[str, str], optional
80
+ Additional headers to include in the MCP session (e.g., auth headers for testing).
81
+ """
82
+ try:
83
+ headers = get_headers()
84
+ if additional_headers:
85
+ headers.update(additional_headers)
86
+
87
+ async with streamablehttp_client(url=get_dr_mcp_server_url(), headers=headers) as (
88
+ read_stream,
89
+ write_stream,
90
+ _,
91
+ ):
92
+ async with ClientSession(read_stream, write_stream) as session:
93
+ await asyncio.wait_for(session.initialize(), timeout=5)
94
+ yield session
95
+ except asyncio.TimeoutError:
96
+ raise TimeoutError(f"Check if the MCP server is running at {get_dr_mcp_server_url()}")
@@ -0,0 +1,94 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import contextlib
17
+ import os
18
+ from collections.abc import AsyncGenerator
19
+ from pathlib import Path
20
+
21
+ from mcp import ClientSession
22
+ from mcp.client.stdio import StdioServerParameters
23
+ from mcp.client.stdio import stdio_client
24
+
25
+ from .utils import load_env
26
+
27
+ load_env()
28
+
29
+
30
+ def integration_test_mcp_server_params() -> StdioServerParameters:
31
+ env = {
32
+ "DATAROBOT_API_TOKEN": os.environ.get("DATAROBOT_API_TOKEN") or "test-token",
33
+ "DATAROBOT_ENDPOINT": os.environ.get("DATAROBOT_ENDPOINT")
34
+ or "https://test.datarobot.com/api/v2",
35
+ "MCP_SERVER_LOG_LEVEL": os.environ.get("MCP_SERVER_LOG_LEVEL") or "WARNING",
36
+ "APP_LOG_LEVEL": os.environ.get("APP_LOG_LEVEL") or "WARNING",
37
+ "OTEL_ENABLED": os.environ.get("OTEL_ENABLED") or "false",
38
+ "MCP_SERVER_REGISTER_DYNAMIC_TOOLS_ON_STARTUP": os.environ.get(
39
+ "MCP_SERVER_REGISTER_DYNAMIC_TOOLS_ON_STARTUP"
40
+ )
41
+ or "false",
42
+ "MCP_SERVER_REGISTER_DYNAMIC_PROMPTS_ON_STARTUP": os.environ.get(
43
+ "MCP_SERVER_REGISTER_DYNAMIC_PROMPTS_ON_STARTUP"
44
+ )
45
+ or "true",
46
+ }
47
+
48
+ script_dir = Path(__file__).resolve().parent
49
+ server_script = str(script_dir / "integration_mcp_server.py")
50
+ # Add src/ directory to Python path so datarobot_genai can be imported
51
+ src_dir = script_dir.parent.parent.parent
52
+
53
+ return StdioServerParameters(
54
+ command="uv",
55
+ args=["run", server_script],
56
+ env={
57
+ "PYTHONPATH": str(src_dir),
58
+ "MCP_SERVER_NAME": "integration",
59
+ "MCP_SERVER_PORT": "8081",
60
+ **env,
61
+ },
62
+ )
63
+
64
+
65
+ @contextlib.asynccontextmanager
66
+ async def integration_test_mcp_session(
67
+ server_params: StdioServerParameters | None = None, timeout: int = 30
68
+ ) -> AsyncGenerator[ClientSession, None]:
69
+ """
70
+ Create and connect a client for the MCP server as a context manager.
71
+
72
+ Args:
73
+ server_params: Parameters for configuring the server connection
74
+ timeout: Timeout
75
+
76
+ Yields
77
+ ------
78
+ ClientSession: Connected MCP client session
79
+
80
+ Raises
81
+ ------
82
+ ConnectionError: If session initialization fails
83
+ TimeoutError: If session initialization exceeds timeout
84
+ """
85
+ server_params = server_params or integration_test_mcp_server_params()
86
+
87
+ try:
88
+ async with stdio_client(server_params) as (read_stream, write_stream):
89
+ async with ClientSession(read_stream, write_stream) as session:
90
+ await asyncio.wait_for(session.initialize(), timeout=timeout)
91
+ yield session
92
+
93
+ except asyncio.TimeoutError:
94
+ raise TimeoutError(f"Session initialization timed out after {timeout} seconds")
@@ -0,0 +1,234 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ from typing import Any
17
+
18
+ import openai
19
+ from mcp import ClientSession
20
+ from mcp.types import CallToolResult
21
+ from mcp.types import ListToolsResult
22
+ from mcp.types import TextContent
23
+ from openai.types.chat.chat_completion import ChatCompletion
24
+
25
+ from .utils import save_response_to_file
26
+
27
+
28
+ class ToolCall:
29
+ """Represents a tool call with its parameters and reasoning."""
30
+
31
+ def __init__(self, tool_name: str, parameters: dict[str, Any], reasoning: str):
32
+ self.tool_name = tool_name
33
+ self.parameters = parameters
34
+ self.reasoning = reasoning
35
+
36
+
37
+ class LLMResponse:
38
+ """Represents an LLM response with content and tool calls."""
39
+
40
+ def __init__(self, content: str, tool_calls: list[ToolCall], tool_results: list[str]):
41
+ self.content = content
42
+ self.tool_calls = tool_calls
43
+ self.tool_results = tool_results
44
+
45
+
46
+ class LLMMCPClient:
47
+ """Client for interacting with LLMs via MCP."""
48
+
49
+ def __init__(self, config: str):
50
+ """Initialize the LLM MCP client."""
51
+ # Parse config string to extract parameters
52
+ config_dict = eval(config) if isinstance(config, str) else config
53
+
54
+ openai_api_key = config_dict.get("openai_api_key")
55
+ openai_api_base = config_dict.get("openai_api_base")
56
+ openai_api_deployment_id = config_dict.get("openai_api_deployment_id")
57
+ model = config_dict.get("model", "gpt-3.5-turbo")
58
+ save_llm_responses = config_dict.get("save_llm_responses", True)
59
+
60
+ if openai_api_base and openai_api_deployment_id:
61
+ # Azure OpenAI
62
+ self.openai_client = openai.AzureOpenAI(
63
+ api_key=openai_api_key,
64
+ azure_endpoint=openai_api_base,
65
+ api_version=config_dict.get("openai_api_version", "2024-02-15-preview"),
66
+ )
67
+ self.model = openai_api_deployment_id
68
+ else:
69
+ # Regular OpenAI
70
+ self.openai_client = openai.OpenAI(api_key=openai_api_key) # type: ignore[assignment]
71
+ self.model = model
72
+
73
+ self.save_llm_responses = save_llm_responses
74
+ self.available_tools: list[dict[str, Any]] = []
75
+ self.available_prompts: list[dict[str, Any]] = []
76
+ self.available_resources: list[dict[str, Any]] = []
77
+
78
+ async def _add_mcp_tool_to_available_tools(self, mcp_session: ClientSession) -> None:
79
+ """Add a tool to the available tools."""
80
+ tools_result: ListToolsResult = await mcp_session.list_tools()
81
+ self.available_tools = [
82
+ {
83
+ "type": "function",
84
+ "function": {
85
+ "name": tool.name,
86
+ "description": tool.description,
87
+ "parameters": tool.inputSchema,
88
+ },
89
+ }
90
+ for tool in tools_result.tools
91
+ ]
92
+
93
+ async def _call_mcp_tool(
94
+ self, tool_name: str, parameters: dict[str, Any], mcp_session: ClientSession
95
+ ) -> str:
96
+ """Call an MCP tool and return the result as a string."""
97
+ result: CallToolResult = await mcp_session.call_tool(tool_name, parameters)
98
+ return (
99
+ result.content[0].text
100
+ if result.content and isinstance(result.content[0], TextContent)
101
+ else str(result.content)
102
+ )
103
+
104
+ async def _process_tool_calls(
105
+ self,
106
+ response: ChatCompletion,
107
+ messages: list[Any],
108
+ mcp_session: ClientSession,
109
+ ) -> tuple[list[ToolCall], list[str]]:
110
+ """Process tool calls from the response, and return the tool calls and tool results."""
111
+ tool_calls = []
112
+ tool_results = []
113
+
114
+ # If the response has tool calls, process them
115
+ if response.choices[0].message.tool_calls:
116
+ messages.append(response.choices[0].message) # Add assistant's message with tool calls
117
+
118
+ for tool_call in response.choices[0].message.tool_calls:
119
+ tool_name = tool_call.function.name # type: ignore[union-attr]
120
+ parameters = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
121
+
122
+ tool_calls.append(
123
+ ToolCall(
124
+ tool_name=tool_name,
125
+ parameters=parameters,
126
+ reasoning="Tool selected by LLM",
127
+ )
128
+ )
129
+
130
+ try:
131
+ result_text = await self._call_mcp_tool(tool_name, parameters, mcp_session)
132
+ tool_results.append(result_text)
133
+
134
+ # Add tool result to messages
135
+ messages.append(
136
+ {
137
+ "role": "tool",
138
+ "content": result_text,
139
+ "tool_call_id": tool_call.id,
140
+ "name": tool_name,
141
+ }
142
+ )
143
+ except Exception as e:
144
+ error_msg = f"Error calling {tool_name}: {str(e)}"
145
+ tool_results.append(error_msg)
146
+ messages.append(
147
+ {
148
+ "role": "tool",
149
+ "content": error_msg,
150
+ "tool_call_id": tool_call.id,
151
+ "name": tool_name,
152
+ }
153
+ )
154
+
155
+ return tool_calls, tool_results
156
+
157
+ async def _get_llm_response(
158
+ self, messages: list[dict[str, Any]], allow_tool_calls: bool = True
159
+ ) -> Any:
160
+ """Get a response from the LLM with optional tool calling capability."""
161
+ kwargs = {
162
+ "model": self.model,
163
+ "messages": messages,
164
+ }
165
+
166
+ if allow_tool_calls and self.available_tools:
167
+ kwargs["tools"] = self.available_tools
168
+ kwargs["tool_choice"] = "auto"
169
+
170
+ return self.openai_client.chat.completions.create(**kwargs)
171
+
172
+ async def process_prompt_with_mcp_support(
173
+ self, prompt: str, mcp_session: ClientSession, output_file_name: str = ""
174
+ ) -> LLMResponse:
175
+ """Process a prompt with MCP tool support."""
176
+ # Add MCP tools to available tools
177
+ await self._add_mcp_tool_to_available_tools(mcp_session)
178
+
179
+ if output_file_name:
180
+ print(f"Processing prompt for test: {output_file_name}")
181
+
182
+ # Initialize conversation
183
+ messages = [
184
+ {
185
+ "role": "system",
186
+ "content": (
187
+ "You are a helpful AI assistant that can use tools to help users. "
188
+ "If you need more information to provide a complete response, you can make "
189
+ "multiple tool calls. When dealing with file paths, use them as raw paths "
190
+ "without converting to file:// URLs."
191
+ ),
192
+ },
193
+ {"role": "user", "content": prompt},
194
+ ]
195
+
196
+ all_tool_calls = []
197
+ all_tool_results = []
198
+
199
+ while True:
200
+ # Get LLM response
201
+ response = await self._get_llm_response(messages)
202
+
203
+ # If no tool calls in response, this is the final response
204
+ if not response.choices[0].message.tool_calls:
205
+ final_response = response.choices[0].message.content
206
+ break
207
+
208
+ # Process tool calls
209
+ tool_calls, tool_results = await self._process_tool_calls(
210
+ response, messages, mcp_session
211
+ )
212
+ all_tool_calls.extend(tool_calls)
213
+ all_tool_results.extend(tool_results)
214
+
215
+ # Get another LLM response to see if we need more tool calls
216
+ response = await self._get_llm_response(messages, allow_tool_calls=True)
217
+
218
+ # If no more tool calls needed, this is the final response
219
+ if not response.choices[0].message.tool_calls:
220
+ final_response = response.choices[0].message.content
221
+ break
222
+
223
+ clean_content = final_response.replace("*", "").lower()
224
+
225
+ llm_response = LLMResponse(
226
+ content=clean_content,
227
+ tool_calls=all_tool_calls,
228
+ tool_results=all_tool_results,
229
+ )
230
+
231
+ if self.save_llm_responses:
232
+ save_response_to_file(llm_response, name=output_file_name)
233
+
234
+ return llm_response