datarobot-genai 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/__init__.py +19 -0
- datarobot_genai/core/__init__.py +0 -0
- datarobot_genai/core/agents/__init__.py +43 -0
- datarobot_genai/core/agents/base.py +195 -0
- datarobot_genai/core/chat/__init__.py +19 -0
- datarobot_genai/core/chat/auth.py +146 -0
- datarobot_genai/core/chat/client.py +178 -0
- datarobot_genai/core/chat/responses.py +297 -0
- datarobot_genai/core/cli/__init__.py +18 -0
- datarobot_genai/core/cli/agent_environment.py +47 -0
- datarobot_genai/core/cli/agent_kernel.py +211 -0
- datarobot_genai/core/custom_model.py +141 -0
- datarobot_genai/core/mcp/__init__.py +0 -0
- datarobot_genai/core/mcp/common.py +218 -0
- datarobot_genai/core/telemetry_agent.py +126 -0
- datarobot_genai/core/utils/__init__.py +3 -0
- datarobot_genai/core/utils/auth.py +234 -0
- datarobot_genai/core/utils/urls.py +64 -0
- datarobot_genai/crewai/__init__.py +24 -0
- datarobot_genai/crewai/agent.py +42 -0
- datarobot_genai/crewai/base.py +159 -0
- datarobot_genai/crewai/events.py +117 -0
- datarobot_genai/crewai/mcp.py +59 -0
- datarobot_genai/drmcp/__init__.py +78 -0
- datarobot_genai/drmcp/core/__init__.py +13 -0
- datarobot_genai/drmcp/core/auth.py +165 -0
- datarobot_genai/drmcp/core/clients.py +180 -0
- datarobot_genai/drmcp/core/config.py +250 -0
- datarobot_genai/drmcp/core/config_utils.py +174 -0
- datarobot_genai/drmcp/core/constants.py +18 -0
- datarobot_genai/drmcp/core/credentials.py +190 -0
- datarobot_genai/drmcp/core/dr_mcp_server.py +316 -0
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
- datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
- datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +128 -0
- datarobot_genai/drmcp/core/dynamic_prompts/register.py +206 -0
- datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
- datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
- datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
- datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
- datarobot_genai/drmcp/core/exceptions.py +25 -0
- datarobot_genai/drmcp/core/logging.py +98 -0
- datarobot_genai/drmcp/core/mcp_instance.py +542 -0
- datarobot_genai/drmcp/core/mcp_server_tools.py +129 -0
- datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
- datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
- datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
- datarobot_genai/drmcp/core/routes.py +436 -0
- datarobot_genai/drmcp/core/routes_utils.py +30 -0
- datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
- datarobot_genai/drmcp/core/telemetry.py +424 -0
- datarobot_genai/drmcp/core/tool_filter.py +108 -0
- datarobot_genai/drmcp/core/utils.py +131 -0
- datarobot_genai/drmcp/server.py +19 -0
- datarobot_genai/drmcp/test_utils/__init__.py +13 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +102 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +96 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +94 -0
- datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +234 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +151 -0
- datarobot_genai/drmcp/test_utils/utils.py +91 -0
- datarobot_genai/drmcp/tools/__init__.py +14 -0
- datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
- datarobot_genai/drmcp/tools/predictive/data.py +97 -0
- datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
- datarobot_genai/drmcp/tools/predictive/model.py +148 -0
- datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
- datarobot_genai/drmcp/tools/predictive/project.py +72 -0
- datarobot_genai/drmcp/tools/predictive/training.py +651 -0
- datarobot_genai/langgraph/__init__.py +0 -0
- datarobot_genai/langgraph/agent.py +341 -0
- datarobot_genai/langgraph/mcp.py +73 -0
- datarobot_genai/llama_index/__init__.py +16 -0
- datarobot_genai/llama_index/agent.py +50 -0
- datarobot_genai/llama_index/base.py +299 -0
- datarobot_genai/llama_index/mcp.py +79 -0
- datarobot_genai/nat/__init__.py +0 -0
- datarobot_genai/nat/agent.py +258 -0
- datarobot_genai/nat/datarobot_llm_clients.py +249 -0
- datarobot_genai/nat/datarobot_llm_providers.py +130 -0
- datarobot_genai/py.typed +0 -0
- datarobot_genai-0.2.0.dist-info/METADATA +139 -0
- datarobot_genai-0.2.0.dist-info/RECORD +101 -0
- datarobot_genai-0.2.0.dist-info/WHEEL +4 -0
- datarobot_genai-0.2.0.dist-info/entry_points.txt +3 -0
- datarobot_genai-0.2.0.dist-info/licenses/AUTHORS +2 -0
- datarobot_genai-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import base64
|
|
15
|
+
import uuid
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import boto3
|
|
19
|
+
from fastmcp.resources import HttpResource
|
|
20
|
+
from fastmcp.tools.tool import ToolResult
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
|
|
23
|
+
from .constants import MAX_INLINE_SIZE
|
|
24
|
+
from .mcp_instance import mcp
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def generate_presigned_url(bucket: str, key: str, expires_in: int = 2592000) -> str:
|
|
28
|
+
"""
|
|
29
|
+
Generate a presigned S3 URL for the given bucket and key.
|
|
30
|
+
Args:
|
|
31
|
+
bucket (str): S3 bucket name.
|
|
32
|
+
key (str): S3 object key.
|
|
33
|
+
expires_in (int): Expiration in seconds (default 30 days).
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
str: Presigned S3 URL for get_object.
|
|
38
|
+
"""
|
|
39
|
+
s3 = boto3.client("s3")
|
|
40
|
+
result = s3.generate_presigned_url(
|
|
41
|
+
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=expires_in
|
|
42
|
+
)
|
|
43
|
+
return str(result)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class PredictionResponse(BaseModel):
|
|
47
|
+
type: str
|
|
48
|
+
data: str | None = None
|
|
49
|
+
resource_id: str | None = None
|
|
50
|
+
s3_url: str | None = None
|
|
51
|
+
show_explanations: bool | None = None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def predictions_result_response(
|
|
55
|
+
df: Any, bucket: str, key: str, resource_name: str, show_explanations: bool = False
|
|
56
|
+
) -> PredictionResponse:
|
|
57
|
+
csv_str = df.to_csv(index=False)
|
|
58
|
+
if len(csv_str.encode("utf-8")) < MAX_INLINE_SIZE:
|
|
59
|
+
return PredictionResponse(type="inline", data=csv_str, show_explanations=show_explanations)
|
|
60
|
+
else:
|
|
61
|
+
resource = save_df_to_s3_and_register_resource(df, bucket, key, resource_name)
|
|
62
|
+
return PredictionResponse(
|
|
63
|
+
type="resource",
|
|
64
|
+
resource_id=str(resource.uri),
|
|
65
|
+
s3_url=resource.url,
|
|
66
|
+
show_explanations=show_explanations,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def save_df_to_s3_and_register_resource(
|
|
71
|
+
df: Any, bucket: str, key: str, resource_name: str, mime_type: str = "text/csv"
|
|
72
|
+
) -> HttpResource:
|
|
73
|
+
"""
|
|
74
|
+
Save a DataFrame to a temp CSV, upload to S3, register as a resource, and return the
|
|
75
|
+
presigned URL.
|
|
76
|
+
Args:
|
|
77
|
+
df (pd.DataFrame): DataFrame to save and upload.
|
|
78
|
+
bucket (str): S3 bucket name.
|
|
79
|
+
key (str): S3 object key.
|
|
80
|
+
resource_name (str): Name for the registered resource.
|
|
81
|
+
mime_type (str): MIME type for the resource (default 'text/csv').
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
str: Presigned S3 URL for the uploaded file.
|
|
86
|
+
"""
|
|
87
|
+
temp_csv = f"/tmp/{uuid.uuid4()}.csv"
|
|
88
|
+
df.to_csv(temp_csv, index=False)
|
|
89
|
+
s3 = boto3.client("s3")
|
|
90
|
+
s3.upload_file(temp_csv, bucket, key)
|
|
91
|
+
s3_url = generate_presigned_url(bucket, key)
|
|
92
|
+
resource = HttpResource(
|
|
93
|
+
uri="predictions://" + uuid.uuid4().hex, # type: ignore[arg-type]
|
|
94
|
+
url=s3_url,
|
|
95
|
+
name=resource_name,
|
|
96
|
+
mime_type=mime_type,
|
|
97
|
+
)
|
|
98
|
+
mcp.add_resource(resource)
|
|
99
|
+
return resource
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def format_response_as_tool_result(data: bytes, content_type: str, charset: str) -> ToolResult:
|
|
103
|
+
"""Format the deployment response into a ToolResult.
|
|
104
|
+
|
|
105
|
+
Using structured_content, to return as much information about
|
|
106
|
+
the response as possible, for LLMs to correctly interpret the
|
|
107
|
+
response.
|
|
108
|
+
"""
|
|
109
|
+
charset = charset or "utf-8"
|
|
110
|
+
content_type = content_type.lower() if content_type else ""
|
|
111
|
+
|
|
112
|
+
if content_type.startswith("text/") or content_type == "application/json":
|
|
113
|
+
payload = {
|
|
114
|
+
"type": "text",
|
|
115
|
+
"mime_type": content_type,
|
|
116
|
+
"data": data.decode(charset),
|
|
117
|
+
}
|
|
118
|
+
elif content_type.startswith("image/"):
|
|
119
|
+
payload = {
|
|
120
|
+
"type": "image",
|
|
121
|
+
"mime_type": content_type,
|
|
122
|
+
"data_base64": base64.b64encode(data).decode(charset),
|
|
123
|
+
}
|
|
124
|
+
else:
|
|
125
|
+
payload = {
|
|
126
|
+
"type": "binary",
|
|
127
|
+
"mime_type": content_type,
|
|
128
|
+
"data_base64": base64.b64encode(data).decode(charset),
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
return ToolResult(structured_content=payload)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from datarobot_genai.drmcp import create_mcp_server
|
|
16
|
+
|
|
17
|
+
if __name__ == "__main__":
|
|
18
|
+
server = create_mcp_server()
|
|
19
|
+
server.run(show_banner=True)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2025 DataRobot, Inc.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
"""
|
|
18
|
+
Integration test MCP server.
|
|
19
|
+
|
|
20
|
+
This server works standalone (base tools only) or detects and loads
|
|
21
|
+
user modules if they exist in the project structure.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Any
|
|
26
|
+
|
|
27
|
+
from datarobot_genai.drmcp import create_mcp_server
|
|
28
|
+
|
|
29
|
+
# Import user components (will be used conditionally)
|
|
30
|
+
try:
|
|
31
|
+
from app.core.server_lifecycle import ServerLifecycle # type: ignore # noqa: F401
|
|
32
|
+
from app.core.user_config import get_user_config # type: ignore # noqa: F401
|
|
33
|
+
from app.core.user_credentials import get_user_credentials # type: ignore # noqa: F401
|
|
34
|
+
|
|
35
|
+
except ImportError:
|
|
36
|
+
# These imports will fail when running from library without user modules
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def detect_user_modules() -> Any:
|
|
41
|
+
"""
|
|
42
|
+
Detect if user modules exist in the project.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
Tuple of (config_factory, credentials_factory, lifecycle, module_paths) or None
|
|
47
|
+
"""
|
|
48
|
+
# Try to find app directory
|
|
49
|
+
# When run from library: won't find it
|
|
50
|
+
# When run from project: will find it
|
|
51
|
+
current_dir = Path.cwd()
|
|
52
|
+
|
|
53
|
+
# Look for app in current directory or parent directories
|
|
54
|
+
for search_dir in [current_dir, current_dir.parent, current_dir.parent.parent]:
|
|
55
|
+
app_dir = search_dir / "app"
|
|
56
|
+
app_core_dir = app_dir / "core"
|
|
57
|
+
if app_core_dir.exists():
|
|
58
|
+
# Found user directory - load user modules
|
|
59
|
+
try:
|
|
60
|
+
module_paths = [
|
|
61
|
+
(str(app_dir / "tools"), "app.tools"),
|
|
62
|
+
(str(app_dir / "prompts"), "app.prompts"),
|
|
63
|
+
(str(app_dir / "resources"), "app.resources"),
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
return (
|
|
67
|
+
get_user_config,
|
|
68
|
+
get_user_credentials,
|
|
69
|
+
ServerLifecycle(),
|
|
70
|
+
module_paths,
|
|
71
|
+
)
|
|
72
|
+
except ImportError:
|
|
73
|
+
# User modules don't exist or can't be imported
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def main() -> None:
|
|
80
|
+
"""Run the integration test MCP server."""
|
|
81
|
+
# Try to detect and load user modules
|
|
82
|
+
user_components = detect_user_modules()
|
|
83
|
+
|
|
84
|
+
if user_components:
|
|
85
|
+
# User modules found - create server with user extensions
|
|
86
|
+
config_factory, credentials_factory, lifecycle, module_paths = user_components
|
|
87
|
+
server = create_mcp_server(
|
|
88
|
+
config_factory=config_factory,
|
|
89
|
+
credentials_factory=credentials_factory,
|
|
90
|
+
lifecycle=lifecycle,
|
|
91
|
+
additional_module_paths=module_paths,
|
|
92
|
+
transport="stdio",
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
# No user modules - create server with base tools only
|
|
96
|
+
server = create_mcp_server(transport="stdio")
|
|
97
|
+
|
|
98
|
+
server.run()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
if __name__ == "__main__":
|
|
102
|
+
main()
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import asyncio
|
|
15
|
+
import os
|
|
16
|
+
from collections.abc import AsyncGenerator
|
|
17
|
+
from contextlib import asynccontextmanager
|
|
18
|
+
|
|
19
|
+
from mcp import ClientSession
|
|
20
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
21
|
+
|
|
22
|
+
from .utils import load_env
|
|
23
|
+
|
|
24
|
+
load_env()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_dr_mcp_server_url() -> str | None:
|
|
28
|
+
"""Get DataRobot MCP server URL."""
|
|
29
|
+
return os.environ.get("DR_MCP_SERVER_URL")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_openai_llm_client_config() -> dict[str, str]:
|
|
33
|
+
"""Get OpenAI LLM client configuration."""
|
|
34
|
+
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
|
35
|
+
openai_api_base = os.environ.get("OPENAI_API_BASE")
|
|
36
|
+
openai_api_deployment_id = os.environ.get("OPENAI_API_DEPLOYMENT_ID")
|
|
37
|
+
openai_api_version = os.environ.get("OPENAI_API_VERSION")
|
|
38
|
+
save_llm_responses = os.environ.get("SAVE_LLM_RESPONSES", "false").lower() == "true"
|
|
39
|
+
|
|
40
|
+
# Check for OpenAI configuration
|
|
41
|
+
if not openai_api_key:
|
|
42
|
+
raise ValueError("Missing required environment variable: OPENAI_API_KEY")
|
|
43
|
+
if (
|
|
44
|
+
openai_api_base and not openai_api_deployment_id
|
|
45
|
+
): # For Azure OpenAI, we need additional variables
|
|
46
|
+
raise ValueError("Missing required environment variable: OPENAI_API_DEPLOYMENT_ID")
|
|
47
|
+
|
|
48
|
+
config: dict[str, str] = {
|
|
49
|
+
"openai_api_key": openai_api_key,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
if openai_api_base:
|
|
53
|
+
config["openai_api_base"] = openai_api_base
|
|
54
|
+
if openai_api_deployment_id:
|
|
55
|
+
config["openai_api_deployment_id"] = openai_api_deployment_id
|
|
56
|
+
if openai_api_version:
|
|
57
|
+
config["openai_api_version"] = openai_api_version
|
|
58
|
+
config["save_llm_responses"] = str(save_llm_responses)
|
|
59
|
+
|
|
60
|
+
return config
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_headers() -> dict[str, str]:
|
|
64
|
+
# When the MCP server is deployed in DataRobot, we have to include the API token in headers for
|
|
65
|
+
# authentication.
|
|
66
|
+
api_token = os.getenv("DATAROBOT_API_TOKEN")
|
|
67
|
+
headers = {"Authorization": f"Bearer {api_token}"}
|
|
68
|
+
return headers
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@asynccontextmanager
|
|
72
|
+
async def ete_test_mcp_session(
|
|
73
|
+
additional_headers: dict[str, str] | None = None,
|
|
74
|
+
) -> AsyncGenerator[ClientSession, None]:
|
|
75
|
+
"""Create an MCP session for each test.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
additional_headers : dict[str, str], optional
|
|
80
|
+
Additional headers to include in the MCP session (e.g., auth headers for testing).
|
|
81
|
+
"""
|
|
82
|
+
try:
|
|
83
|
+
headers = get_headers()
|
|
84
|
+
if additional_headers:
|
|
85
|
+
headers.update(additional_headers)
|
|
86
|
+
|
|
87
|
+
async with streamablehttp_client(url=get_dr_mcp_server_url(), headers=headers) as (
|
|
88
|
+
read_stream,
|
|
89
|
+
write_stream,
|
|
90
|
+
_,
|
|
91
|
+
):
|
|
92
|
+
async with ClientSession(read_stream, write_stream) as session:
|
|
93
|
+
await asyncio.wait_for(session.initialize(), timeout=5)
|
|
94
|
+
yield session
|
|
95
|
+
except asyncio.TimeoutError:
|
|
96
|
+
raise TimeoutError(f"Check if the MCP server is running at {get_dr_mcp_server_url()}")
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import contextlib
|
|
17
|
+
import os
|
|
18
|
+
from collections.abc import AsyncGenerator
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
|
|
21
|
+
from mcp import ClientSession
|
|
22
|
+
from mcp.client.stdio import StdioServerParameters
|
|
23
|
+
from mcp.client.stdio import stdio_client
|
|
24
|
+
|
|
25
|
+
from .utils import load_env
|
|
26
|
+
|
|
27
|
+
load_env()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def integration_test_mcp_server_params() -> StdioServerParameters:
|
|
31
|
+
env = {
|
|
32
|
+
"DATAROBOT_API_TOKEN": os.environ.get("DATAROBOT_API_TOKEN") or "test-token",
|
|
33
|
+
"DATAROBOT_ENDPOINT": os.environ.get("DATAROBOT_ENDPOINT")
|
|
34
|
+
or "https://test.datarobot.com/api/v2",
|
|
35
|
+
"MCP_SERVER_LOG_LEVEL": os.environ.get("MCP_SERVER_LOG_LEVEL") or "WARNING",
|
|
36
|
+
"APP_LOG_LEVEL": os.environ.get("APP_LOG_LEVEL") or "WARNING",
|
|
37
|
+
"OTEL_ENABLED": os.environ.get("OTEL_ENABLED") or "false",
|
|
38
|
+
"MCP_SERVER_REGISTER_DYNAMIC_TOOLS_ON_STARTUP": os.environ.get(
|
|
39
|
+
"MCP_SERVER_REGISTER_DYNAMIC_TOOLS_ON_STARTUP"
|
|
40
|
+
)
|
|
41
|
+
or "false",
|
|
42
|
+
"MCP_SERVER_REGISTER_DYNAMIC_PROMPTS_ON_STARTUP": os.environ.get(
|
|
43
|
+
"MCP_SERVER_REGISTER_DYNAMIC_PROMPTS_ON_STARTUP"
|
|
44
|
+
)
|
|
45
|
+
or "true",
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
script_dir = Path(__file__).resolve().parent
|
|
49
|
+
server_script = str(script_dir / "integration_mcp_server.py")
|
|
50
|
+
# Add src/ directory to Python path so datarobot_genai can be imported
|
|
51
|
+
src_dir = script_dir.parent.parent.parent
|
|
52
|
+
|
|
53
|
+
return StdioServerParameters(
|
|
54
|
+
command="uv",
|
|
55
|
+
args=["run", server_script],
|
|
56
|
+
env={
|
|
57
|
+
"PYTHONPATH": str(src_dir),
|
|
58
|
+
"MCP_SERVER_NAME": "integration",
|
|
59
|
+
"MCP_SERVER_PORT": "8081",
|
|
60
|
+
**env,
|
|
61
|
+
},
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@contextlib.asynccontextmanager
|
|
66
|
+
async def integration_test_mcp_session(
|
|
67
|
+
server_params: StdioServerParameters | None = None, timeout: int = 30
|
|
68
|
+
) -> AsyncGenerator[ClientSession, None]:
|
|
69
|
+
"""
|
|
70
|
+
Create and connect a client for the MCP server as a context manager.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
server_params: Parameters for configuring the server connection
|
|
74
|
+
timeout: Timeout
|
|
75
|
+
|
|
76
|
+
Yields
|
|
77
|
+
------
|
|
78
|
+
ClientSession: Connected MCP client session
|
|
79
|
+
|
|
80
|
+
Raises
|
|
81
|
+
------
|
|
82
|
+
ConnectionError: If session initialization fails
|
|
83
|
+
TimeoutError: If session initialization exceeds timeout
|
|
84
|
+
"""
|
|
85
|
+
server_params = server_params or integration_test_mcp_server_params()
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
async with stdio_client(server_params) as (read_stream, write_stream):
|
|
89
|
+
async with ClientSession(read_stream, write_stream) as session:
|
|
90
|
+
await asyncio.wait_for(session.initialize(), timeout=timeout)
|
|
91
|
+
yield session
|
|
92
|
+
|
|
93
|
+
except asyncio.TimeoutError:
|
|
94
|
+
raise TimeoutError(f"Session initialization timed out after {timeout} seconds")
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import openai
|
|
19
|
+
from mcp import ClientSession
|
|
20
|
+
from mcp.types import CallToolResult
|
|
21
|
+
from mcp.types import ListToolsResult
|
|
22
|
+
from mcp.types import TextContent
|
|
23
|
+
from openai.types.chat.chat_completion import ChatCompletion
|
|
24
|
+
|
|
25
|
+
from .utils import save_response_to_file
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ToolCall:
|
|
29
|
+
"""Represents a tool call with its parameters and reasoning."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, tool_name: str, parameters: dict[str, Any], reasoning: str):
|
|
32
|
+
self.tool_name = tool_name
|
|
33
|
+
self.parameters = parameters
|
|
34
|
+
self.reasoning = reasoning
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class LLMResponse:
|
|
38
|
+
"""Represents an LLM response with content and tool calls."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, content: str, tool_calls: list[ToolCall], tool_results: list[str]):
|
|
41
|
+
self.content = content
|
|
42
|
+
self.tool_calls = tool_calls
|
|
43
|
+
self.tool_results = tool_results
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class LLMMCPClient:
|
|
47
|
+
"""Client for interacting with LLMs via MCP."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, config: str):
|
|
50
|
+
"""Initialize the LLM MCP client."""
|
|
51
|
+
# Parse config string to extract parameters
|
|
52
|
+
config_dict = eval(config) if isinstance(config, str) else config
|
|
53
|
+
|
|
54
|
+
openai_api_key = config_dict.get("openai_api_key")
|
|
55
|
+
openai_api_base = config_dict.get("openai_api_base")
|
|
56
|
+
openai_api_deployment_id = config_dict.get("openai_api_deployment_id")
|
|
57
|
+
model = config_dict.get("model", "gpt-3.5-turbo")
|
|
58
|
+
save_llm_responses = config_dict.get("save_llm_responses", True)
|
|
59
|
+
|
|
60
|
+
if openai_api_base and openai_api_deployment_id:
|
|
61
|
+
# Azure OpenAI
|
|
62
|
+
self.openai_client = openai.AzureOpenAI(
|
|
63
|
+
api_key=openai_api_key,
|
|
64
|
+
azure_endpoint=openai_api_base,
|
|
65
|
+
api_version=config_dict.get("openai_api_version", "2024-02-15-preview"),
|
|
66
|
+
)
|
|
67
|
+
self.model = openai_api_deployment_id
|
|
68
|
+
else:
|
|
69
|
+
# Regular OpenAI
|
|
70
|
+
self.openai_client = openai.OpenAI(api_key=openai_api_key) # type: ignore[assignment]
|
|
71
|
+
self.model = model
|
|
72
|
+
|
|
73
|
+
self.save_llm_responses = save_llm_responses
|
|
74
|
+
self.available_tools: list[dict[str, Any]] = []
|
|
75
|
+
self.available_prompts: list[dict[str, Any]] = []
|
|
76
|
+
self.available_resources: list[dict[str, Any]] = []
|
|
77
|
+
|
|
78
|
+
async def _add_mcp_tool_to_available_tools(self, mcp_session: ClientSession) -> None:
|
|
79
|
+
"""Add a tool to the available tools."""
|
|
80
|
+
tools_result: ListToolsResult = await mcp_session.list_tools()
|
|
81
|
+
self.available_tools = [
|
|
82
|
+
{
|
|
83
|
+
"type": "function",
|
|
84
|
+
"function": {
|
|
85
|
+
"name": tool.name,
|
|
86
|
+
"description": tool.description,
|
|
87
|
+
"parameters": tool.inputSchema,
|
|
88
|
+
},
|
|
89
|
+
}
|
|
90
|
+
for tool in tools_result.tools
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
async def _call_mcp_tool(
|
|
94
|
+
self, tool_name: str, parameters: dict[str, Any], mcp_session: ClientSession
|
|
95
|
+
) -> str:
|
|
96
|
+
"""Call an MCP tool and return the result as a string."""
|
|
97
|
+
result: CallToolResult = await mcp_session.call_tool(tool_name, parameters)
|
|
98
|
+
return (
|
|
99
|
+
result.content[0].text
|
|
100
|
+
if result.content and isinstance(result.content[0], TextContent)
|
|
101
|
+
else str(result.content)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
async def _process_tool_calls(
|
|
105
|
+
self,
|
|
106
|
+
response: ChatCompletion,
|
|
107
|
+
messages: list[Any],
|
|
108
|
+
mcp_session: ClientSession,
|
|
109
|
+
) -> tuple[list[ToolCall], list[str]]:
|
|
110
|
+
"""Process tool calls from the response, and return the tool calls and tool results."""
|
|
111
|
+
tool_calls = []
|
|
112
|
+
tool_results = []
|
|
113
|
+
|
|
114
|
+
# If the response has tool calls, process them
|
|
115
|
+
if response.choices[0].message.tool_calls:
|
|
116
|
+
messages.append(response.choices[0].message) # Add assistant's message with tool calls
|
|
117
|
+
|
|
118
|
+
for tool_call in response.choices[0].message.tool_calls:
|
|
119
|
+
tool_name = tool_call.function.name # type: ignore[union-attr]
|
|
120
|
+
parameters = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
|
|
121
|
+
|
|
122
|
+
tool_calls.append(
|
|
123
|
+
ToolCall(
|
|
124
|
+
tool_name=tool_name,
|
|
125
|
+
parameters=parameters,
|
|
126
|
+
reasoning="Tool selected by LLM",
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
result_text = await self._call_mcp_tool(tool_name, parameters, mcp_session)
|
|
132
|
+
tool_results.append(result_text)
|
|
133
|
+
|
|
134
|
+
# Add tool result to messages
|
|
135
|
+
messages.append(
|
|
136
|
+
{
|
|
137
|
+
"role": "tool",
|
|
138
|
+
"content": result_text,
|
|
139
|
+
"tool_call_id": tool_call.id,
|
|
140
|
+
"name": tool_name,
|
|
141
|
+
}
|
|
142
|
+
)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
error_msg = f"Error calling {tool_name}: {str(e)}"
|
|
145
|
+
tool_results.append(error_msg)
|
|
146
|
+
messages.append(
|
|
147
|
+
{
|
|
148
|
+
"role": "tool",
|
|
149
|
+
"content": error_msg,
|
|
150
|
+
"tool_call_id": tool_call.id,
|
|
151
|
+
"name": tool_name,
|
|
152
|
+
}
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return tool_calls, tool_results
|
|
156
|
+
|
|
157
|
+
async def _get_llm_response(
|
|
158
|
+
self, messages: list[dict[str, Any]], allow_tool_calls: bool = True
|
|
159
|
+
) -> Any:
|
|
160
|
+
"""Get a response from the LLM with optional tool calling capability."""
|
|
161
|
+
kwargs = {
|
|
162
|
+
"model": self.model,
|
|
163
|
+
"messages": messages,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
if allow_tool_calls and self.available_tools:
|
|
167
|
+
kwargs["tools"] = self.available_tools
|
|
168
|
+
kwargs["tool_choice"] = "auto"
|
|
169
|
+
|
|
170
|
+
return self.openai_client.chat.completions.create(**kwargs)
|
|
171
|
+
|
|
172
|
+
async def process_prompt_with_mcp_support(
|
|
173
|
+
self, prompt: str, mcp_session: ClientSession, output_file_name: str = ""
|
|
174
|
+
) -> LLMResponse:
|
|
175
|
+
"""Process a prompt with MCP tool support."""
|
|
176
|
+
# Add MCP tools to available tools
|
|
177
|
+
await self._add_mcp_tool_to_available_tools(mcp_session)
|
|
178
|
+
|
|
179
|
+
if output_file_name:
|
|
180
|
+
print(f"Processing prompt for test: {output_file_name}")
|
|
181
|
+
|
|
182
|
+
# Initialize conversation
|
|
183
|
+
messages = [
|
|
184
|
+
{
|
|
185
|
+
"role": "system",
|
|
186
|
+
"content": (
|
|
187
|
+
"You are a helpful AI assistant that can use tools to help users. "
|
|
188
|
+
"If you need more information to provide a complete response, you can make "
|
|
189
|
+
"multiple tool calls. When dealing with file paths, use them as raw paths "
|
|
190
|
+
"without converting to file:// URLs."
|
|
191
|
+
),
|
|
192
|
+
},
|
|
193
|
+
{"role": "user", "content": prompt},
|
|
194
|
+
]
|
|
195
|
+
|
|
196
|
+
all_tool_calls = []
|
|
197
|
+
all_tool_results = []
|
|
198
|
+
|
|
199
|
+
while True:
|
|
200
|
+
# Get LLM response
|
|
201
|
+
response = await self._get_llm_response(messages)
|
|
202
|
+
|
|
203
|
+
# If no tool calls in response, this is the final response
|
|
204
|
+
if not response.choices[0].message.tool_calls:
|
|
205
|
+
final_response = response.choices[0].message.content
|
|
206
|
+
break
|
|
207
|
+
|
|
208
|
+
# Process tool calls
|
|
209
|
+
tool_calls, tool_results = await self._process_tool_calls(
|
|
210
|
+
response, messages, mcp_session
|
|
211
|
+
)
|
|
212
|
+
all_tool_calls.extend(tool_calls)
|
|
213
|
+
all_tool_results.extend(tool_results)
|
|
214
|
+
|
|
215
|
+
# Get another LLM response to see if we need more tool calls
|
|
216
|
+
response = await self._get_llm_response(messages, allow_tool_calls=True)
|
|
217
|
+
|
|
218
|
+
# If no more tool calls needed, this is the final response
|
|
219
|
+
if not response.choices[0].message.tool_calls:
|
|
220
|
+
final_response = response.choices[0].message.content
|
|
221
|
+
break
|
|
222
|
+
|
|
223
|
+
clean_content = final_response.replace("*", "").lower()
|
|
224
|
+
|
|
225
|
+
llm_response = LLMResponse(
|
|
226
|
+
content=clean_content,
|
|
227
|
+
tool_calls=all_tool_calls,
|
|
228
|
+
tool_results=all_tool_results,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if self.save_llm_responses:
|
|
232
|
+
save_response_to_file(llm_response, name=output_file_name)
|
|
233
|
+
|
|
234
|
+
return llm_response
|