datarobot-genai 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. datarobot_genai/__init__.py +19 -0
  2. datarobot_genai/core/__init__.py +0 -0
  3. datarobot_genai/core/agents/__init__.py +43 -0
  4. datarobot_genai/core/agents/base.py +195 -0
  5. datarobot_genai/core/chat/__init__.py +19 -0
  6. datarobot_genai/core/chat/auth.py +146 -0
  7. datarobot_genai/core/chat/client.py +178 -0
  8. datarobot_genai/core/chat/responses.py +297 -0
  9. datarobot_genai/core/cli/__init__.py +18 -0
  10. datarobot_genai/core/cli/agent_environment.py +47 -0
  11. datarobot_genai/core/cli/agent_kernel.py +211 -0
  12. datarobot_genai/core/custom_model.py +141 -0
  13. datarobot_genai/core/mcp/__init__.py +0 -0
  14. datarobot_genai/core/mcp/common.py +218 -0
  15. datarobot_genai/core/telemetry_agent.py +126 -0
  16. datarobot_genai/core/utils/__init__.py +3 -0
  17. datarobot_genai/core/utils/auth.py +234 -0
  18. datarobot_genai/core/utils/urls.py +64 -0
  19. datarobot_genai/crewai/__init__.py +24 -0
  20. datarobot_genai/crewai/agent.py +42 -0
  21. datarobot_genai/crewai/base.py +159 -0
  22. datarobot_genai/crewai/events.py +117 -0
  23. datarobot_genai/crewai/mcp.py +59 -0
  24. datarobot_genai/drmcp/__init__.py +78 -0
  25. datarobot_genai/drmcp/core/__init__.py +13 -0
  26. datarobot_genai/drmcp/core/auth.py +165 -0
  27. datarobot_genai/drmcp/core/clients.py +180 -0
  28. datarobot_genai/drmcp/core/config.py +364 -0
  29. datarobot_genai/drmcp/core/config_utils.py +174 -0
  30. datarobot_genai/drmcp/core/constants.py +18 -0
  31. datarobot_genai/drmcp/core/credentials.py +190 -0
  32. datarobot_genai/drmcp/core/dr_mcp_server.py +350 -0
  33. datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
  34. datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
  35. datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
  36. datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +70 -0
  37. datarobot_genai/drmcp/core/dynamic_prompts/register.py +205 -0
  38. datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
  39. datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
  40. datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
  41. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
  42. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
  43. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
  44. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
  45. datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
  46. datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
  47. datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
  48. datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
  49. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
  50. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
  51. datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
  52. datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
  53. datarobot_genai/drmcp/core/exceptions.py +25 -0
  54. datarobot_genai/drmcp/core/logging.py +98 -0
  55. datarobot_genai/drmcp/core/mcp_instance.py +515 -0
  56. datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
  57. datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
  58. datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
  59. datarobot_genai/drmcp/core/routes.py +439 -0
  60. datarobot_genai/drmcp/core/routes_utils.py +30 -0
  61. datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
  62. datarobot_genai/drmcp/core/telemetry.py +424 -0
  63. datarobot_genai/drmcp/core/tool_config.py +111 -0
  64. datarobot_genai/drmcp/core/tool_filter.py +117 -0
  65. datarobot_genai/drmcp/core/utils.py +138 -0
  66. datarobot_genai/drmcp/server.py +19 -0
  67. datarobot_genai/drmcp/test_utils/__init__.py +13 -0
  68. datarobot_genai/drmcp/test_utils/clients/__init__.py +0 -0
  69. datarobot_genai/drmcp/test_utils/clients/anthropic.py +68 -0
  70. datarobot_genai/drmcp/test_utils/clients/base.py +300 -0
  71. datarobot_genai/drmcp/test_utils/clients/dr_gateway.py +58 -0
  72. datarobot_genai/drmcp/test_utils/clients/openai.py +68 -0
  73. datarobot_genai/drmcp/test_utils/elicitation_test_tool.py +89 -0
  74. datarobot_genai/drmcp/test_utils/integration_mcp_server.py +109 -0
  75. datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +133 -0
  76. datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +107 -0
  77. datarobot_genai/drmcp/test_utils/test_interactive.py +205 -0
  78. datarobot_genai/drmcp/test_utils/tool_base_ete.py +220 -0
  79. datarobot_genai/drmcp/test_utils/utils.py +91 -0
  80. datarobot_genai/drmcp/tools/__init__.py +14 -0
  81. datarobot_genai/drmcp/tools/clients/__init__.py +14 -0
  82. datarobot_genai/drmcp/tools/clients/atlassian.py +188 -0
  83. datarobot_genai/drmcp/tools/clients/confluence.py +584 -0
  84. datarobot_genai/drmcp/tools/clients/gdrive.py +832 -0
  85. datarobot_genai/drmcp/tools/clients/jira.py +334 -0
  86. datarobot_genai/drmcp/tools/clients/microsoft_graph.py +479 -0
  87. datarobot_genai/drmcp/tools/clients/s3.py +28 -0
  88. datarobot_genai/drmcp/tools/confluence/__init__.py +14 -0
  89. datarobot_genai/drmcp/tools/confluence/tools.py +321 -0
  90. datarobot_genai/drmcp/tools/gdrive/__init__.py +0 -0
  91. datarobot_genai/drmcp/tools/gdrive/tools.py +347 -0
  92. datarobot_genai/drmcp/tools/jira/__init__.py +14 -0
  93. datarobot_genai/drmcp/tools/jira/tools.py +243 -0
  94. datarobot_genai/drmcp/tools/microsoft_graph/__init__.py +13 -0
  95. datarobot_genai/drmcp/tools/microsoft_graph/tools.py +198 -0
  96. datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
  97. datarobot_genai/drmcp/tools/predictive/data.py +133 -0
  98. datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
  99. datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
  100. datarobot_genai/drmcp/tools/predictive/model.py +148 -0
  101. datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
  102. datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
  103. datarobot_genai/drmcp/tools/predictive/project.py +90 -0
  104. datarobot_genai/drmcp/tools/predictive/training.py +661 -0
  105. datarobot_genai/langgraph/__init__.py +0 -0
  106. datarobot_genai/langgraph/agent.py +341 -0
  107. datarobot_genai/langgraph/mcp.py +73 -0
  108. datarobot_genai/llama_index/__init__.py +16 -0
  109. datarobot_genai/llama_index/agent.py +50 -0
  110. datarobot_genai/llama_index/base.py +299 -0
  111. datarobot_genai/llama_index/mcp.py +79 -0
  112. datarobot_genai/nat/__init__.py +0 -0
  113. datarobot_genai/nat/agent.py +275 -0
  114. datarobot_genai/nat/datarobot_auth_provider.py +110 -0
  115. datarobot_genai/nat/datarobot_llm_clients.py +318 -0
  116. datarobot_genai/nat/datarobot_llm_providers.py +130 -0
  117. datarobot_genai/nat/datarobot_mcp_client.py +266 -0
  118. datarobot_genai/nat/helpers.py +87 -0
  119. datarobot_genai/py.typed +0 -0
  120. datarobot_genai-0.2.31.dist-info/METADATA +145 -0
  121. datarobot_genai-0.2.31.dist-info/RECORD +125 -0
  122. datarobot_genai-0.2.31.dist-info/WHEEL +4 -0
  123. datarobot_genai-0.2.31.dist-info/entry_points.txt +5 -0
  124. datarobot_genai-0.2.31.dist-info/licenses/AUTHORS +2 -0
  125. datarobot_genai-0.2.31.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,234 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import warnings
16
+ from typing import Any
17
+
18
+ import jwt
19
+ from datarobot.auth.datarobot.oauth import AsyncOAuth as DatarobotAsyncOAuthClient
20
+ from datarobot.auth.identity import Identity
21
+ from datarobot.auth.oauth import AsyncOAuthComponent
22
+ from datarobot.auth.session import AuthCtx
23
+ from datarobot.core.config import DataRobotAppFrameworkBaseSettings
24
+ from datarobot.models.genai.agent.auth import ToolAuth
25
+ from datarobot.models.genai.agent.auth import get_authorization_context
26
+ from pydantic import BaseModel
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class AuthContextConfig(DataRobotAppFrameworkBaseSettings):
32
+ session_secret_key: str = ""
33
+
34
+
35
+ class DRAppCtx(BaseModel):
36
+ """DataRobot application context from authorization metadata."""
37
+
38
+ email: str | None = None
39
+ api_key: str | None = None
40
+
41
+
42
+ class AuthContextHeaderHandler:
43
+ """Manages encoding and decoding of authorization context into JWT tokens.
44
+
45
+ This class provides a consistent interface for encoding auth context into JWT tokens
46
+ and exchanging them via HTTP headers across multiple applications.
47
+ """
48
+
49
+ HEADER_NAME = "X-DataRobot-Authorization-Context"
50
+ DEFAULT_ALGORITHM = "HS256"
51
+
52
+ def __init__(
53
+ self,
54
+ secret_key: str | None = None,
55
+ algorithm: str = DEFAULT_ALGORITHM,
56
+ validate_signature: bool = True,
57
+ ) -> None:
58
+ """Initialize the handler.
59
+
60
+ Parameters
61
+ ----------
62
+ secret_key : Optional[str]
63
+ Secret key for JWT encoding/decoding. If None, tokens will be unsigned (insecure).
64
+ algorithm : str
65
+ JWT algorithm. Default is "HS256".
66
+ validate_signature : bool
67
+ Whether to validate JWT signatures. Default is True.
68
+
69
+ Raises
70
+ ------
71
+ ValueError
72
+ If algorithm is 'none' (insecure).
73
+ """
74
+ if algorithm is None:
75
+ raise ValueError("Algorithm None is not allowed. Use a secure algorithm like HS256.")
76
+
77
+ # Get secret key from parameter, config, or environment variable
78
+ # Handle the case where AuthContextConfig() initialization fails due to
79
+ # a bug in the datarobot package when SESSION_SECRET_KEY is not set
80
+ if secret_key:
81
+ self.secret_key = secret_key
82
+ else:
83
+ try:
84
+ config = AuthContextConfig()
85
+ self.secret_key = config.session_secret_key or ""
86
+ except (TypeError, AttributeError, Exception):
87
+ # Fallback to reading environment variable directly if config initialization fails
88
+ # This can happen when SESSION_SECRET_KEY is not set and the datarobot package's
89
+ # getenv function encounters a bug with None values
90
+ # it tries to check if "apiToken" in payload: when payload is None
91
+ self.secret_key = ""
92
+
93
+ self.algorithm = algorithm
94
+ self.validate_signature = validate_signature
95
+
96
+ @property
97
+ def header(self) -> str:
98
+ """Get the header name for authorization context."""
99
+ return self.HEADER_NAME
100
+
101
+ def get_header(self, authorization_context: dict[str, Any] | None = None) -> dict[str, str]:
102
+ """Get the authorization context header with encoded JWT token."""
103
+ token = self.encode(authorization_context)
104
+ if not token:
105
+ return {}
106
+
107
+ return {self.header: token}
108
+
109
+ def encode(self, authorization_context: dict[str, Any] | None = None) -> str | None:
110
+ """Encode the current authorization context into a JWT token."""
111
+ auth_context = authorization_context or get_authorization_context()
112
+ if not auth_context:
113
+ return None
114
+
115
+ if not self.secret_key:
116
+ warnings.warn(
117
+ "No secret key provided. Please make sure SESSION_SECRET_KEY is set. "
118
+ "JWT tokens will be signed with an empty key. This is insecure and should "
119
+ "only be used for testing."
120
+ )
121
+
122
+ return jwt.encode(auth_context, self.secret_key, algorithm=self.algorithm)
123
+
124
+ def decode(self, token: str) -> dict[str, Any] | None:
125
+ """Decode a JWT token into the authorization context."""
126
+ if not token:
127
+ return None
128
+
129
+ if not self.secret_key and self.validate_signature:
130
+ logger.error(
131
+ "No secret key provided. Cannot validate signature. "
132
+ "Provide a secret key or set validate_signature to False."
133
+ )
134
+ return None
135
+
136
+ try:
137
+ decoded = jwt.decode(
138
+ jwt=token,
139
+ key=self.secret_key,
140
+ algorithms=[self.algorithm],
141
+ options={"verify_signature": self.validate_signature},
142
+ )
143
+ except jwt.ExpiredSignatureError:
144
+ logger.info("JWT token has expired.")
145
+ return None
146
+ except jwt.InvalidTokenError:
147
+ logger.warning("JWT token is invalid or malformed.")
148
+ return None
149
+
150
+ if not isinstance(decoded, dict):
151
+ logger.warning("Decoded JWT token is not a dictionary.")
152
+ return None
153
+
154
+ return decoded
155
+
156
+ def get_context(self, headers: dict[str, str]) -> AuthCtx | None:
157
+ """Extract and validate authorization context from headers.
158
+
159
+ Parameters
160
+ ----------
161
+ headers : Dict[str, str]
162
+ HTTP headers containing the authorization context.
163
+
164
+ Returns
165
+ -------
166
+ Optional[AuthCtx]
167
+ Validated authorization context or None if validation fails.
168
+ """
169
+ token = headers.get(self.header) or headers.get(self.header.lower())
170
+ if not token:
171
+ logger.debug("No authorization context header found")
172
+ return None
173
+
174
+ auth_ctx_dict = self.decode(token)
175
+ if not auth_ctx_dict:
176
+ logger.debug("Failed to decode auth context from token")
177
+ return None
178
+
179
+ try:
180
+ return AuthCtx(**auth_ctx_dict)
181
+ except Exception as e:
182
+ logger.error(f"Failed to create AuthCtx from decoded token: {e}", exc_info=True)
183
+ return None
184
+
185
+
186
+ class AsyncOAuthTokenProvider:
187
+ """Manages OAuth access tokens using generic OAuth client."""
188
+
189
+ def __init__(self, auth_ctx: AuthCtx) -> None:
190
+ self.auth_ctx = auth_ctx
191
+ self.oauth_client = self._create_oauth_client()
192
+
193
+ def _get_identity(self, provider_type: str | None) -> Identity:
194
+ """Retrieve the appropriate identity from the authentication context."""
195
+ identities = [x for x in self.auth_ctx.identities if x.provider_identity_id is not None]
196
+
197
+ if not identities:
198
+ raise ValueError("No identities found in authorization context.")
199
+
200
+ if provider_type is None:
201
+ if len(identities) > 1:
202
+ raise ValueError(
203
+ "Multiple identities found. Please specify 'provider_type' parameter."
204
+ )
205
+ return identities[0]
206
+
207
+ identity = next((id for id in identities if id.provider_type == provider_type), None)
208
+
209
+ if identity is None:
210
+ raise ValueError(f"No identity found for provider '{provider_type}'.")
211
+
212
+ return identity
213
+
214
+ async def get_token(self, auth_type: ToolAuth, provider_type: str | None = None) -> str:
215
+ """Get OAuth access token using the specified method."""
216
+ if auth_type != ToolAuth.OBO:
217
+ raise ValueError(
218
+ f"Unsupported auth type: {auth_type}. Only {ToolAuth.OBO} is supported."
219
+ )
220
+
221
+ identity = self._get_identity(provider_type)
222
+ token_data = await self.oauth_client.refresh_access_token(
223
+ identity_id=identity.provider_identity_id
224
+ )
225
+ return token_data.access_token
226
+
227
+ def _create_oauth_client(self) -> AsyncOAuthComponent:
228
+ """Create either DataRobot or Authlib OAuth client based on
229
+ authorization context.
230
+
231
+ Note: at the moment, only DataRobot OAuth client is supported.
232
+ """
233
+ logger.debug("Using DataRobot OAuth client")
234
+ return DatarobotAsyncOAuthClient()
@@ -0,0 +1,64 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from urllib.parse import urlparse
17
+ from urllib.parse import urlunparse
18
+
19
+
20
+ def get_api_base(api_base: str, deployment_id: str | None) -> str:
21
+ """
22
+ Construct the LiteLLM API base URL for a deployment.
23
+
24
+ Parameters
25
+ ----------
26
+ api_base : str
27
+ Base URL for the LiteLLM API.
28
+ deployment_id : str | None
29
+ Deployment identifier. When provided, a chat/completions URL is produced.
30
+
31
+ Returns
32
+ -------
33
+ str
34
+ Normalized URL for the given deployment. Ensures a trailing slash unless the path
35
+ ends with "chat/completions" or already has a meaningful path component.
36
+ """
37
+ # Normalize the URL and drop a trailing /api/v2 if present
38
+ parsed = urlparse(api_base)
39
+ path = re.sub(r"/api/v2/?$", "", parsed.path)
40
+ base_url = urlunparse(
41
+ (
42
+ parsed.scheme,
43
+ parsed.netloc,
44
+ path,
45
+ parsed.params,
46
+ parsed.query,
47
+ parsed.fragment,
48
+ )
49
+ )
50
+ base_url = base_url.rstrip("/")
51
+
52
+ # If the base_url already ends with chat/completions, return it.
53
+ if base_url.endswith("chat/completions"):
54
+ return base_url
55
+
56
+ # If the path contains deployments or genai, it's already a complete API path preserve it.
57
+ if path and ("deployments" in path or "genai" in path):
58
+ return f"{base_url}/" if not base_url.endswith("/") else base_url
59
+
60
+ # For all other cases (including custom base paths), apply deployment logic if needed.
61
+ if deployment_id:
62
+ return f"{base_url}/api/v2/deployments/{deployment_id}/chat/completions"
63
+ # Otherwise, just return the base URL with a trailing slash for normalization.
64
+ return f"{base_url}/"
@@ -0,0 +1,24 @@
1
+ """CrewAI utilities and helpers.
2
+
3
+ Public API:
4
+ - mcp_tools_context: Context manager returning available MCP tools for CrewAI.
5
+ - build_llm: Construct a CrewAI LLM configured for DataRobot endpoints.
6
+ - create_pipeline_interactions_from_messages: Convert messages to MultiTurnSample.
7
+ """
8
+
9
+ from datarobot_genai.core.mcp.common import MCPConfig
10
+
11
+ from .agent import build_llm
12
+ from .agent import create_pipeline_interactions_from_messages
13
+ from .base import CrewAIAgent
14
+ from .events import CrewAIEventListener
15
+ from .mcp import mcp_tools_context
16
+
17
+ __all__ = [
18
+ "mcp_tools_context",
19
+ "CrewAIAgent",
20
+ "build_llm",
21
+ "create_pipeline_interactions_from_messages",
22
+ "CrewAIEventListener",
23
+ "MCPConfig",
24
+ ]
@@ -0,0 +1,42 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from crewai import LLM
16
+ from ragas import MultiTurnSample
17
+ from ragas.messages import AIMessage
18
+ from ragas.messages import HumanMessage
19
+ from ragas.messages import ToolMessage
20
+
21
+ from datarobot_genai.core.utils.urls import get_api_base
22
+
23
+
24
+ def build_llm(
25
+ *,
26
+ api_base: str,
27
+ api_key: str | None,
28
+ model: str,
29
+ deployment_id: str | None,
30
+ timeout: int,
31
+ ) -> LLM:
32
+ """Create a CrewAI LLM configured for DataRobot LLM Gateway or deployment."""
33
+ base = get_api_base(api_base, deployment_id)
34
+ return LLM(model=model, api_base=base, api_key=api_key, timeout=timeout)
35
+
36
+
37
+ def create_pipeline_interactions_from_messages(
38
+ messages: list[HumanMessage | AIMessage | ToolMessage] | None,
39
+ ) -> MultiTurnSample | None:
40
+ if not messages:
41
+ return None
42
+ return MultiTurnSample(user_input=messages)
@@ -0,0 +1,159 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Base class for CrewAI-based agents.
17
+
18
+ Manages MCP tool lifecycle and standardizes kickoff flow.
19
+
20
+ Note: This base does not capture pipeline interactions; it returns None by
21
+ default. Subclasses may implement message capture if they need interactions.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import abc
27
+ import asyncio
28
+ from collections.abc import AsyncGenerator
29
+ from typing import Any
30
+
31
+ from crewai import Crew
32
+ from crewai.events.event_bus import CrewAIEventsBus
33
+ from crewai.tools import BaseTool
34
+ from openai.types.chat import CompletionCreateParams
35
+ from ragas import MultiTurnSample
36
+
37
+ from datarobot_genai.core.agents.base import BaseAgent
38
+ from datarobot_genai.core.agents.base import InvokeReturn
39
+ from datarobot_genai.core.agents.base import UsageMetrics
40
+ from datarobot_genai.core.agents.base import default_usage_metrics
41
+ from datarobot_genai.core.agents.base import extract_user_prompt_content
42
+ from datarobot_genai.core.agents.base import is_streaming
43
+
44
+ from .agent import create_pipeline_interactions_from_messages
45
+ from .mcp import mcp_tools_context
46
+
47
+
48
+ class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
49
+ """Abstract base agent for CrewAI workflows.
50
+
51
+ Subclasses should define the ``agents`` and ``tasks`` properties
52
+ and may override ``build_crewai_workflow`` to customize the workflow
53
+ construction.
54
+ """
55
+
56
+ @property
57
+ @abc.abstractmethod
58
+ def agents(self) -> list[Any]: # CrewAI Agent list
59
+ raise NotImplementedError
60
+
61
+ @property
62
+ @abc.abstractmethod
63
+ def tasks(self) -> list[Any]: # CrewAI Task list
64
+ raise NotImplementedError
65
+
66
+ def build_crewai_workflow(self) -> Any:
67
+ """Create a CrewAI workflow instance.
68
+
69
+ Default implementation constructs a Crew with provided agents and tasks.
70
+ Subclasses can override to customize Crew options.
71
+ """
72
+ return Crew(agents=self.agents, tasks=self.tasks, verbose=self.verbose)
73
+
74
+ @abc.abstractmethod
75
+ def make_kickoff_inputs(self, user_prompt_content: str) -> dict[str, Any]:
76
+ """Build the inputs dict for ``Crew.kickoff``.
77
+
78
+ Subclasses must implement this to provide the exact inputs required
79
+ by their CrewAI tasks.
80
+ """
81
+ raise NotImplementedError
82
+
83
+ def _extract_pipeline_interactions(self) -> MultiTurnSample | None:
84
+ """Extract pipeline interactions from event listener if available."""
85
+ if not hasattr(self, "event_listener"):
86
+ return None
87
+ try:
88
+ listener = getattr(self, "event_listener", None)
89
+ messages = getattr(listener, "messages", None) if listener is not None else None
90
+ return create_pipeline_interactions_from_messages(messages)
91
+ except Exception:
92
+ return None
93
+
94
+ def _extract_usage_metrics(self, crew_output: Any) -> UsageMetrics:
95
+ """Extract usage metrics from crew output."""
96
+ token_usage = getattr(crew_output, "token_usage", None)
97
+ if token_usage is not None:
98
+ return {
99
+ "completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
100
+ "prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
101
+ "total_tokens": int(getattr(token_usage, "total_tokens", 0)),
102
+ }
103
+ return default_usage_metrics()
104
+
105
+ def _process_crew_output(
106
+ self, crew_output: Any
107
+ ) -> tuple[str, MultiTurnSample | None, UsageMetrics]:
108
+ """Process crew output into response tuple."""
109
+ response_text = str(crew_output.raw)
110
+ pipeline_interactions = self._extract_pipeline_interactions()
111
+ usage_metrics = self._extract_usage_metrics(crew_output)
112
+ return response_text, pipeline_interactions, usage_metrics
113
+
114
+ async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
115
+ """Run the CrewAI workflow with the provided completion parameters."""
116
+ user_prompt_content = extract_user_prompt_content(completion_create_params)
117
+ # Preserve prior template startup print for CLI parity
118
+ try:
119
+ print("Running agent with user prompt:", user_prompt_content, flush=True)
120
+ except Exception:
121
+ # Printing is best-effort; proceed regardless
122
+ pass
123
+
124
+ # Use MCP context manager to handle connection lifecycle
125
+ with mcp_tools_context(
126
+ authorization_context=self._authorization_context,
127
+ forwarded_headers=self.forwarded_headers,
128
+ ) as mcp_tools:
129
+ # Set MCP tools for all agents if MCP is not configured this is effectively a no-op
130
+ self.set_mcp_tools(mcp_tools)
131
+
132
+ # If an event listener is provided by the subclass/template, register it
133
+ if hasattr(self, "event_listener") and CrewAIEventsBus is not None:
134
+ try:
135
+ listener = getattr(self, "event_listener")
136
+ setup_fn = getattr(listener, "setup_listeners", None)
137
+ if callable(setup_fn):
138
+ setup_fn(CrewAIEventsBus)
139
+ except Exception:
140
+ # Listener is optional best-effort; proceed without failing invoke
141
+ pass
142
+
143
+ crew = self.build_crewai_workflow()
144
+
145
+ if is_streaming(completion_create_params):
146
+
147
+ async def _gen() -> AsyncGenerator[
148
+ tuple[str, MultiTurnSample | None, UsageMetrics]
149
+ ]:
150
+ crew_output = await asyncio.to_thread(
151
+ crew.kickoff,
152
+ inputs=self.make_kickoff_inputs(user_prompt_content),
153
+ )
154
+ yield self._process_crew_output(crew_output)
155
+
156
+ return _gen()
157
+
158
+ crew_output = crew.kickoff(inputs=self.make_kickoff_inputs(user_prompt_content))
159
+ return self._process_crew_output(crew_output)
@@ -0,0 +1,117 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from __future__ import annotations
15
+
16
+ import importlib
17
+ import json
18
+ import logging
19
+ from typing import Any
20
+
21
+ from ragas.messages import AIMessage
22
+ from ragas.messages import HumanMessage
23
+ from ragas.messages import ToolCall
24
+ from ragas.messages import ToolMessage
25
+
26
+ # Resolve crewai symbols at runtime to avoid mypy issues with untyped packages
27
+ try:
28
+ _events_mod = importlib.import_module("crewai.events.event_types")
29
+ AgentExecutionCompletedEvent = getattr(_events_mod, "AgentExecutionCompletedEvent")
30
+ AgentExecutionStartedEvent = getattr(_events_mod, "AgentExecutionStartedEvent")
31
+ CrewKickoffStartedEvent = getattr(_events_mod, "CrewKickoffStartedEvent")
32
+ ToolUsageFinishedEvent = getattr(_events_mod, "ToolUsageFinishedEvent")
33
+ ToolUsageStartedEvent = getattr(_events_mod, "ToolUsageStartedEvent")
34
+
35
+ _bus_mod = importlib.import_module("crewai.events.event_bus")
36
+ CrewAIEventsBus = getattr(_bus_mod, "CrewAIEventsBus")
37
+
38
+ _base_mod = importlib.import_module("crewai.events.base_event_listener")
39
+ _RuntimeBaseEventListener = getattr(_base_mod, "BaseEventListener")
40
+ except Exception:
41
+ try: # pragma: no cover - compatibility for older crewai
42
+ _events_mod = importlib.import_module("crewai.utilities.events")
43
+ AgentExecutionCompletedEvent = getattr(_events_mod, "AgentExecutionCompletedEvent")
44
+ AgentExecutionStartedEvent = getattr(_events_mod, "AgentExecutionStartedEvent")
45
+ CrewKickoffStartedEvent = getattr(_events_mod, "CrewKickoffStartedEvent")
46
+ ToolUsageFinishedEvent = getattr(_events_mod, "ToolUsageFinishedEvent")
47
+ ToolUsageStartedEvent = getattr(_events_mod, "ToolUsageStartedEvent")
48
+
49
+ _bus_mod = importlib.import_module("crewai.utilities.events")
50
+ CrewAIEventsBus = getattr(_bus_mod, "CrewAIEventsBus")
51
+ _base_mod = importlib.import_module("crewai.utilities.events.base_event_listener")
52
+ _RuntimeBaseEventListener = getattr(_base_mod, "BaseEventListener")
53
+ except Exception:
54
+ raise ImportError(
55
+ "CrewAI is required for datarobot_genai.crewai.* modules. "
56
+ "Install with the CrewAI extra:\n"
57
+ " install 'datarobot-genai[crewai]'"
58
+ )
59
+
60
+
61
+ class CrewAIEventListener:
62
+ """Collects CrewAI events into Ragas messages for pipeline interactions."""
63
+
64
+ def __init__(self) -> None:
65
+ self.messages: list[HumanMessage | AIMessage | ToolMessage] = []
66
+
67
+ def setup_listeners(self, crewai_event_bus: Any) -> None:
68
+ @crewai_event_bus.on(CrewKickoffStartedEvent)
69
+ def on_crew_execution_started(_: Any, event: Any) -> None:
70
+ self.messages.append(
71
+ HumanMessage(content=f"Working on input '{json.dumps(event.inputs)}'")
72
+ )
73
+
74
+ @crewai_event_bus.on(AgentExecutionStartedEvent)
75
+ def on_agent_execution_started(_: Any, event: Any) -> None:
76
+ self.messages.append(AIMessage(content=event.task_prompt, tool_calls=[]))
77
+
78
+ @crewai_event_bus.on(AgentExecutionCompletedEvent)
79
+ def on_agent_execution_completed(_: Any, event: Any) -> None:
80
+ self.messages.append(AIMessage(content=event.output, tool_calls=[]))
81
+
82
+ @crewai_event_bus.on(ToolUsageStartedEvent)
83
+ def on_tool_usage_started(_: Any, event: Any) -> None:
84
+ # It's a tool call - add tool call to last AIMessage
85
+ if len(self.messages) == 0:
86
+ logging.warning("Direct tool usage without agent invocation")
87
+ return
88
+ last_message = self.messages[-1]
89
+ if not isinstance(last_message, AIMessage):
90
+ logging.warning(
91
+ "Tool call must be preceded by an AIMessage somewhere in the conversation."
92
+ )
93
+ return
94
+ if isinstance(event.tool_args, (str, bytes, bytearray)):
95
+ parsed_args: Any = json.loads(event.tool_args)
96
+ else:
97
+ parsed_args = event.tool_args
98
+ tool_call = ToolCall(name=event.tool_name, args=parsed_args)
99
+ if last_message.tool_calls is None:
100
+ last_message.tool_calls = []
101
+ last_message.tool_calls.append(tool_call)
102
+
103
+ @crewai_event_bus.on(ToolUsageFinishedEvent)
104
+ def on_tool_usage_finished(_: Any, event: Any) -> None:
105
+ if len(self.messages) == 0:
106
+ logging.warning("Direct tool usage without agent invocation")
107
+ return
108
+ last_message = self.messages[-1]
109
+ if not isinstance(last_message, AIMessage):
110
+ logging.warning(
111
+ "Tool call must be preceded by an AIMessage somewhere in the conversation."
112
+ )
113
+ return
114
+ if not last_message.tool_calls:
115
+ logging.warning("No previous tool calls found")
116
+ return
117
+ self.messages.append(ToolMessage(content=event.output))