datarobot-genai 0.2.11__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/drmcp/core/utils.py +7 -0
- datarobot_genai/drmcp/test_utils/elicitation_test_tool.py +89 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +7 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +9 -1
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +17 -4
- datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +71 -8
- datarobot_genai/drmcp/test_utils/test_interactive.py +205 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +22 -20
- datarobot_genai/drmcp/tools/clients/confluence.py +201 -4
- datarobot_genai/drmcp/tools/clients/jira.py +193 -6
- datarobot_genai/drmcp/tools/confluence/tools.py +109 -2
- datarobot_genai/drmcp/tools/jira/tools.py +192 -1
- datarobot_genai/drmcp/tools/predictive/data.py +60 -32
- datarobot_genai/nat/agent.py +20 -7
- datarobot_genai/nat/datarobot_llm_clients.py +45 -12
- datarobot_genai/nat/helpers.py +87 -0
- {datarobot_genai-0.2.11.dist-info → datarobot_genai-0.2.19.dist-info}/METADATA +1 -1
- {datarobot_genai-0.2.11.dist-info → datarobot_genai-0.2.19.dist-info}/RECORD +22 -19
- {datarobot_genai-0.2.11.dist-info → datarobot_genai-0.2.19.dist-info}/WHEEL +0 -0
- {datarobot_genai-0.2.11.dist-info → datarobot_genai-0.2.19.dist-info}/entry_points.txt +0 -0
- {datarobot_genai-0.2.11.dist-info → datarobot_genai-0.2.19.dist-info}/licenses/AUTHORS +0 -0
- {datarobot_genai-0.2.11.dist-info → datarobot_genai-0.2.19.dist-info}/licenses/LICENSE +0 -0
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import logging
|
|
16
16
|
from typing import Annotated
|
|
17
|
+
from typing import Any
|
|
17
18
|
|
|
18
19
|
from fastmcp.exceptions import ToolError
|
|
19
20
|
from fastmcp.tools.tool import ToolResult
|
|
@@ -25,6 +26,40 @@ from datarobot_genai.drmcp.tools.clients.jira import JiraClient
|
|
|
25
26
|
logger = logging.getLogger(__name__)
|
|
26
27
|
|
|
27
28
|
|
|
29
|
+
@dr_mcp_tool(tags={"jira", "search", "issues"})
|
|
30
|
+
async def jira_search_issues(
|
|
31
|
+
*,
|
|
32
|
+
jql_query: Annotated[
|
|
33
|
+
str, "The JQL (Jira Query Language) string used to filter and search for issues."
|
|
34
|
+
],
|
|
35
|
+
max_results: Annotated[int, "Maximum number of issues to return. Default is 50."] = 50,
|
|
36
|
+
) -> ToolResult:
|
|
37
|
+
"""
|
|
38
|
+
Search for Jira issues using a powerful JQL query string.
|
|
39
|
+
|
|
40
|
+
Refer to JQL documentation for advanced query construction:
|
|
41
|
+
JQL functions: https://support.atlassian.com/jira-service-management-cloud/docs/jql-functions/
|
|
42
|
+
JQL fields: https://support.atlassian.com/jira-service-management-cloud/docs/jql-fields/
|
|
43
|
+
JQL keywords: https://support.atlassian.com/jira-service-management-cloud/docs/use-advanced-search-with-jira-query-language-jql/
|
|
44
|
+
JQL operators: https://support.atlassian.com/jira-service-management-cloud/docs/jql-operators/
|
|
45
|
+
"""
|
|
46
|
+
if not jql_query:
|
|
47
|
+
raise ToolError("Argument validation error: 'jql_query' cannot be empty.")
|
|
48
|
+
|
|
49
|
+
access_token = await get_atlassian_access_token()
|
|
50
|
+
if isinstance(access_token, ToolError):
|
|
51
|
+
raise access_token
|
|
52
|
+
|
|
53
|
+
async with JiraClient(access_token) as client:
|
|
54
|
+
issues = await client.search_jira_issues(jql_query=jql_query, max_results=max_results)
|
|
55
|
+
|
|
56
|
+
n = len(issues)
|
|
57
|
+
return ToolResult(
|
|
58
|
+
content=f"Successfully executed JQL query and retrieved {n} issue(s).",
|
|
59
|
+
structured_content={"data": [issue.as_flat_dict() for issue in issues], "count": n},
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
28
63
|
@dr_mcp_tool(tags={"jira", "read", "get", "issue"})
|
|
29
64
|
async def jira_get_issue(
|
|
30
65
|
*, issue_key: Annotated[str, "The key (ID) of the Jira issue to retrieve, e.g., 'PROJ-123'."]
|
|
@@ -41,7 +76,7 @@ async def jira_get_issue(
|
|
|
41
76
|
async with JiraClient(access_token) as client:
|
|
42
77
|
issue = await client.get_jira_issue(issue_key)
|
|
43
78
|
except Exception as e:
|
|
44
|
-
logger.error(f"Unexpected error getting Jira issue: {e}")
|
|
79
|
+
logger.error(f"Unexpected error while getting Jira issue: {e}")
|
|
45
80
|
raise ToolError(
|
|
46
81
|
f"An unexpected error occurred while getting Jira issue '{issue_key}': {str(e)}"
|
|
47
82
|
)
|
|
@@ -50,3 +85,159 @@ async def jira_get_issue(
|
|
|
50
85
|
content=f"Successfully retrieved details for issue '{issue_key}'.",
|
|
51
86
|
structured_content=issue.as_flat_dict(),
|
|
52
87
|
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dr_mcp_tool(tags={"jira", "create", "add", "issue"})
|
|
91
|
+
async def jira_create_issue(
|
|
92
|
+
*,
|
|
93
|
+
project_key: Annotated[str, "The key of the project where the issue should be created."],
|
|
94
|
+
summary: Annotated[str, "A brief summary or title for the new issue."],
|
|
95
|
+
issue_type: Annotated[str, "The type of issue to create (e.g., 'Task', 'Bug', 'Story')."],
|
|
96
|
+
description: Annotated[str | None, "Detailed description of the issue."] = None,
|
|
97
|
+
) -> ToolResult:
|
|
98
|
+
"""Create a new Jira issue with mandatory project, summary, and type information."""
|
|
99
|
+
if not all([project_key, summary, issue_type]):
|
|
100
|
+
raise ToolError(
|
|
101
|
+
"Argument validation error: project_key, summary, and issue_type are required fields."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
access_token = await get_atlassian_access_token()
|
|
105
|
+
if isinstance(access_token, ToolError):
|
|
106
|
+
raise access_token
|
|
107
|
+
|
|
108
|
+
async with JiraClient(access_token) as client:
|
|
109
|
+
# Maybe we should cache it somehow?
|
|
110
|
+
# It'll be probably constant through whole mcp server lifecycle...
|
|
111
|
+
issue_types = await client.get_jira_issue_types(project_key=project_key)
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
issue_type_id = issue_types[issue_type]
|
|
115
|
+
except KeyError:
|
|
116
|
+
possible_issue_types = ",".join(issue_types)
|
|
117
|
+
raise ToolError(
|
|
118
|
+
f"Unexpected issue type `{issue_type}`. Possible values are {possible_issue_types}."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
async with JiraClient(access_token) as client:
|
|
123
|
+
issue_key = await client.create_jira_issue(
|
|
124
|
+
project_key=project_key,
|
|
125
|
+
summary=summary,
|
|
126
|
+
issue_type_id=issue_type_id,
|
|
127
|
+
description=description,
|
|
128
|
+
)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.error(f"Unexpected error while creating Jira issue: {e}")
|
|
131
|
+
raise ToolError(f"An unexpected error occurred while creating Jira issue: {str(e)}")
|
|
132
|
+
|
|
133
|
+
return ToolResult(
|
|
134
|
+
content=f"Successfully created issue '{issue_key}'.",
|
|
135
|
+
structured_content={"newIssueKey": issue_key, "projectKey": project_key},
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@dr_mcp_tool(tags={"jira", "update", "edit", "issue"})
|
|
140
|
+
async def jira_update_issue(
|
|
141
|
+
*,
|
|
142
|
+
issue_key: Annotated[str, "The key (ID) of the Jira issue to retrieve, e.g., 'PROJ-123'."],
|
|
143
|
+
fields_to_update: Annotated[
|
|
144
|
+
dict[str, Any],
|
|
145
|
+
"A dictionary of field names and their new values (e.g., {'summary': 'New content'}).",
|
|
146
|
+
],
|
|
147
|
+
) -> ToolResult:
|
|
148
|
+
"""
|
|
149
|
+
Modify descriptive fields or custom fields on an existing Jira issue using its key.
|
|
150
|
+
If you want to update issue status you should use `jira_transition_issue` tool instead.
|
|
151
|
+
|
|
152
|
+
Some fields needs very specific schema to allow update.
|
|
153
|
+
You should follow jira rest api guidance.
|
|
154
|
+
Good example is description field:
|
|
155
|
+
"description": {
|
|
156
|
+
"type": "text",
|
|
157
|
+
"version": 1,
|
|
158
|
+
"text": [
|
|
159
|
+
{
|
|
160
|
+
"type": "paragraph",
|
|
161
|
+
"content": [
|
|
162
|
+
{
|
|
163
|
+
"type": "text",
|
|
164
|
+
"text": "[HERE YOU PUT REAL DESCRIPTION]"
|
|
165
|
+
}
|
|
166
|
+
]
|
|
167
|
+
}
|
|
168
|
+
]
|
|
169
|
+
}
|
|
170
|
+
"""
|
|
171
|
+
if not issue_key:
|
|
172
|
+
raise ToolError("Argument validation error: 'issue_key' cannot be empty.")
|
|
173
|
+
if not fields_to_update or not isinstance(fields_to_update, dict):
|
|
174
|
+
raise ToolError(
|
|
175
|
+
"Argument validation error: 'fields_to_update' must be a non-empty dictionary."
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
access_token = await get_atlassian_access_token()
|
|
179
|
+
if isinstance(access_token, ToolError):
|
|
180
|
+
raise access_token
|
|
181
|
+
|
|
182
|
+
try:
|
|
183
|
+
async with JiraClient(access_token) as client:
|
|
184
|
+
updated_fields = await client.update_jira_issue(
|
|
185
|
+
issue_key=issue_key, fields=fields_to_update
|
|
186
|
+
)
|
|
187
|
+
except Exception as e:
|
|
188
|
+
logger.error(f"Unexpected error while updating Jira issue: {e}")
|
|
189
|
+
raise ToolError(f"An unexpected error occurred while updating Jira issue: {str(e)}")
|
|
190
|
+
|
|
191
|
+
updated_fields_str = ",".join(updated_fields)
|
|
192
|
+
return ToolResult(
|
|
193
|
+
content=f"Successfully updated issue '{issue_key}'. Fields modified: {updated_fields_str}.",
|
|
194
|
+
structured_content={"updatedIssueKey": issue_key, "fields": updated_fields},
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
@dr_mcp_tool(tags={"jira", "update", "transition", "issue"})
|
|
199
|
+
async def jira_transition_issue(
|
|
200
|
+
*,
|
|
201
|
+
issue_key: Annotated[str, "The key (ID) of the Jira issue to transition, e.g. 'PROJ-123'."],
|
|
202
|
+
transition_name: Annotated[
|
|
203
|
+
str, "The exact name of the target status/transition (e.g., 'In Progress')."
|
|
204
|
+
],
|
|
205
|
+
) -> ToolResult:
|
|
206
|
+
"""
|
|
207
|
+
Move a Jira issue through its defined workflow to a new status.
|
|
208
|
+
This leverages Jira's workflow engine directly.
|
|
209
|
+
"""
|
|
210
|
+
if not all([issue_key, transition_name]):
|
|
211
|
+
raise ToolError("Argument validation error: issue_key and transition name/ID are required.")
|
|
212
|
+
|
|
213
|
+
access_token = await get_atlassian_access_token()
|
|
214
|
+
if isinstance(access_token, ToolError):
|
|
215
|
+
raise access_token
|
|
216
|
+
|
|
217
|
+
async with JiraClient(access_token) as client:
|
|
218
|
+
available_transitions = await client.get_available_jira_transitions(issue_key=issue_key)
|
|
219
|
+
|
|
220
|
+
try:
|
|
221
|
+
transition_id = available_transitions[transition_name]
|
|
222
|
+
except KeyError:
|
|
223
|
+
available_transitions_str = ",".join(available_transitions)
|
|
224
|
+
raise ToolError(
|
|
225
|
+
f"Unexpected transition name `{transition_name}`. "
|
|
226
|
+
f"Possible values are {available_transitions_str}."
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
try:
|
|
230
|
+
async with JiraClient(access_token) as client:
|
|
231
|
+
await client.transition_jira_issue(issue_key=issue_key, transition_id=transition_id)
|
|
232
|
+
except Exception as e:
|
|
233
|
+
logger.error(f"Unexpected error while transitioning Jira issue: {e}")
|
|
234
|
+
raise ToolError(f"An unexpected error occurred while transitioning Jira issue: {str(e)}")
|
|
235
|
+
|
|
236
|
+
return ToolResult(
|
|
237
|
+
content=f"Successfully transitioned issue '{issue_key}' to status '{transition_name}'.",
|
|
238
|
+
structured_content={
|
|
239
|
+
"transitionedIssueKey": issue_key,
|
|
240
|
+
"newStatusName": transition_name,
|
|
241
|
+
"newStatusId": transition_id,
|
|
242
|
+
},
|
|
243
|
+
)
|
|
@@ -14,51 +14,79 @@
|
|
|
14
14
|
|
|
15
15
|
import logging
|
|
16
16
|
import os
|
|
17
|
+
from typing import Annotated
|
|
18
|
+
|
|
19
|
+
from fastmcp.exceptions import ToolError
|
|
20
|
+
from fastmcp.tools.tool import ToolResult
|
|
17
21
|
|
|
18
22
|
from datarobot_genai.drmcp.core.clients import get_sdk_client
|
|
19
23
|
from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
|
|
24
|
+
from datarobot_genai.drmcp.core.utils import is_valid_url
|
|
20
25
|
|
|
21
26
|
logger = logging.getLogger(__name__)
|
|
22
27
|
|
|
23
28
|
|
|
24
|
-
@dr_mcp_tool(tags={"data", "
|
|
25
|
-
async def upload_dataset_to_ai_catalog(
|
|
26
|
-
""
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
29
|
+
@dr_mcp_tool(tags={"predictive", "data", "write", "upload", "catalog"})
|
|
30
|
+
async def upload_dataset_to_ai_catalog(
|
|
31
|
+
file_path: Annotated[str, "The path to the dataset file to upload."] | None = None,
|
|
32
|
+
file_url: Annotated[str, "The URL to the dataset file to upload."] | None = None,
|
|
33
|
+
) -> ToolError | ToolResult:
|
|
34
|
+
"""Upload a dataset to the DataRobot AI Catalog / Data Registry."""
|
|
35
|
+
if not file_path and not file_url:
|
|
36
|
+
return ToolError("Either file_path or file_url must be provided.")
|
|
37
|
+
if file_path and file_url:
|
|
38
|
+
return ToolError("Please provide either file_path or file_url, not both.")
|
|
31
39
|
|
|
32
|
-
|
|
33
|
-
-------
|
|
34
|
-
A string summary of the upload result.
|
|
35
|
-
"""
|
|
40
|
+
# Get client
|
|
36
41
|
client = get_sdk_client()
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
42
|
+
catalog_item = None
|
|
43
|
+
# If file path is provided, create dataset from file.
|
|
44
|
+
if file_path:
|
|
45
|
+
# Does file exist?
|
|
46
|
+
if not os.path.exists(file_path):
|
|
47
|
+
logger.error("File not found: %s", file_path)
|
|
48
|
+
return ToolError(f"File not found: {file_path}")
|
|
49
|
+
catalog_item = client.Dataset.create_from_file(file_path)
|
|
50
|
+
else:
|
|
51
|
+
# Does URL exist?
|
|
52
|
+
if file_url is None or not is_valid_url(file_url):
|
|
53
|
+
logger.error("Invalid file URL: %s", file_url)
|
|
54
|
+
return ToolError(f"Invalid file URL: {file_url}")
|
|
55
|
+
catalog_item = client.Dataset.create_from_url(file_url)
|
|
56
|
+
|
|
57
|
+
if not catalog_item:
|
|
58
|
+
return ToolError("Failed to upload dataset.")
|
|
59
|
+
|
|
60
|
+
return ToolResult(
|
|
61
|
+
content=f"Successfully uploaded dataset: {catalog_item.id}",
|
|
62
|
+
structured_content={
|
|
63
|
+
"dataset_id": catalog_item.id,
|
|
64
|
+
"dataset_version_id": catalog_item.version_id,
|
|
65
|
+
"dataset_name": catalog_item.name,
|
|
66
|
+
},
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dr_mcp_tool(tags={"predictive", "data", "read", "list", "catalog"})
|
|
71
|
+
async def list_ai_catalog_items() -> ToolResult:
|
|
72
|
+
"""List all AI Catalog items (datasets) for the authenticated user."""
|
|
54
73
|
client = get_sdk_client()
|
|
55
74
|
datasets = client.Dataset.list()
|
|
75
|
+
|
|
56
76
|
if not datasets:
|
|
57
77
|
logger.info("No AI Catalog items found")
|
|
58
|
-
return
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
78
|
+
return ToolResult(
|
|
79
|
+
content="No AI Catalog items found.",
|
|
80
|
+
structured_content={"datasets": []},
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return ToolResult(
|
|
84
|
+
content=f"Found {len(datasets)} AI Catalog items.",
|
|
85
|
+
structured_content={
|
|
86
|
+
"datasets": [{"id": ds.id, "name": ds.name} for ds in datasets],
|
|
87
|
+
"count": len(datasets),
|
|
88
|
+
},
|
|
89
|
+
)
|
|
62
90
|
|
|
63
91
|
|
|
64
92
|
# from fastmcp import Context
|
datarobot_genai/nat/agent.py
CHANGED
|
@@ -21,7 +21,6 @@ from nat.data_models.api_server import ChatRequest
|
|
|
21
21
|
from nat.data_models.api_server import ChatResponse
|
|
22
22
|
from nat.data_models.intermediate_step import IntermediateStep
|
|
23
23
|
from nat.data_models.intermediate_step import IntermediateStepType
|
|
24
|
-
from nat.runtime.loader import load_workflow
|
|
25
24
|
from nat.utils.type_utils import StrPath
|
|
26
25
|
from openai.types.chat import CompletionCreateParams
|
|
27
26
|
from ragas import MultiTurnSample
|
|
@@ -34,6 +33,8 @@ from datarobot_genai.core.agents.base import InvokeReturn
|
|
|
34
33
|
from datarobot_genai.core.agents.base import UsageMetrics
|
|
35
34
|
from datarobot_genai.core.agents.base import extract_user_prompt_content
|
|
36
35
|
from datarobot_genai.core.agents.base import is_streaming
|
|
36
|
+
from datarobot_genai.core.mcp.common import MCPConfig
|
|
37
|
+
from datarobot_genai.nat.helpers import load_workflow
|
|
37
38
|
|
|
38
39
|
logger = logging.getLogger(__name__)
|
|
39
40
|
|
|
@@ -166,17 +167,24 @@ class NatAgent(BaseAgent[None]):
|
|
|
166
167
|
# Print commands may need flush=True to ensure they are displayed in real-time.
|
|
167
168
|
print("Running agent with user prompt:", chat_request.messages[0].content, flush=True)
|
|
168
169
|
|
|
170
|
+
mcp_config = MCPConfig(
|
|
171
|
+
authorization_context=self.authorization_context,
|
|
172
|
+
forwarded_headers=self.forwarded_headers,
|
|
173
|
+
)
|
|
174
|
+
server_config = mcp_config.server_config
|
|
175
|
+
headers = server_config["headers"] if server_config else None
|
|
176
|
+
|
|
169
177
|
if is_streaming(completion_create_params):
|
|
170
178
|
|
|
171
179
|
async def stream_generator() -> AsyncGenerator[
|
|
172
180
|
tuple[str, MultiTurnSample | None, UsageMetrics], None
|
|
173
181
|
]:
|
|
174
|
-
|
|
182
|
+
default_usage_metrics: UsageMetrics = {
|
|
175
183
|
"completion_tokens": 0,
|
|
176
184
|
"prompt_tokens": 0,
|
|
177
185
|
"total_tokens": 0,
|
|
178
186
|
}
|
|
179
|
-
async with load_workflow(self.workflow_path) as workflow:
|
|
187
|
+
async with load_workflow(self.workflow_path, headers=headers) as workflow:
|
|
180
188
|
async with workflow.run(chat_request) as runner:
|
|
181
189
|
intermediate_future = pull_intermediate_structured()
|
|
182
190
|
async for result in runner.result_stream():
|
|
@@ -188,7 +196,7 @@ class NatAgent(BaseAgent[None]):
|
|
|
188
196
|
yield (
|
|
189
197
|
result_text,
|
|
190
198
|
None,
|
|
191
|
-
|
|
199
|
+
default_usage_metrics,
|
|
192
200
|
)
|
|
193
201
|
|
|
194
202
|
steps = await intermediate_future
|
|
@@ -197,6 +205,11 @@ class NatAgent(BaseAgent[None]):
|
|
|
197
205
|
for step in steps
|
|
198
206
|
if step.event_type == IntermediateStepType.LLM_END
|
|
199
207
|
]
|
|
208
|
+
usage_metrics: UsageMetrics = {
|
|
209
|
+
"completion_tokens": 0,
|
|
210
|
+
"prompt_tokens": 0,
|
|
211
|
+
"total_tokens": 0,
|
|
212
|
+
}
|
|
200
213
|
for step in llm_end_steps:
|
|
201
214
|
if step.usage_info:
|
|
202
215
|
token_usage = step.usage_info.token_usage
|
|
@@ -210,7 +223,7 @@ class NatAgent(BaseAgent[None]):
|
|
|
210
223
|
return stream_generator()
|
|
211
224
|
|
|
212
225
|
# Create and invoke the NAT (Nemo Agent Toolkit) Agentic Workflow with the inputs
|
|
213
|
-
result, steps = await self.run_nat_workflow(self.workflow_path, chat_request)
|
|
226
|
+
result, steps = await self.run_nat_workflow(self.workflow_path, chat_request, headers)
|
|
214
227
|
|
|
215
228
|
llm_end_steps = [step for step in steps if step.event_type == IntermediateStepType.LLM_END]
|
|
216
229
|
usage_metrics: UsageMetrics = {
|
|
@@ -234,7 +247,7 @@ class NatAgent(BaseAgent[None]):
|
|
|
234
247
|
return result_text, pipeline_interactions, usage_metrics
|
|
235
248
|
|
|
236
249
|
async def run_nat_workflow(
|
|
237
|
-
self, workflow_path: StrPath, chat_request: ChatRequest
|
|
250
|
+
self, workflow_path: StrPath, chat_request: ChatRequest, headers: dict[str, str] | None
|
|
238
251
|
) -> tuple[ChatResponse | str, list[IntermediateStep]]:
|
|
239
252
|
"""Run the NAT workflow with the provided config file and input string.
|
|
240
253
|
|
|
@@ -247,7 +260,7 @@ class NatAgent(BaseAgent[None]):
|
|
|
247
260
|
ChatResponse | str: The result from the NAT workflow
|
|
248
261
|
list[IntermediateStep]: The list of intermediate steps
|
|
249
262
|
"""
|
|
250
|
-
async with load_workflow(workflow_path) as workflow:
|
|
263
|
+
async with load_workflow(workflow_path, headers=headers) as workflow:
|
|
251
264
|
async with workflow.run(chat_request) as runner:
|
|
252
265
|
intermediate_future = pull_intermediate_structured()
|
|
253
266
|
runner_outputs = await runner.result()
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
from collections.abc import AsyncGenerator
|
|
16
16
|
from typing import Any
|
|
17
|
+
from typing import TypeVar
|
|
17
18
|
|
|
18
19
|
from crewai import LLM
|
|
19
20
|
from langchain_openai import ChatOpenAI
|
|
@@ -22,12 +23,32 @@ from llama_index.llms.litellm import LiteLLM
|
|
|
22
23
|
from nat.builder.builder import Builder
|
|
23
24
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
24
25
|
from nat.cli.register_workflow import register_llm_client
|
|
26
|
+
from nat.data_models.llm import LLMBaseConfig
|
|
27
|
+
from nat.data_models.retry_mixin import RetryMixin
|
|
28
|
+
from nat.plugins.langchain.llm import (
|
|
29
|
+
_patch_llm_based_on_config as langchain_patch_llm_based_on_config,
|
|
30
|
+
)
|
|
31
|
+
from nat.utils.exception_handlers.automatic_retries import patch_with_retry
|
|
25
32
|
|
|
26
33
|
from ..nat.datarobot_llm_providers import DataRobotLLMComponentModelConfig
|
|
27
34
|
from ..nat.datarobot_llm_providers import DataRobotLLMDeploymentModelConfig
|
|
28
35
|
from ..nat.datarobot_llm_providers import DataRobotLLMGatewayModelConfig
|
|
29
36
|
from ..nat.datarobot_llm_providers import DataRobotNIMModelConfig
|
|
30
37
|
|
|
38
|
+
ModelType = TypeVar("ModelType")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType:
|
|
42
|
+
if isinstance(llm_config, RetryMixin):
|
|
43
|
+
client = patch_with_retry(
|
|
44
|
+
client,
|
|
45
|
+
retries=llm_config.num_retries,
|
|
46
|
+
retry_codes=llm_config.retry_on_status_codes,
|
|
47
|
+
retry_on_messages=llm_config.retry_on_errors,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return client
|
|
51
|
+
|
|
31
52
|
|
|
32
53
|
class DataRobotChatOpenAI(ChatOpenAI):
|
|
33
54
|
def _get_request_payload(
|
|
@@ -77,7 +98,8 @@ async def datarobot_llm_gateway_langchain(
|
|
|
77
98
|
config["base_url"] = config["base_url"] + "/genai/llmgw"
|
|
78
99
|
config["stream_options"] = {"include_usage": True}
|
|
79
100
|
config["model"] = config["model"].removeprefix("datarobot/")
|
|
80
|
-
|
|
101
|
+
client = DataRobotChatOpenAI(**config)
|
|
102
|
+
yield langchain_patch_llm_based_on_config(client, config)
|
|
81
103
|
|
|
82
104
|
|
|
83
105
|
@register_llm_client(
|
|
@@ -90,7 +112,8 @@ async def datarobot_llm_gateway_crewai(
|
|
|
90
112
|
if not config["model"].startswith("datarobot/"):
|
|
91
113
|
config["model"] = "datarobot/" + config["model"]
|
|
92
114
|
config["base_url"] = config["base_url"].removesuffix("/api/v2")
|
|
93
|
-
|
|
115
|
+
client = LLM(**config)
|
|
116
|
+
yield _patch_llm_based_on_config(client, config)
|
|
94
117
|
|
|
95
118
|
|
|
96
119
|
@register_llm_client(
|
|
@@ -103,7 +126,8 @@ async def datarobot_llm_gateway_llamaindex(
|
|
|
103
126
|
if not config["model"].startswith("datarobot/"):
|
|
104
127
|
config["model"] = "datarobot/" + config["model"]
|
|
105
128
|
config["api_base"] = config.pop("base_url").removesuffix("/api/v2")
|
|
106
|
-
|
|
129
|
+
client = DataRobotLiteLLM(**config)
|
|
130
|
+
yield _patch_llm_based_on_config(client, config)
|
|
107
131
|
|
|
108
132
|
|
|
109
133
|
@register_llm_client(
|
|
@@ -119,7 +143,8 @@ async def datarobot_llm_deployment_langchain(
|
|
|
119
143
|
)
|
|
120
144
|
config["stream_options"] = {"include_usage": True}
|
|
121
145
|
config["model"] = config["model"].removeprefix("datarobot/")
|
|
122
|
-
|
|
146
|
+
client = DataRobotChatOpenAI(**config)
|
|
147
|
+
yield langchain_patch_llm_based_on_config(client, config)
|
|
123
148
|
|
|
124
149
|
|
|
125
150
|
@register_llm_client(
|
|
@@ -136,7 +161,8 @@ async def datarobot_llm_deployment_crewai(
|
|
|
136
161
|
if not config["model"].startswith("datarobot/"):
|
|
137
162
|
config["model"] = "datarobot/" + config["model"]
|
|
138
163
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
139
|
-
|
|
164
|
+
client = LLM(**config)
|
|
165
|
+
yield _patch_llm_based_on_config(client, config)
|
|
140
166
|
|
|
141
167
|
|
|
142
168
|
@register_llm_client(
|
|
@@ -153,7 +179,8 @@ async def datarobot_llm_deployment_llamaindex(
|
|
|
153
179
|
if not config["model"].startswith("datarobot/"):
|
|
154
180
|
config["model"] = "datarobot/" + config["model"]
|
|
155
181
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
156
|
-
|
|
182
|
+
client = DataRobotLiteLLM(**config)
|
|
183
|
+
yield _patch_llm_based_on_config(client, config)
|
|
157
184
|
|
|
158
185
|
|
|
159
186
|
@register_llm_client(config_type=DataRobotNIMModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
@@ -167,7 +194,8 @@ async def datarobot_nim_langchain(
|
|
|
167
194
|
)
|
|
168
195
|
config["stream_options"] = {"include_usage": True}
|
|
169
196
|
config["model"] = config["model"].removeprefix("datarobot/")
|
|
170
|
-
|
|
197
|
+
client = DataRobotChatOpenAI(**config)
|
|
198
|
+
yield langchain_patch_llm_based_on_config(client, config)
|
|
171
199
|
|
|
172
200
|
|
|
173
201
|
@register_llm_client(config_type=DataRobotNIMModelConfig, wrapper_type=LLMFrameworkEnum.CREWAI)
|
|
@@ -182,7 +210,8 @@ async def datarobot_nim_crewai(
|
|
|
182
210
|
if not config["model"].startswith("datarobot/"):
|
|
183
211
|
config["model"] = "datarobot/" + config["model"]
|
|
184
212
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
185
|
-
|
|
213
|
+
client = LLM(**config)
|
|
214
|
+
yield _patch_llm_based_on_config(client, config)
|
|
186
215
|
|
|
187
216
|
|
|
188
217
|
@register_llm_client(config_type=DataRobotNIMModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX)
|
|
@@ -197,7 +226,8 @@ async def datarobot_nim_llamaindex(
|
|
|
197
226
|
if not config["model"].startswith("datarobot/"):
|
|
198
227
|
config["model"] = "datarobot/" + config["model"]
|
|
199
228
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
200
|
-
|
|
229
|
+
client = DataRobotLiteLLM(**config)
|
|
230
|
+
yield _patch_llm_based_on_config(client, config)
|
|
201
231
|
|
|
202
232
|
|
|
203
233
|
@register_llm_client(
|
|
@@ -212,7 +242,8 @@ async def datarobot_llm_component_langchain(
|
|
|
212
242
|
config["stream_options"] = {"include_usage": True}
|
|
213
243
|
config["model"] = config["model"].removeprefix("datarobot/")
|
|
214
244
|
config.pop("use_datarobot_llm_gateway")
|
|
215
|
-
|
|
245
|
+
client = DataRobotChatOpenAI(**config)
|
|
246
|
+
yield langchain_patch_llm_based_on_config(client, config)
|
|
216
247
|
|
|
217
248
|
|
|
218
249
|
@register_llm_client(
|
|
@@ -229,7 +260,8 @@ async def datarobot_llm_component_crewai(
|
|
|
229
260
|
else:
|
|
230
261
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
231
262
|
config.pop("use_datarobot_llm_gateway")
|
|
232
|
-
|
|
263
|
+
client = LLM(**config)
|
|
264
|
+
yield _patch_llm_based_on_config(client, config)
|
|
233
265
|
|
|
234
266
|
|
|
235
267
|
@register_llm_client(
|
|
@@ -246,4 +278,5 @@ async def datarobot_llm_component_llamaindex(
|
|
|
246
278
|
else:
|
|
247
279
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
248
280
|
config.pop("use_datarobot_llm_gateway")
|
|
249
|
-
|
|
281
|
+
client = DataRobotLiteLLM(**config)
|
|
282
|
+
yield _patch_llm_based_on_config(client, config)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from collections.abc import AsyncGenerator
|
|
16
|
+
from contextlib import asynccontextmanager
|
|
17
|
+
|
|
18
|
+
from nat.builder.workflow import Workflow
|
|
19
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
|
20
|
+
from nat.data_models.config import Config
|
|
21
|
+
from nat.runtime.loader import PluginTypes
|
|
22
|
+
from nat.runtime.loader import discover_and_register_plugins
|
|
23
|
+
from nat.runtime.session import SessionManager
|
|
24
|
+
from nat.utils.data_models.schema_validator import validate_schema
|
|
25
|
+
from nat.utils.io.yaml_tools import yaml_load
|
|
26
|
+
from nat.utils.type_utils import StrPath
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def load_config(config_file: StrPath, headers: dict[str, str] | None = None) -> Config:
|
|
30
|
+
"""
|
|
31
|
+
Load a NAT configuration file with injected headers. It ensures that all plugins are
|
|
32
|
+
loaded and then validates the configuration file against the Config schema.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
config_file : StrPath
|
|
37
|
+
The path to the configuration file
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
Config
|
|
42
|
+
The validated Config object
|
|
43
|
+
"""
|
|
44
|
+
# Ensure all of the plugins are loaded
|
|
45
|
+
discover_and_register_plugins(PluginTypes.CONFIG_OBJECT)
|
|
46
|
+
|
|
47
|
+
config_yaml = yaml_load(config_file)
|
|
48
|
+
|
|
49
|
+
add_headers_to_datarobot_mcp_auth(config_yaml, headers)
|
|
50
|
+
|
|
51
|
+
# Validate configuration adheres to NAT schemas
|
|
52
|
+
validated_nat_config = validate_schema(config_yaml, Config)
|
|
53
|
+
|
|
54
|
+
return validated_nat_config
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def add_headers_to_datarobot_mcp_auth(config_yaml: dict, headers: dict[str, str] | None) -> None:
|
|
58
|
+
if headers:
|
|
59
|
+
if authentication := config_yaml.get("authentication"):
|
|
60
|
+
for auth_name in authentication:
|
|
61
|
+
auth_config = authentication[auth_name]
|
|
62
|
+
if auth_config.get("_type") == "datarobot_mcp_auth":
|
|
63
|
+
auth_config["headers"] = headers
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@asynccontextmanager
|
|
67
|
+
async def load_workflow(
|
|
68
|
+
config_file: StrPath, max_concurrency: int = -1, headers: dict[str, str] | None = None
|
|
69
|
+
) -> AsyncGenerator[Workflow, None]:
|
|
70
|
+
"""
|
|
71
|
+
Load the NAT configuration file and create a Runner object. This is the primary entry point for
|
|
72
|
+
running NAT workflows with injected headers.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
config_file : StrPath
|
|
77
|
+
The path to the configuration file
|
|
78
|
+
max_concurrency : int, optional
|
|
79
|
+
The maximum number of parallel workflow invocations to support. Specifying 0 or -1 will
|
|
80
|
+
allow an unlimited count, by default -1
|
|
81
|
+
"""
|
|
82
|
+
# Load the config object
|
|
83
|
+
config = load_config(config_file, headers=headers)
|
|
84
|
+
|
|
85
|
+
# Must yield the workflow function otherwise it cleans up
|
|
86
|
+
async with WorkflowBuilder.from_config(config=config) as workflow:
|
|
87
|
+
yield SessionManager(await workflow.build(), max_concurrency=max_concurrency)
|