datarobot-genai 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/__init__.py +19 -0
- datarobot_genai/core/__init__.py +0 -0
- datarobot_genai/core/agents/__init__.py +43 -0
- datarobot_genai/core/agents/base.py +195 -0
- datarobot_genai/core/chat/__init__.py +19 -0
- datarobot_genai/core/chat/auth.py +146 -0
- datarobot_genai/core/chat/client.py +178 -0
- datarobot_genai/core/chat/responses.py +297 -0
- datarobot_genai/core/cli/__init__.py +18 -0
- datarobot_genai/core/cli/agent_environment.py +47 -0
- datarobot_genai/core/cli/agent_kernel.py +211 -0
- datarobot_genai/core/custom_model.py +141 -0
- datarobot_genai/core/mcp/__init__.py +0 -0
- datarobot_genai/core/mcp/common.py +218 -0
- datarobot_genai/core/telemetry_agent.py +126 -0
- datarobot_genai/core/utils/__init__.py +3 -0
- datarobot_genai/core/utils/auth.py +234 -0
- datarobot_genai/core/utils/urls.py +64 -0
- datarobot_genai/crewai/__init__.py +24 -0
- datarobot_genai/crewai/agent.py +42 -0
- datarobot_genai/crewai/base.py +159 -0
- datarobot_genai/crewai/events.py +117 -0
- datarobot_genai/crewai/mcp.py +59 -0
- datarobot_genai/drmcp/__init__.py +78 -0
- datarobot_genai/drmcp/core/__init__.py +13 -0
- datarobot_genai/drmcp/core/auth.py +165 -0
- datarobot_genai/drmcp/core/clients.py +180 -0
- datarobot_genai/drmcp/core/config.py +364 -0
- datarobot_genai/drmcp/core/config_utils.py +174 -0
- datarobot_genai/drmcp/core/constants.py +18 -0
- datarobot_genai/drmcp/core/credentials.py +190 -0
- datarobot_genai/drmcp/core/dr_mcp_server.py +350 -0
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
- datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
- datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +70 -0
- datarobot_genai/drmcp/core/dynamic_prompts/register.py +205 -0
- datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
- datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
- datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
- datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
- datarobot_genai/drmcp/core/exceptions.py +25 -0
- datarobot_genai/drmcp/core/logging.py +98 -0
- datarobot_genai/drmcp/core/mcp_instance.py +515 -0
- datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
- datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
- datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
- datarobot_genai/drmcp/core/routes.py +439 -0
- datarobot_genai/drmcp/core/routes_utils.py +30 -0
- datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
- datarobot_genai/drmcp/core/telemetry.py +424 -0
- datarobot_genai/drmcp/core/tool_config.py +111 -0
- datarobot_genai/drmcp/core/tool_filter.py +117 -0
- datarobot_genai/drmcp/core/utils.py +138 -0
- datarobot_genai/drmcp/server.py +19 -0
- datarobot_genai/drmcp/test_utils/__init__.py +13 -0
- datarobot_genai/drmcp/test_utils/clients/__init__.py +0 -0
- datarobot_genai/drmcp/test_utils/clients/anthropic.py +68 -0
- datarobot_genai/drmcp/test_utils/clients/base.py +300 -0
- datarobot_genai/drmcp/test_utils/clients/dr_gateway.py +58 -0
- datarobot_genai/drmcp/test_utils/clients/openai.py +68 -0
- datarobot_genai/drmcp/test_utils/elicitation_test_tool.py +89 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +109 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +133 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +107 -0
- datarobot_genai/drmcp/test_utils/test_interactive.py +205 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +220 -0
- datarobot_genai/drmcp/test_utils/utils.py +91 -0
- datarobot_genai/drmcp/tools/__init__.py +14 -0
- datarobot_genai/drmcp/tools/clients/__init__.py +14 -0
- datarobot_genai/drmcp/tools/clients/atlassian.py +188 -0
- datarobot_genai/drmcp/tools/clients/confluence.py +584 -0
- datarobot_genai/drmcp/tools/clients/gdrive.py +832 -0
- datarobot_genai/drmcp/tools/clients/jira.py +334 -0
- datarobot_genai/drmcp/tools/clients/microsoft_graph.py +479 -0
- datarobot_genai/drmcp/tools/clients/s3.py +28 -0
- datarobot_genai/drmcp/tools/confluence/__init__.py +14 -0
- datarobot_genai/drmcp/tools/confluence/tools.py +321 -0
- datarobot_genai/drmcp/tools/gdrive/__init__.py +0 -0
- datarobot_genai/drmcp/tools/gdrive/tools.py +347 -0
- datarobot_genai/drmcp/tools/jira/__init__.py +14 -0
- datarobot_genai/drmcp/tools/jira/tools.py +243 -0
- datarobot_genai/drmcp/tools/microsoft_graph/__init__.py +13 -0
- datarobot_genai/drmcp/tools/microsoft_graph/tools.py +198 -0
- datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
- datarobot_genai/drmcp/tools/predictive/data.py +133 -0
- datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
- datarobot_genai/drmcp/tools/predictive/model.py +148 -0
- datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
- datarobot_genai/drmcp/tools/predictive/project.py +90 -0
- datarobot_genai/drmcp/tools/predictive/training.py +661 -0
- datarobot_genai/langgraph/__init__.py +0 -0
- datarobot_genai/langgraph/agent.py +341 -0
- datarobot_genai/langgraph/mcp.py +73 -0
- datarobot_genai/llama_index/__init__.py +16 -0
- datarobot_genai/llama_index/agent.py +50 -0
- datarobot_genai/llama_index/base.py +299 -0
- datarobot_genai/llama_index/mcp.py +79 -0
- datarobot_genai/nat/__init__.py +0 -0
- datarobot_genai/nat/agent.py +275 -0
- datarobot_genai/nat/datarobot_auth_provider.py +110 -0
- datarobot_genai/nat/datarobot_llm_clients.py +318 -0
- datarobot_genai/nat/datarobot_llm_providers.py +130 -0
- datarobot_genai/nat/datarobot_mcp_client.py +266 -0
- datarobot_genai/nat/helpers.py +87 -0
- datarobot_genai/py.typed +0 -0
- datarobot_genai-0.2.31.dist-info/METADATA +145 -0
- datarobot_genai-0.2.31.dist-info/RECORD +125 -0
- datarobot_genai-0.2.31.dist-info/WHEEL +4 -0
- datarobot_genai-0.2.31.dist-info/entry_points.txt +5 -0
- datarobot_genai-0.2.31.dist-info/licenses/AUTHORS +2 -0
- datarobot_genai-0.2.31.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
import warnings
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import jwt
|
|
19
|
+
from datarobot.auth.datarobot.oauth import AsyncOAuth as DatarobotAsyncOAuthClient
|
|
20
|
+
from datarobot.auth.identity import Identity
|
|
21
|
+
from datarobot.auth.oauth import AsyncOAuthComponent
|
|
22
|
+
from datarobot.auth.session import AuthCtx
|
|
23
|
+
from datarobot.core.config import DataRobotAppFrameworkBaseSettings
|
|
24
|
+
from datarobot.models.genai.agent.auth import ToolAuth
|
|
25
|
+
from datarobot.models.genai.agent.auth import get_authorization_context
|
|
26
|
+
from pydantic import BaseModel
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AuthContextConfig(DataRobotAppFrameworkBaseSettings):
|
|
32
|
+
session_secret_key: str = ""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class DRAppCtx(BaseModel):
|
|
36
|
+
"""DataRobot application context from authorization metadata."""
|
|
37
|
+
|
|
38
|
+
email: str | None = None
|
|
39
|
+
api_key: str | None = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class AuthContextHeaderHandler:
|
|
43
|
+
"""Manages encoding and decoding of authorization context into JWT tokens.
|
|
44
|
+
|
|
45
|
+
This class provides a consistent interface for encoding auth context into JWT tokens
|
|
46
|
+
and exchanging them via HTTP headers across multiple applications.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
HEADER_NAME = "X-DataRobot-Authorization-Context"
|
|
50
|
+
DEFAULT_ALGORITHM = "HS256"
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
secret_key: str | None = None,
|
|
55
|
+
algorithm: str = DEFAULT_ALGORITHM,
|
|
56
|
+
validate_signature: bool = True,
|
|
57
|
+
) -> None:
|
|
58
|
+
"""Initialize the handler.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
secret_key : Optional[str]
|
|
63
|
+
Secret key for JWT encoding/decoding. If None, tokens will be unsigned (insecure).
|
|
64
|
+
algorithm : str
|
|
65
|
+
JWT algorithm. Default is "HS256".
|
|
66
|
+
validate_signature : bool
|
|
67
|
+
Whether to validate JWT signatures. Default is True.
|
|
68
|
+
|
|
69
|
+
Raises
|
|
70
|
+
------
|
|
71
|
+
ValueError
|
|
72
|
+
If algorithm is 'none' (insecure).
|
|
73
|
+
"""
|
|
74
|
+
if algorithm is None:
|
|
75
|
+
raise ValueError("Algorithm None is not allowed. Use a secure algorithm like HS256.")
|
|
76
|
+
|
|
77
|
+
# Get secret key from parameter, config, or environment variable
|
|
78
|
+
# Handle the case where AuthContextConfig() initialization fails due to
|
|
79
|
+
# a bug in the datarobot package when SESSION_SECRET_KEY is not set
|
|
80
|
+
if secret_key:
|
|
81
|
+
self.secret_key = secret_key
|
|
82
|
+
else:
|
|
83
|
+
try:
|
|
84
|
+
config = AuthContextConfig()
|
|
85
|
+
self.secret_key = config.session_secret_key or ""
|
|
86
|
+
except (TypeError, AttributeError, Exception):
|
|
87
|
+
# Fallback to reading environment variable directly if config initialization fails
|
|
88
|
+
# This can happen when SESSION_SECRET_KEY is not set and the datarobot package's
|
|
89
|
+
# getenv function encounters a bug with None values
|
|
90
|
+
# it tries to check if "apiToken" in payload: when payload is None
|
|
91
|
+
self.secret_key = ""
|
|
92
|
+
|
|
93
|
+
self.algorithm = algorithm
|
|
94
|
+
self.validate_signature = validate_signature
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def header(self) -> str:
|
|
98
|
+
"""Get the header name for authorization context."""
|
|
99
|
+
return self.HEADER_NAME
|
|
100
|
+
|
|
101
|
+
def get_header(self, authorization_context: dict[str, Any] | None = None) -> dict[str, str]:
|
|
102
|
+
"""Get the authorization context header with encoded JWT token."""
|
|
103
|
+
token = self.encode(authorization_context)
|
|
104
|
+
if not token:
|
|
105
|
+
return {}
|
|
106
|
+
|
|
107
|
+
return {self.header: token}
|
|
108
|
+
|
|
109
|
+
def encode(self, authorization_context: dict[str, Any] | None = None) -> str | None:
|
|
110
|
+
"""Encode the current authorization context into a JWT token."""
|
|
111
|
+
auth_context = authorization_context or get_authorization_context()
|
|
112
|
+
if not auth_context:
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
if not self.secret_key:
|
|
116
|
+
warnings.warn(
|
|
117
|
+
"No secret key provided. Please make sure SESSION_SECRET_KEY is set. "
|
|
118
|
+
"JWT tokens will be signed with an empty key. This is insecure and should "
|
|
119
|
+
"only be used for testing."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return jwt.encode(auth_context, self.secret_key, algorithm=self.algorithm)
|
|
123
|
+
|
|
124
|
+
def decode(self, token: str) -> dict[str, Any] | None:
|
|
125
|
+
"""Decode a JWT token into the authorization context."""
|
|
126
|
+
if not token:
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
if not self.secret_key and self.validate_signature:
|
|
130
|
+
logger.error(
|
|
131
|
+
"No secret key provided. Cannot validate signature. "
|
|
132
|
+
"Provide a secret key or set validate_signature to False."
|
|
133
|
+
)
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
decoded = jwt.decode(
|
|
138
|
+
jwt=token,
|
|
139
|
+
key=self.secret_key,
|
|
140
|
+
algorithms=[self.algorithm],
|
|
141
|
+
options={"verify_signature": self.validate_signature},
|
|
142
|
+
)
|
|
143
|
+
except jwt.ExpiredSignatureError:
|
|
144
|
+
logger.info("JWT token has expired.")
|
|
145
|
+
return None
|
|
146
|
+
except jwt.InvalidTokenError:
|
|
147
|
+
logger.warning("JWT token is invalid or malformed.")
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
if not isinstance(decoded, dict):
|
|
151
|
+
logger.warning("Decoded JWT token is not a dictionary.")
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
return decoded
|
|
155
|
+
|
|
156
|
+
def get_context(self, headers: dict[str, str]) -> AuthCtx | None:
|
|
157
|
+
"""Extract and validate authorization context from headers.
|
|
158
|
+
|
|
159
|
+
Parameters
|
|
160
|
+
----------
|
|
161
|
+
headers : Dict[str, str]
|
|
162
|
+
HTTP headers containing the authorization context.
|
|
163
|
+
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
Optional[AuthCtx]
|
|
167
|
+
Validated authorization context or None if validation fails.
|
|
168
|
+
"""
|
|
169
|
+
token = headers.get(self.header) or headers.get(self.header.lower())
|
|
170
|
+
if not token:
|
|
171
|
+
logger.debug("No authorization context header found")
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
auth_ctx_dict = self.decode(token)
|
|
175
|
+
if not auth_ctx_dict:
|
|
176
|
+
logger.debug("Failed to decode auth context from token")
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
return AuthCtx(**auth_ctx_dict)
|
|
181
|
+
except Exception as e:
|
|
182
|
+
logger.error(f"Failed to create AuthCtx from decoded token: {e}", exc_info=True)
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class AsyncOAuthTokenProvider:
|
|
187
|
+
"""Manages OAuth access tokens using generic OAuth client."""
|
|
188
|
+
|
|
189
|
+
def __init__(self, auth_ctx: AuthCtx) -> None:
|
|
190
|
+
self.auth_ctx = auth_ctx
|
|
191
|
+
self.oauth_client = self._create_oauth_client()
|
|
192
|
+
|
|
193
|
+
def _get_identity(self, provider_type: str | None) -> Identity:
|
|
194
|
+
"""Retrieve the appropriate identity from the authentication context."""
|
|
195
|
+
identities = [x for x in self.auth_ctx.identities if x.provider_identity_id is not None]
|
|
196
|
+
|
|
197
|
+
if not identities:
|
|
198
|
+
raise ValueError("No identities found in authorization context.")
|
|
199
|
+
|
|
200
|
+
if provider_type is None:
|
|
201
|
+
if len(identities) > 1:
|
|
202
|
+
raise ValueError(
|
|
203
|
+
"Multiple identities found. Please specify 'provider_type' parameter."
|
|
204
|
+
)
|
|
205
|
+
return identities[0]
|
|
206
|
+
|
|
207
|
+
identity = next((id for id in identities if id.provider_type == provider_type), None)
|
|
208
|
+
|
|
209
|
+
if identity is None:
|
|
210
|
+
raise ValueError(f"No identity found for provider '{provider_type}'.")
|
|
211
|
+
|
|
212
|
+
return identity
|
|
213
|
+
|
|
214
|
+
async def get_token(self, auth_type: ToolAuth, provider_type: str | None = None) -> str:
|
|
215
|
+
"""Get OAuth access token using the specified method."""
|
|
216
|
+
if auth_type != ToolAuth.OBO:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"Unsupported auth type: {auth_type}. Only {ToolAuth.OBO} is supported."
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
identity = self._get_identity(provider_type)
|
|
222
|
+
token_data = await self.oauth_client.refresh_access_token(
|
|
223
|
+
identity_id=identity.provider_identity_id
|
|
224
|
+
)
|
|
225
|
+
return token_data.access_token
|
|
226
|
+
|
|
227
|
+
def _create_oauth_client(self) -> AsyncOAuthComponent:
|
|
228
|
+
"""Create either DataRobot or Authlib OAuth client based on
|
|
229
|
+
authorization context.
|
|
230
|
+
|
|
231
|
+
Note: at the moment, only DataRobot OAuth client is supported.
|
|
232
|
+
"""
|
|
233
|
+
logger.debug("Using DataRobot OAuth client")
|
|
234
|
+
return DatarobotAsyncOAuthClient()
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import re
|
|
16
|
+
from urllib.parse import urlparse
|
|
17
|
+
from urllib.parse import urlunparse
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_api_base(api_base: str, deployment_id: str | None) -> str:
|
|
21
|
+
"""
|
|
22
|
+
Construct the LiteLLM API base URL for a deployment.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
api_base : str
|
|
27
|
+
Base URL for the LiteLLM API.
|
|
28
|
+
deployment_id : str | None
|
|
29
|
+
Deployment identifier. When provided, a chat/completions URL is produced.
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
str
|
|
34
|
+
Normalized URL for the given deployment. Ensures a trailing slash unless the path
|
|
35
|
+
ends with "chat/completions" or already has a meaningful path component.
|
|
36
|
+
"""
|
|
37
|
+
# Normalize the URL and drop a trailing /api/v2 if present
|
|
38
|
+
parsed = urlparse(api_base)
|
|
39
|
+
path = re.sub(r"/api/v2/?$", "", parsed.path)
|
|
40
|
+
base_url = urlunparse(
|
|
41
|
+
(
|
|
42
|
+
parsed.scheme,
|
|
43
|
+
parsed.netloc,
|
|
44
|
+
path,
|
|
45
|
+
parsed.params,
|
|
46
|
+
parsed.query,
|
|
47
|
+
parsed.fragment,
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
base_url = base_url.rstrip("/")
|
|
51
|
+
|
|
52
|
+
# If the base_url already ends with chat/completions, return it.
|
|
53
|
+
if base_url.endswith("chat/completions"):
|
|
54
|
+
return base_url
|
|
55
|
+
|
|
56
|
+
# If the path contains deployments or genai, it's already a complete API path preserve it.
|
|
57
|
+
if path and ("deployments" in path or "genai" in path):
|
|
58
|
+
return f"{base_url}/" if not base_url.endswith("/") else base_url
|
|
59
|
+
|
|
60
|
+
# For all other cases (including custom base paths), apply deployment logic if needed.
|
|
61
|
+
if deployment_id:
|
|
62
|
+
return f"{base_url}/api/v2/deployments/{deployment_id}/chat/completions"
|
|
63
|
+
# Otherwise, just return the base URL with a trailing slash for normalization.
|
|
64
|
+
return f"{base_url}/"
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""CrewAI utilities and helpers.
|
|
2
|
+
|
|
3
|
+
Public API:
|
|
4
|
+
- mcp_tools_context: Context manager returning available MCP tools for CrewAI.
|
|
5
|
+
- build_llm: Construct a CrewAI LLM configured for DataRobot endpoints.
|
|
6
|
+
- create_pipeline_interactions_from_messages: Convert messages to MultiTurnSample.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from datarobot_genai.core.mcp.common import MCPConfig
|
|
10
|
+
|
|
11
|
+
from .agent import build_llm
|
|
12
|
+
from .agent import create_pipeline_interactions_from_messages
|
|
13
|
+
from .base import CrewAIAgent
|
|
14
|
+
from .events import CrewAIEventListener
|
|
15
|
+
from .mcp import mcp_tools_context
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"mcp_tools_context",
|
|
19
|
+
"CrewAIAgent",
|
|
20
|
+
"build_llm",
|
|
21
|
+
"create_pipeline_interactions_from_messages",
|
|
22
|
+
"CrewAIEventListener",
|
|
23
|
+
"MCPConfig",
|
|
24
|
+
]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from crewai import LLM
|
|
16
|
+
from ragas import MultiTurnSample
|
|
17
|
+
from ragas.messages import AIMessage
|
|
18
|
+
from ragas.messages import HumanMessage
|
|
19
|
+
from ragas.messages import ToolMessage
|
|
20
|
+
|
|
21
|
+
from datarobot_genai.core.utils.urls import get_api_base
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def build_llm(
|
|
25
|
+
*,
|
|
26
|
+
api_base: str,
|
|
27
|
+
api_key: str | None,
|
|
28
|
+
model: str,
|
|
29
|
+
deployment_id: str | None,
|
|
30
|
+
timeout: int,
|
|
31
|
+
) -> LLM:
|
|
32
|
+
"""Create a CrewAI LLM configured for DataRobot LLM Gateway or deployment."""
|
|
33
|
+
base = get_api_base(api_base, deployment_id)
|
|
34
|
+
return LLM(model=model, api_base=base, api_key=api_key, timeout=timeout)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def create_pipeline_interactions_from_messages(
|
|
38
|
+
messages: list[HumanMessage | AIMessage | ToolMessage] | None,
|
|
39
|
+
) -> MultiTurnSample | None:
|
|
40
|
+
if not messages:
|
|
41
|
+
return None
|
|
42
|
+
return MultiTurnSample(user_input=messages)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Base class for CrewAI-based agents.
|
|
17
|
+
|
|
18
|
+
Manages MCP tool lifecycle and standardizes kickoff flow.
|
|
19
|
+
|
|
20
|
+
Note: This base does not capture pipeline interactions; it returns None by
|
|
21
|
+
default. Subclasses may implement message capture if they need interactions.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import abc
|
|
27
|
+
import asyncio
|
|
28
|
+
from collections.abc import AsyncGenerator
|
|
29
|
+
from typing import Any
|
|
30
|
+
|
|
31
|
+
from crewai import Crew
|
|
32
|
+
from crewai.events.event_bus import CrewAIEventsBus
|
|
33
|
+
from crewai.tools import BaseTool
|
|
34
|
+
from openai.types.chat import CompletionCreateParams
|
|
35
|
+
from ragas import MultiTurnSample
|
|
36
|
+
|
|
37
|
+
from datarobot_genai.core.agents.base import BaseAgent
|
|
38
|
+
from datarobot_genai.core.agents.base import InvokeReturn
|
|
39
|
+
from datarobot_genai.core.agents.base import UsageMetrics
|
|
40
|
+
from datarobot_genai.core.agents.base import default_usage_metrics
|
|
41
|
+
from datarobot_genai.core.agents.base import extract_user_prompt_content
|
|
42
|
+
from datarobot_genai.core.agents.base import is_streaming
|
|
43
|
+
|
|
44
|
+
from .agent import create_pipeline_interactions_from_messages
|
|
45
|
+
from .mcp import mcp_tools_context
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
|
|
49
|
+
"""Abstract base agent for CrewAI workflows.
|
|
50
|
+
|
|
51
|
+
Subclasses should define the ``agents`` and ``tasks`` properties
|
|
52
|
+
and may override ``build_crewai_workflow`` to customize the workflow
|
|
53
|
+
construction.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
@abc.abstractmethod
|
|
58
|
+
def agents(self) -> list[Any]: # CrewAI Agent list
|
|
59
|
+
raise NotImplementedError
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
@abc.abstractmethod
|
|
63
|
+
def tasks(self) -> list[Any]: # CrewAI Task list
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
def build_crewai_workflow(self) -> Any:
|
|
67
|
+
"""Create a CrewAI workflow instance.
|
|
68
|
+
|
|
69
|
+
Default implementation constructs a Crew with provided agents and tasks.
|
|
70
|
+
Subclasses can override to customize Crew options.
|
|
71
|
+
"""
|
|
72
|
+
return Crew(agents=self.agents, tasks=self.tasks, verbose=self.verbose)
|
|
73
|
+
|
|
74
|
+
@abc.abstractmethod
|
|
75
|
+
def make_kickoff_inputs(self, user_prompt_content: str) -> dict[str, Any]:
|
|
76
|
+
"""Build the inputs dict for ``Crew.kickoff``.
|
|
77
|
+
|
|
78
|
+
Subclasses must implement this to provide the exact inputs required
|
|
79
|
+
by their CrewAI tasks.
|
|
80
|
+
"""
|
|
81
|
+
raise NotImplementedError
|
|
82
|
+
|
|
83
|
+
def _extract_pipeline_interactions(self) -> MultiTurnSample | None:
|
|
84
|
+
"""Extract pipeline interactions from event listener if available."""
|
|
85
|
+
if not hasattr(self, "event_listener"):
|
|
86
|
+
return None
|
|
87
|
+
try:
|
|
88
|
+
listener = getattr(self, "event_listener", None)
|
|
89
|
+
messages = getattr(listener, "messages", None) if listener is not None else None
|
|
90
|
+
return create_pipeline_interactions_from_messages(messages)
|
|
91
|
+
except Exception:
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
def _extract_usage_metrics(self, crew_output: Any) -> UsageMetrics:
|
|
95
|
+
"""Extract usage metrics from crew output."""
|
|
96
|
+
token_usage = getattr(crew_output, "token_usage", None)
|
|
97
|
+
if token_usage is not None:
|
|
98
|
+
return {
|
|
99
|
+
"completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
|
|
100
|
+
"prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
|
|
101
|
+
"total_tokens": int(getattr(token_usage, "total_tokens", 0)),
|
|
102
|
+
}
|
|
103
|
+
return default_usage_metrics()
|
|
104
|
+
|
|
105
|
+
def _process_crew_output(
|
|
106
|
+
self, crew_output: Any
|
|
107
|
+
) -> tuple[str, MultiTurnSample | None, UsageMetrics]:
|
|
108
|
+
"""Process crew output into response tuple."""
|
|
109
|
+
response_text = str(crew_output.raw)
|
|
110
|
+
pipeline_interactions = self._extract_pipeline_interactions()
|
|
111
|
+
usage_metrics = self._extract_usage_metrics(crew_output)
|
|
112
|
+
return response_text, pipeline_interactions, usage_metrics
|
|
113
|
+
|
|
114
|
+
async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
|
|
115
|
+
"""Run the CrewAI workflow with the provided completion parameters."""
|
|
116
|
+
user_prompt_content = extract_user_prompt_content(completion_create_params)
|
|
117
|
+
# Preserve prior template startup print for CLI parity
|
|
118
|
+
try:
|
|
119
|
+
print("Running agent with user prompt:", user_prompt_content, flush=True)
|
|
120
|
+
except Exception:
|
|
121
|
+
# Printing is best-effort; proceed regardless
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
# Use MCP context manager to handle connection lifecycle
|
|
125
|
+
with mcp_tools_context(
|
|
126
|
+
authorization_context=self._authorization_context,
|
|
127
|
+
forwarded_headers=self.forwarded_headers,
|
|
128
|
+
) as mcp_tools:
|
|
129
|
+
# Set MCP tools for all agents if MCP is not configured this is effectively a no-op
|
|
130
|
+
self.set_mcp_tools(mcp_tools)
|
|
131
|
+
|
|
132
|
+
# If an event listener is provided by the subclass/template, register it
|
|
133
|
+
if hasattr(self, "event_listener") and CrewAIEventsBus is not None:
|
|
134
|
+
try:
|
|
135
|
+
listener = getattr(self, "event_listener")
|
|
136
|
+
setup_fn = getattr(listener, "setup_listeners", None)
|
|
137
|
+
if callable(setup_fn):
|
|
138
|
+
setup_fn(CrewAIEventsBus)
|
|
139
|
+
except Exception:
|
|
140
|
+
# Listener is optional best-effort; proceed without failing invoke
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
crew = self.build_crewai_workflow()
|
|
144
|
+
|
|
145
|
+
if is_streaming(completion_create_params):
|
|
146
|
+
|
|
147
|
+
async def _gen() -> AsyncGenerator[
|
|
148
|
+
tuple[str, MultiTurnSample | None, UsageMetrics]
|
|
149
|
+
]:
|
|
150
|
+
crew_output = await asyncio.to_thread(
|
|
151
|
+
crew.kickoff,
|
|
152
|
+
inputs=self.make_kickoff_inputs(user_prompt_content),
|
|
153
|
+
)
|
|
154
|
+
yield self._process_crew_output(crew_output)
|
|
155
|
+
|
|
156
|
+
return _gen()
|
|
157
|
+
|
|
158
|
+
crew_output = crew.kickoff(inputs=self.make_kickoff_inputs(user_prompt_content))
|
|
159
|
+
return self._process_crew_output(crew_output)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import importlib
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from ragas.messages import AIMessage
|
|
22
|
+
from ragas.messages import HumanMessage
|
|
23
|
+
from ragas.messages import ToolCall
|
|
24
|
+
from ragas.messages import ToolMessage
|
|
25
|
+
|
|
26
|
+
# Resolve crewai symbols at runtime to avoid mypy issues with untyped packages
|
|
27
|
+
try:
|
|
28
|
+
_events_mod = importlib.import_module("crewai.events.event_types")
|
|
29
|
+
AgentExecutionCompletedEvent = getattr(_events_mod, "AgentExecutionCompletedEvent")
|
|
30
|
+
AgentExecutionStartedEvent = getattr(_events_mod, "AgentExecutionStartedEvent")
|
|
31
|
+
CrewKickoffStartedEvent = getattr(_events_mod, "CrewKickoffStartedEvent")
|
|
32
|
+
ToolUsageFinishedEvent = getattr(_events_mod, "ToolUsageFinishedEvent")
|
|
33
|
+
ToolUsageStartedEvent = getattr(_events_mod, "ToolUsageStartedEvent")
|
|
34
|
+
|
|
35
|
+
_bus_mod = importlib.import_module("crewai.events.event_bus")
|
|
36
|
+
CrewAIEventsBus = getattr(_bus_mod, "CrewAIEventsBus")
|
|
37
|
+
|
|
38
|
+
_base_mod = importlib.import_module("crewai.events.base_event_listener")
|
|
39
|
+
_RuntimeBaseEventListener = getattr(_base_mod, "BaseEventListener")
|
|
40
|
+
except Exception:
|
|
41
|
+
try: # pragma: no cover - compatibility for older crewai
|
|
42
|
+
_events_mod = importlib.import_module("crewai.utilities.events")
|
|
43
|
+
AgentExecutionCompletedEvent = getattr(_events_mod, "AgentExecutionCompletedEvent")
|
|
44
|
+
AgentExecutionStartedEvent = getattr(_events_mod, "AgentExecutionStartedEvent")
|
|
45
|
+
CrewKickoffStartedEvent = getattr(_events_mod, "CrewKickoffStartedEvent")
|
|
46
|
+
ToolUsageFinishedEvent = getattr(_events_mod, "ToolUsageFinishedEvent")
|
|
47
|
+
ToolUsageStartedEvent = getattr(_events_mod, "ToolUsageStartedEvent")
|
|
48
|
+
|
|
49
|
+
_bus_mod = importlib.import_module("crewai.utilities.events")
|
|
50
|
+
CrewAIEventsBus = getattr(_bus_mod, "CrewAIEventsBus")
|
|
51
|
+
_base_mod = importlib.import_module("crewai.utilities.events.base_event_listener")
|
|
52
|
+
_RuntimeBaseEventListener = getattr(_base_mod, "BaseEventListener")
|
|
53
|
+
except Exception:
|
|
54
|
+
raise ImportError(
|
|
55
|
+
"CrewAI is required for datarobot_genai.crewai.* modules. "
|
|
56
|
+
"Install with the CrewAI extra:\n"
|
|
57
|
+
" install 'datarobot-genai[crewai]'"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class CrewAIEventListener:
|
|
62
|
+
"""Collects CrewAI events into Ragas messages for pipeline interactions."""
|
|
63
|
+
|
|
64
|
+
def __init__(self) -> None:
|
|
65
|
+
self.messages: list[HumanMessage | AIMessage | ToolMessage] = []
|
|
66
|
+
|
|
67
|
+
def setup_listeners(self, crewai_event_bus: Any) -> None:
|
|
68
|
+
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
|
69
|
+
def on_crew_execution_started(_: Any, event: Any) -> None:
|
|
70
|
+
self.messages.append(
|
|
71
|
+
HumanMessage(content=f"Working on input '{json.dumps(event.inputs)}'")
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
@crewai_event_bus.on(AgentExecutionStartedEvent)
|
|
75
|
+
def on_agent_execution_started(_: Any, event: Any) -> None:
|
|
76
|
+
self.messages.append(AIMessage(content=event.task_prompt, tool_calls=[]))
|
|
77
|
+
|
|
78
|
+
@crewai_event_bus.on(AgentExecutionCompletedEvent)
|
|
79
|
+
def on_agent_execution_completed(_: Any, event: Any) -> None:
|
|
80
|
+
self.messages.append(AIMessage(content=event.output, tool_calls=[]))
|
|
81
|
+
|
|
82
|
+
@crewai_event_bus.on(ToolUsageStartedEvent)
|
|
83
|
+
def on_tool_usage_started(_: Any, event: Any) -> None:
|
|
84
|
+
# It's a tool call - add tool call to last AIMessage
|
|
85
|
+
if len(self.messages) == 0:
|
|
86
|
+
logging.warning("Direct tool usage without agent invocation")
|
|
87
|
+
return
|
|
88
|
+
last_message = self.messages[-1]
|
|
89
|
+
if not isinstance(last_message, AIMessage):
|
|
90
|
+
logging.warning(
|
|
91
|
+
"Tool call must be preceded by an AIMessage somewhere in the conversation."
|
|
92
|
+
)
|
|
93
|
+
return
|
|
94
|
+
if isinstance(event.tool_args, (str, bytes, bytearray)):
|
|
95
|
+
parsed_args: Any = json.loads(event.tool_args)
|
|
96
|
+
else:
|
|
97
|
+
parsed_args = event.tool_args
|
|
98
|
+
tool_call = ToolCall(name=event.tool_name, args=parsed_args)
|
|
99
|
+
if last_message.tool_calls is None:
|
|
100
|
+
last_message.tool_calls = []
|
|
101
|
+
last_message.tool_calls.append(tool_call)
|
|
102
|
+
|
|
103
|
+
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
|
104
|
+
def on_tool_usage_finished(_: Any, event: Any) -> None:
|
|
105
|
+
if len(self.messages) == 0:
|
|
106
|
+
logging.warning("Direct tool usage without agent invocation")
|
|
107
|
+
return
|
|
108
|
+
last_message = self.messages[-1]
|
|
109
|
+
if not isinstance(last_message, AIMessage):
|
|
110
|
+
logging.warning(
|
|
111
|
+
"Tool call must be preceded by an AIMessage somewhere in the conversation."
|
|
112
|
+
)
|
|
113
|
+
return
|
|
114
|
+
if not last_message.tool_calls:
|
|
115
|
+
logging.warning("No previous tool calls found")
|
|
116
|
+
return
|
|
117
|
+
self.messages.append(ToolMessage(content=event.output))
|