datarobot-genai 0.2.39__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/core/agents/__init__.py +1 -1
- datarobot_genai/core/agents/base.py +5 -2
- datarobot_genai/core/chat/responses.py +6 -1
- datarobot_genai/core/utils/auth.py +188 -31
- datarobot_genai/crewai/__init__.py +1 -4
- datarobot_genai/crewai/agent.py +150 -17
- datarobot_genai/crewai/events.py +11 -4
- datarobot_genai/drmcp/__init__.py +4 -2
- datarobot_genai/drmcp/core/config.py +21 -1
- datarobot_genai/drmcp/core/mcp_instance.py +5 -49
- datarobot_genai/drmcp/core/routes.py +108 -13
- datarobot_genai/drmcp/core/tool_config.py +16 -0
- datarobot_genai/drmcp/core/utils.py +110 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +41 -26
- datarobot_genai/drmcp/tools/clients/gdrive.py +2 -0
- datarobot_genai/drmcp/tools/clients/microsoft_graph.py +96 -0
- datarobot_genai/drmcp/tools/clients/perplexity.py +173 -0
- datarobot_genai/drmcp/tools/clients/tavily.py +199 -0
- datarobot_genai/drmcp/tools/confluence/tools.py +0 -5
- datarobot_genai/drmcp/tools/gdrive/tools.py +12 -59
- datarobot_genai/drmcp/tools/jira/tools.py +4 -8
- datarobot_genai/drmcp/tools/microsoft_graph/tools.py +135 -19
- datarobot_genai/drmcp/tools/perplexity/__init__.py +0 -0
- datarobot_genai/drmcp/tools/perplexity/tools.py +117 -0
- datarobot_genai/drmcp/tools/predictive/data.py +1 -9
- datarobot_genai/drmcp/tools/predictive/deployment.py +0 -8
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +0 -19
- datarobot_genai/drmcp/tools/predictive/model.py +0 -21
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +3 -0
- datarobot_genai/drmcp/tools/predictive/project.py +3 -19
- datarobot_genai/drmcp/tools/predictive/training.py +1 -19
- datarobot_genai/drmcp/tools/tavily/__init__.py +13 -0
- datarobot_genai/drmcp/tools/tavily/tools.py +141 -0
- datarobot_genai/langgraph/agent.py +10 -2
- datarobot_genai/llama_index/__init__.py +1 -1
- datarobot_genai/llama_index/agent.py +284 -5
- datarobot_genai/nat/agent.py +17 -6
- {datarobot_genai-0.2.39.dist-info → datarobot_genai-0.3.1.dist-info}/METADATA +3 -1
- {datarobot_genai-0.2.39.dist-info → datarobot_genai-0.3.1.dist-info}/RECORD +43 -40
- datarobot_genai/crewai/base.py +0 -159
- datarobot_genai/drmcp/core/tool_filter.py +0 -117
- datarobot_genai/llama_index/base.py +0 -299
- {datarobot_genai-0.2.39.dist-info → datarobot_genai-0.3.1.dist-info}/WHEEL +0 -0
- {datarobot_genai-0.2.39.dist-info → datarobot_genai-0.3.1.dist-info}/entry_points.txt +0 -0
- {datarobot_genai-0.2.39.dist-info → datarobot_genai-0.3.1.dist-info}/licenses/AUTHORS +0 -0
- {datarobot_genai-0.2.39.dist-info → datarobot_genai-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
18
|
from typing import Annotated
|
|
19
|
+
from typing import Any
|
|
19
20
|
from typing import Literal
|
|
20
21
|
|
|
21
22
|
from fastmcp.exceptions import ToolError
|
|
@@ -124,10 +125,6 @@ async def microsoft_graph_search_content(
|
|
|
124
125
|
- Documentation: https://learn.microsoft.com/en-us/graph/api/search-query
|
|
125
126
|
- Search concepts: https://learn.microsoft.com/en-us/graph/search-concept-files
|
|
126
127
|
|
|
127
|
-
Permissions:
|
|
128
|
-
- Requires Sites.Read.All or Sites.Search.All permission
|
|
129
|
-
- include_hidden_content only works with delegated permissions
|
|
130
|
-
- region parameter is required for application permissions in multi-region environments
|
|
131
128
|
"""
|
|
132
129
|
if not search_query:
|
|
133
130
|
raise ToolError("Argument validation error: 'search_query' cannot be empty.")
|
|
@@ -171,12 +168,7 @@ async def microsoft_graph_search_content(
|
|
|
171
168
|
}
|
|
172
169
|
results.append(result_dict)
|
|
173
170
|
|
|
174
|
-
n = len(results)
|
|
175
171
|
return ToolResult(
|
|
176
|
-
content=(
|
|
177
|
-
f"Successfully searched Microsoft Graph and retrieved {n} result(s) for "
|
|
178
|
-
f"'{search_query}' (from={from_offset}, size={size})."
|
|
179
|
-
),
|
|
180
172
|
structured_content={
|
|
181
173
|
"query": search_query,
|
|
182
174
|
"siteUrl": site_url,
|
|
@@ -184,7 +176,7 @@ async def microsoft_graph_search_content(
|
|
|
184
176
|
"from": from_offset,
|
|
185
177
|
"size": size,
|
|
186
178
|
"results": results,
|
|
187
|
-
"count":
|
|
179
|
+
"count": len(results),
|
|
188
180
|
},
|
|
189
181
|
)
|
|
190
182
|
|
|
@@ -234,18 +226,12 @@ async def microsoft_graph_share_item(
|
|
|
234
226
|
send_invitation=send_invitation,
|
|
235
227
|
)
|
|
236
228
|
|
|
237
|
-
n = len(recipient_emails)
|
|
238
229
|
return ToolResult(
|
|
239
|
-
content=(
|
|
240
|
-
f"Successfully shared file {file_id} "
|
|
241
|
-
f"from document library {document_library_id} "
|
|
242
|
-
f"with {n} recipients with '{role}' role."
|
|
243
|
-
),
|
|
244
230
|
structured_content={
|
|
245
231
|
"fileId": file_id,
|
|
246
232
|
"documentLibraryId": document_library_id,
|
|
247
233
|
"recipientEmails": recipient_emails,
|
|
248
|
-
"n":
|
|
234
|
+
"n": len(recipient_emails),
|
|
249
235
|
"role": role,
|
|
250
236
|
},
|
|
251
237
|
)
|
|
@@ -261,7 +247,8 @@ async def microsoft_graph_share_item(
|
|
|
261
247
|
"create",
|
|
262
248
|
"file",
|
|
263
249
|
"write",
|
|
264
|
-
}
|
|
250
|
+
},
|
|
251
|
+
enabled=False,
|
|
265
252
|
)
|
|
266
253
|
async def microsoft_create_file(
|
|
267
254
|
*,
|
|
@@ -318,7 +305,6 @@ async def microsoft_create_file(
|
|
|
318
305
|
)
|
|
319
306
|
|
|
320
307
|
return ToolResult(
|
|
321
|
-
content=f"File '{created_file.name}' created successfully.",
|
|
322
308
|
structured_content={
|
|
323
309
|
"file_name": created_file.name,
|
|
324
310
|
"destination": "onedrive" if is_personal_onedrive else "sharepoint",
|
|
@@ -328,3 +314,133 @@ async def microsoft_create_file(
|
|
|
328
314
|
"parentFolderId": created_file.parent_folder_id,
|
|
329
315
|
},
|
|
330
316
|
)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@dr_mcp_tool(
|
|
320
|
+
tags={
|
|
321
|
+
"microsoft",
|
|
322
|
+
"graph api",
|
|
323
|
+
"sharepoint",
|
|
324
|
+
"onedrive",
|
|
325
|
+
"metadata",
|
|
326
|
+
"update",
|
|
327
|
+
"fields",
|
|
328
|
+
"compliance",
|
|
329
|
+
},
|
|
330
|
+
enabled=False,
|
|
331
|
+
)
|
|
332
|
+
async def microsoft_update_metadata(
|
|
333
|
+
*,
|
|
334
|
+
item_id: Annotated[str, "The ID of the file or list item to update."],
|
|
335
|
+
fields_to_update: Annotated[
|
|
336
|
+
dict[str, Any],
|
|
337
|
+
"Key-value pairs of metadata fields to modify. "
|
|
338
|
+
"For SharePoint list items: any custom column values. "
|
|
339
|
+
"For drive items: 'name' and/or 'description'.",
|
|
340
|
+
],
|
|
341
|
+
site_id: Annotated[
|
|
342
|
+
str | None,
|
|
343
|
+
"The site ID (required for SharePoint list items, along with list_id).",
|
|
344
|
+
] = None,
|
|
345
|
+
list_id: Annotated[
|
|
346
|
+
str | None,
|
|
347
|
+
"The list ID (required for SharePoint list items, along with site_id).",
|
|
348
|
+
] = None,
|
|
349
|
+
document_library_id: Annotated[
|
|
350
|
+
str | None,
|
|
351
|
+
"The drive ID (required for OneDrive/drive item updates). "
|
|
352
|
+
"Cannot be used together with site_id and list_id.",
|
|
353
|
+
] = None,
|
|
354
|
+
) -> ToolResult | ToolError:
|
|
355
|
+
"""
|
|
356
|
+
Update metadata on a SharePoint list item or OneDrive/SharePoint drive item.
|
|
357
|
+
|
|
358
|
+
**SharePoint List Items:** Provide site_id and list_id to update custom
|
|
359
|
+
column values on a list item. All custom columns can be updated.
|
|
360
|
+
|
|
361
|
+
**OneDrive/Drive Items:** Provide document_library_id to update drive item
|
|
362
|
+
properties. Only 'name' and 'description' fields can be updated.
|
|
363
|
+
|
|
364
|
+
**Context Requirements:**
|
|
365
|
+
- For SharePoint list items: Both site_id AND list_id are required
|
|
366
|
+
- For OneDrive/drive items: document_library_id is required
|
|
367
|
+
- Cannot specify both contexts simultaneously
|
|
368
|
+
|
|
369
|
+
**Examples:**
|
|
370
|
+
- SharePoint list item: Update a 'Status' column to 'Approved'
|
|
371
|
+
- Drive item: Rename a file or update its description
|
|
372
|
+
|
|
373
|
+
"""
|
|
374
|
+
if not item_id or not item_id.strip():
|
|
375
|
+
raise ToolError("Error: item_id is required.")
|
|
376
|
+
if not fields_to_update:
|
|
377
|
+
raise ToolError("Error: fields_to_update is required and cannot be empty.")
|
|
378
|
+
|
|
379
|
+
# Validate context parameters
|
|
380
|
+
has_sharepoint_context = site_id is not None and list_id is not None
|
|
381
|
+
has_partial_sharepoint_context = (site_id is not None) != (list_id is not None)
|
|
382
|
+
has_drive_context = document_library_id is not None
|
|
383
|
+
|
|
384
|
+
if has_partial_sharepoint_context:
|
|
385
|
+
raise ToolError(
|
|
386
|
+
"Error: For SharePoint list items, both site_id and list_id must be provided."
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
if has_sharepoint_context and has_drive_context:
|
|
390
|
+
raise ToolError(
|
|
391
|
+
"Error: Cannot specify both SharePoint (site_id + list_id) and OneDrive "
|
|
392
|
+
"(document_library_id) context. Choose one."
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
if not has_sharepoint_context and not has_drive_context:
|
|
396
|
+
raise ToolError(
|
|
397
|
+
"Error: Must specify either SharePoint context (site_id + list_id) or "
|
|
398
|
+
"OneDrive context (document_library_id)."
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
access_token = await get_microsoft_graph_access_token()
|
|
402
|
+
if isinstance(access_token, ToolError):
|
|
403
|
+
raise access_token
|
|
404
|
+
|
|
405
|
+
async with MicrosoftGraphClient(access_token=access_token) as client:
|
|
406
|
+
result = await client.update_item_metadata(
|
|
407
|
+
item_id=item_id.strip(),
|
|
408
|
+
fields_to_update=fields_to_update,
|
|
409
|
+
site_id=site_id,
|
|
410
|
+
list_id=list_id,
|
|
411
|
+
drive_id=document_library_id,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
context_type = "sharepoint_list_item" if has_sharepoint_context else "drive_item"
|
|
415
|
+
structured: dict[str, Any] = {
|
|
416
|
+
"item_id": item_id,
|
|
417
|
+
"context_type": context_type,
|
|
418
|
+
"fields_updated": list(fields_to_update.keys()),
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
# Add context-specific IDs for traceability
|
|
422
|
+
if has_sharepoint_context:
|
|
423
|
+
structured["site_id"] = site_id
|
|
424
|
+
structured["list_id"] = list_id
|
|
425
|
+
else:
|
|
426
|
+
structured["document_library_id"] = document_library_id
|
|
427
|
+
|
|
428
|
+
# Include relevant response data
|
|
429
|
+
if isinstance(result, dict):
|
|
430
|
+
# For drive items, include key properties if present
|
|
431
|
+
if has_drive_context:
|
|
432
|
+
if "id" in result:
|
|
433
|
+
structured["id"] = result["id"]
|
|
434
|
+
if "name" in result:
|
|
435
|
+
structured["name"] = result["name"]
|
|
436
|
+
if "webUrl" in result:
|
|
437
|
+
structured["webUrl"] = result["webUrl"]
|
|
438
|
+
if "description" in result:
|
|
439
|
+
structured["description"] = result.get("description")
|
|
440
|
+
# For list items, the response is the fields object itself
|
|
441
|
+
else:
|
|
442
|
+
structured["updated_fields"] = result
|
|
443
|
+
|
|
444
|
+
return ToolResult(
|
|
445
|
+
structured_content=structured,
|
|
446
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Copyright 2026 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Perplexity MCP tools."""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Annotated
|
|
19
|
+
from typing import Literal
|
|
20
|
+
|
|
21
|
+
from fastmcp.exceptions import ToolError
|
|
22
|
+
from fastmcp.tools.tool import ToolResult
|
|
23
|
+
|
|
24
|
+
from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
|
|
25
|
+
from datarobot_genai.drmcp.tools.clients.perplexity import MAX_QUERIES
|
|
26
|
+
from datarobot_genai.drmcp.tools.clients.perplexity import MAX_RESULTS
|
|
27
|
+
from datarobot_genai.drmcp.tools.clients.perplexity import MAX_RESULTS_DEFAULT
|
|
28
|
+
from datarobot_genai.drmcp.tools.clients.perplexity import MAX_SEARCH_DOMAIN_FILTER
|
|
29
|
+
from datarobot_genai.drmcp.tools.clients.perplexity import MAX_TOKENS_PER_PAGE
|
|
30
|
+
from datarobot_genai.drmcp.tools.clients.perplexity import MAX_TOKENS_PER_PAGE_DEFAULT
|
|
31
|
+
from datarobot_genai.drmcp.tools.clients.perplexity import PerplexityClient
|
|
32
|
+
from datarobot_genai.drmcp.tools.clients.perplexity import get_perplexity_access_token
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dr_mcp_tool(tags={"perplexity", "web", "search", "websearch"})
|
|
38
|
+
async def perplexity_search(
|
|
39
|
+
*,
|
|
40
|
+
query: Annotated[
|
|
41
|
+
str,
|
|
42
|
+
list[str],
|
|
43
|
+
f"The search query string OR "
|
|
44
|
+
f"a list of up to {MAX_QUERIES} sub-queries for multi-query research.",
|
|
45
|
+
],
|
|
46
|
+
search_domain_filter: Annotated[
|
|
47
|
+
list[str] | None,
|
|
48
|
+
f"Up to {MAX_SEARCH_DOMAIN_FILTER} domains/URLs "
|
|
49
|
+
f"to allowlist or denylist (prefix with '-').",
|
|
50
|
+
] = None,
|
|
51
|
+
recency: Annotated[
|
|
52
|
+
Literal["day", "week", "month", "year"] | None, "Filter results by time period."
|
|
53
|
+
] = None,
|
|
54
|
+
max_results: Annotated[
|
|
55
|
+
int, f"Number of ranked results to return (1-{MAX_RESULTS})."
|
|
56
|
+
] = MAX_RESULTS_DEFAULT,
|
|
57
|
+
max_tokens_per_page: Annotated[
|
|
58
|
+
int,
|
|
59
|
+
f"Content extraction cap per page (1-{MAX_TOKENS_PER_PAGE}) "
|
|
60
|
+
f"(default {MAX_TOKENS_PER_PAGE_DEFAULT}).",
|
|
61
|
+
] = MAX_TOKENS_PER_PAGE_DEFAULT,
|
|
62
|
+
) -> ToolResult:
|
|
63
|
+
"""Perplexity web search tool combining multi-query research and content extraction control."""
|
|
64
|
+
if not query:
|
|
65
|
+
raise ToolError("Argument validation error: query cannot be empty.")
|
|
66
|
+
if query and isinstance(query, str) and not query.strip():
|
|
67
|
+
raise ToolError("Argument validation error: query cannot be empty.")
|
|
68
|
+
if query and isinstance(query, list) and len(query) > MAX_QUERIES:
|
|
69
|
+
raise ToolError(
|
|
70
|
+
f"Argument validation error: query list cannot be bigger than {MAX_QUERIES}."
|
|
71
|
+
)
|
|
72
|
+
if query and isinstance(query, list) and not all(q.strip() for q in query):
|
|
73
|
+
raise ToolError("Argument validation error: query cannot contain empty str.")
|
|
74
|
+
if search_domain_filter and len(search_domain_filter) > MAX_SEARCH_DOMAIN_FILTER:
|
|
75
|
+
raise ToolError(
|
|
76
|
+
f"Argument validation error: "
|
|
77
|
+
f"maximum number of search domain filters is {MAX_SEARCH_DOMAIN_FILTER}."
|
|
78
|
+
)
|
|
79
|
+
if max_results <= 0:
|
|
80
|
+
raise ToolError("Argument validation error: max_results must be greater than 0.")
|
|
81
|
+
if max_results > MAX_RESULTS:
|
|
82
|
+
raise ToolError(
|
|
83
|
+
f"Argument validation error: "
|
|
84
|
+
f"max_results must be smaller than or equal to {MAX_RESULTS}."
|
|
85
|
+
)
|
|
86
|
+
if max_tokens_per_page <= 0:
|
|
87
|
+
raise ToolError("Argument validation error: max_tokens_per_page must be greater than 0.")
|
|
88
|
+
if max_tokens_per_page > MAX_TOKENS_PER_PAGE:
|
|
89
|
+
raise ToolError(
|
|
90
|
+
f"Argument validation error: "
|
|
91
|
+
f"max_tokens_per_page must be smaller than or equal to {MAX_TOKENS_PER_PAGE}."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
access_token = await get_perplexity_access_token()
|
|
95
|
+
if isinstance(access_token, ToolError):
|
|
96
|
+
raise access_token
|
|
97
|
+
|
|
98
|
+
async with PerplexityClient(access_token=access_token) as perplexity_client:
|
|
99
|
+
results = await perplexity_client.search(
|
|
100
|
+
query=query,
|
|
101
|
+
search_domain_filter=search_domain_filter,
|
|
102
|
+
recency=recency,
|
|
103
|
+
max_results=max_results,
|
|
104
|
+
max_tokens_per_page=max_tokens_per_page,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return ToolResult(
|
|
108
|
+
structured_content={
|
|
109
|
+
"results": results,
|
|
110
|
+
"count": len(results),
|
|
111
|
+
"metadata": {
|
|
112
|
+
"queriesExecuted": len(query) if isinstance(query, list) else 1,
|
|
113
|
+
"filtersApplied": {"domains": search_domain_filter, "recency": recency},
|
|
114
|
+
"extractionLimit": max_tokens_per_page,
|
|
115
|
+
},
|
|
116
|
+
},
|
|
117
|
+
)
|
|
@@ -12,7 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import json
|
|
16
15
|
import logging
|
|
17
16
|
import os
|
|
18
17
|
from typing import Annotated
|
|
@@ -60,7 +59,6 @@ async def upload_dataset_to_ai_catalog(
|
|
|
60
59
|
raise ToolError("Failed to upload dataset.")
|
|
61
60
|
|
|
62
61
|
return ToolResult(
|
|
63
|
-
content=f"Successfully uploaded dataset: {catalog_item.id}",
|
|
64
62
|
structured_content={
|
|
65
63
|
"dataset_id": catalog_item.id,
|
|
66
64
|
"dataset_version_id": catalog_item.version_id,
|
|
@@ -78,21 +76,15 @@ async def list_ai_catalog_items() -> ToolResult:
|
|
|
78
76
|
if not datasets:
|
|
79
77
|
logger.info("No AI Catalog items found")
|
|
80
78
|
return ToolResult(
|
|
81
|
-
content="No AI Catalog items found.",
|
|
82
79
|
structured_content={"datasets": []},
|
|
83
80
|
)
|
|
84
81
|
|
|
85
82
|
datasets_dict = {ds.id: ds.name for ds in datasets}
|
|
86
|
-
datasets_count = len(datasets)
|
|
87
83
|
|
|
88
84
|
return ToolResult(
|
|
89
|
-
content=(
|
|
90
|
-
f"Found {datasets_count} AI Catalog items, here are the details:\n"
|
|
91
|
-
f"{json.dumps(datasets_dict, indent=2)}"
|
|
92
|
-
),
|
|
93
85
|
structured_content={
|
|
94
86
|
"datasets": datasets_dict,
|
|
95
|
-
"count":
|
|
87
|
+
"count": len(datasets),
|
|
96
88
|
},
|
|
97
89
|
)
|
|
98
90
|
|
|
@@ -12,7 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import json
|
|
16
15
|
import logging
|
|
17
16
|
from typing import Annotated
|
|
18
17
|
|
|
@@ -32,12 +31,10 @@ async def list_deployments() -> ToolResult:
|
|
|
32
31
|
deployments = client.Deployment.list()
|
|
33
32
|
if not deployments:
|
|
34
33
|
return ToolResult(
|
|
35
|
-
content="No deployments found.",
|
|
36
34
|
structured_content={"deployments": []},
|
|
37
35
|
)
|
|
38
36
|
deployments_dict = {d.id: d.label for d in deployments}
|
|
39
37
|
return ToolResult(
|
|
40
|
-
content="\n".join(f"{d.id}: {d.label}" for d in deployments),
|
|
41
38
|
structured_content={"deployments": deployments_dict},
|
|
42
39
|
)
|
|
43
40
|
|
|
@@ -54,10 +51,6 @@ async def get_model_info_from_deployment(
|
|
|
54
51
|
client = get_sdk_client()
|
|
55
52
|
deployment = client.Deployment.get(deployment_id)
|
|
56
53
|
return ToolResult(
|
|
57
|
-
content=(
|
|
58
|
-
f"Retrieved model info for deployment {deployment_id}, here are the details:\n"
|
|
59
|
-
f"{json.dumps(deployment.model, indent=2)}"
|
|
60
|
-
),
|
|
61
54
|
structured_content=deployment.model,
|
|
62
55
|
)
|
|
63
56
|
|
|
@@ -87,7 +80,6 @@ async def deploy_model(
|
|
|
87
80
|
default_prediction_server_id=prediction_servers[0].id,
|
|
88
81
|
)
|
|
89
82
|
return ToolResult(
|
|
90
|
-
content=f"Created deployment {deployment.id} with label {label}",
|
|
91
83
|
structured_content={
|
|
92
84
|
"deployment_id": deployment.id,
|
|
93
85
|
"label": label,
|
|
@@ -95,7 +95,6 @@ async def get_deployment_info(
|
|
|
95
95
|
}
|
|
96
96
|
|
|
97
97
|
return ToolResult(
|
|
98
|
-
content=json.dumps(result, indent=2),
|
|
99
98
|
structured_content=result,
|
|
100
99
|
)
|
|
101
100
|
|
|
@@ -179,21 +178,6 @@ async def generate_prediction_data_template(
|
|
|
179
178
|
# Create DataFrame
|
|
180
179
|
df = pd.DataFrame(template_data)
|
|
181
180
|
|
|
182
|
-
# Add metadata comments
|
|
183
|
-
result = f"# Prediction Data Template for Deployment: {deployment_id}\n"
|
|
184
|
-
result += f"# Model Type: {features_info['model_type']}\n"
|
|
185
|
-
result += f"# Target: {features_info['target']} (Type: {features_info['target_type']})\n"
|
|
186
|
-
|
|
187
|
-
if "time_series_config" in features_info:
|
|
188
|
-
ts = features_info["time_series_config"]
|
|
189
|
-
result += f"# Time Series: datetime_column={ts['datetime_column']}, "
|
|
190
|
-
result += f"forecast_window=[{ts['forecast_window_start']}, {ts['forecast_window_end']}]\n"
|
|
191
|
-
if ts["series_id_columns"]:
|
|
192
|
-
result += f"# Multiseries ID Columns: {', '.join(ts['series_id_columns'])}\n"
|
|
193
|
-
|
|
194
|
-
result += f"# Total Features: {features_info['total_features']}\n"
|
|
195
|
-
result += df.to_csv(index=False)
|
|
196
|
-
|
|
197
181
|
# Build structured content with template data and metadata
|
|
198
182
|
structured_content = {
|
|
199
183
|
"deployment_id": deployment_id,
|
|
@@ -208,7 +192,6 @@ async def generate_prediction_data_template(
|
|
|
208
192
|
structured_content["time_series_config"] = features_info["time_series_config"]
|
|
209
193
|
|
|
210
194
|
return ToolResult(
|
|
211
|
-
content=str(result),
|
|
212
195
|
structured_content=structured_content,
|
|
213
196
|
)
|
|
214
197
|
|
|
@@ -342,7 +325,6 @@ async def validate_prediction_data(
|
|
|
342
325
|
}
|
|
343
326
|
|
|
344
327
|
return ToolResult(
|
|
345
|
-
content=json.dumps(validation_report, indent=2),
|
|
346
328
|
structured_content=validation_report,
|
|
347
329
|
)
|
|
348
330
|
|
|
@@ -380,6 +362,5 @@ async def get_deployment_features(
|
|
|
380
362
|
result["target_type"] = info["target_type"]
|
|
381
363
|
|
|
382
364
|
return ToolResult(
|
|
383
|
-
content=json.dumps(result, indent=2),
|
|
384
365
|
structured_content=result,
|
|
385
366
|
)
|
|
@@ -93,33 +93,17 @@ async def get_best_model(
|
|
|
93
93
|
best_model = leaderboard[0]
|
|
94
94
|
logger.info(f"Found best model {best_model.id} for project {project_id}")
|
|
95
95
|
|
|
96
|
-
metric_info = ""
|
|
97
96
|
metric_value = None
|
|
98
97
|
|
|
99
98
|
if metric and best_model.metrics and metric in best_model.metrics:
|
|
100
99
|
metric_value = best_model.metrics[metric].get("validation")
|
|
101
|
-
if metric_value is not None:
|
|
102
|
-
metric_info = f" with {metric}: {metric_value:.2f}"
|
|
103
100
|
|
|
104
101
|
# Include full metrics in the response
|
|
105
102
|
best_model_dict = model_to_dict(best_model)
|
|
106
103
|
best_model_dict["metric"] = metric
|
|
107
104
|
best_model_dict["metric_value"] = metric_value
|
|
108
105
|
|
|
109
|
-
# Format metrics for human-readable content
|
|
110
|
-
metrics_text = ""
|
|
111
|
-
if best_model.metrics:
|
|
112
|
-
metrics_list = []
|
|
113
|
-
for metric_name, metric_data in best_model.metrics.items():
|
|
114
|
-
if isinstance(metric_data, dict) and "validation" in metric_data:
|
|
115
|
-
val = metric_data["validation"]
|
|
116
|
-
if val is not None:
|
|
117
|
-
metrics_list.append(f"{metric_name}: {val:.4f}")
|
|
118
|
-
if metrics_list:
|
|
119
|
-
metrics_text = "\nPerformance metrics:\n" + "\n".join(f" - {m}" for m in metrics_list)
|
|
120
|
-
|
|
121
106
|
return ToolResult(
|
|
122
|
-
content=f"Best model: {best_model.model_type}{metric_info}{metrics_text}",
|
|
123
107
|
structured_content={
|
|
124
108
|
"project_id": project_id,
|
|
125
109
|
"best_model": best_model_dict,
|
|
@@ -148,7 +132,6 @@ async def score_dataset_with_model(
|
|
|
148
132
|
job = model.score(dataset_url)
|
|
149
133
|
|
|
150
134
|
return ToolResult(
|
|
151
|
-
content=f"Scoring job started: {job.id}",
|
|
152
135
|
structured_content={
|
|
153
136
|
"scoring_job_id": job.id,
|
|
154
137
|
"project_id": project_id,
|
|
@@ -172,10 +155,6 @@ async def list_models(
|
|
|
172
155
|
models = project.get_models()
|
|
173
156
|
|
|
174
157
|
return ToolResult(
|
|
175
|
-
content=(
|
|
176
|
-
f"Found {len(models)} models in project {project_id}, here are the details:\n"
|
|
177
|
-
f"{json.dumps(models, indent=2, cls=ModelEncoder)}"
|
|
178
|
-
),
|
|
179
158
|
structured_content={
|
|
180
159
|
"project_id": project_id,
|
|
181
160
|
"models": [model_to_dict(model) for model in models],
|
|
@@ -240,6 +240,9 @@ async def predict_realtime(
|
|
|
240
240
|
else:
|
|
241
241
|
raise ValueError("Either file_path or dataset must be provided.")
|
|
242
242
|
|
|
243
|
+
# Normalize column names: strip leading/trailing whitespace
|
|
244
|
+
df.columns = df.columns.str.strip()
|
|
245
|
+
|
|
243
246
|
if series_id_column and series_id_column not in df.columns:
|
|
244
247
|
raise ValueError(f"series_id_column '{series_id_column}' not found in input data.")
|
|
245
248
|
|
|
@@ -12,7 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import json
|
|
16
15
|
import logging
|
|
17
16
|
from typing import Annotated
|
|
18
17
|
|
|
@@ -33,11 +32,6 @@ async def list_projects() -> ToolResult:
|
|
|
33
32
|
projects = {p.id: p.project_name for p in projects}
|
|
34
33
|
|
|
35
34
|
return ToolResult(
|
|
36
|
-
content=(
|
|
37
|
-
json.dumps(projects, indent=2)
|
|
38
|
-
if projects
|
|
39
|
-
else json.dumps({"message": "No projects found."}, indent=2)
|
|
40
|
-
),
|
|
41
35
|
structured_content=projects,
|
|
42
36
|
)
|
|
43
37
|
|
|
@@ -48,7 +42,7 @@ async def get_project_dataset_by_name(
|
|
|
48
42
|
project_id: Annotated[str, "The ID of the DataRobot project."] | None = None,
|
|
49
43
|
dataset_name: Annotated[str, "The name of the dataset to find (e.g., 'training', 'holdout')."]
|
|
50
44
|
| None = None,
|
|
51
|
-
) ->
|
|
45
|
+
) -> ToolResult:
|
|
52
46
|
"""Get a dataset ID by name for a given project.
|
|
53
47
|
|
|
54
48
|
The dataset ID and the dataset type (source or prediction) as a string, or an error message.
|
|
@@ -70,21 +64,11 @@ async def get_project_dataset_by_name(
|
|
|
70
64
|
for ds in all_datasets:
|
|
71
65
|
if dataset_name.lower() in ds["dataset"].name.lower():
|
|
72
66
|
return ToolResult(
|
|
73
|
-
content=(
|
|
74
|
-
json.dumps(
|
|
75
|
-
{
|
|
76
|
-
"dataset_id": ds["dataset"].id,
|
|
77
|
-
"dataset_type": ds["type"],
|
|
78
|
-
},
|
|
79
|
-
indent=2,
|
|
80
|
-
)
|
|
81
|
-
),
|
|
82
67
|
structured_content={
|
|
83
68
|
"dataset_id": ds["dataset"].id,
|
|
84
69
|
"dataset_type": ds["type"],
|
|
85
70
|
},
|
|
86
71
|
)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
structured_content={},
|
|
72
|
+
raise ToolError(
|
|
73
|
+
f"Dataset with name containing '{dataset_name}' not found in project {project_id}."
|
|
90
74
|
)
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
|
|
15
15
|
"""Tools for analyzing datasets and suggesting ML use cases."""
|
|
16
16
|
|
|
17
|
-
import json
|
|
18
17
|
import logging
|
|
19
18
|
from dataclasses import asdict
|
|
20
19
|
from dataclasses import dataclass
|
|
@@ -134,7 +133,6 @@ async def analyze_dataset(
|
|
|
134
133
|
insights_dict = asdict(insights)
|
|
135
134
|
|
|
136
135
|
return ToolResult(
|
|
137
|
-
content=json.dumps(insights_dict, indent=2),
|
|
138
136
|
structured_content=insights_dict,
|
|
139
137
|
)
|
|
140
138
|
|
|
@@ -164,7 +162,6 @@ async def suggest_use_cases(
|
|
|
164
162
|
suggestions.sort(key=lambda x: x["confidence"], reverse=True)
|
|
165
163
|
|
|
166
164
|
return ToolResult(
|
|
167
|
-
content=json.dumps(suggestions, indent=2),
|
|
168
165
|
structured_content={"use_case_suggestions": suggestions},
|
|
169
166
|
)
|
|
170
167
|
|
|
@@ -255,7 +252,6 @@ async def get_exploratory_insights(
|
|
|
255
252
|
)
|
|
256
253
|
|
|
257
254
|
return ToolResult(
|
|
258
|
-
content=json.dumps(eda_insights, indent=2),
|
|
259
255
|
structured_content=eda_insights,
|
|
260
256
|
)
|
|
261
257
|
|
|
@@ -540,22 +536,11 @@ async def start_autopilot(
|
|
|
540
536
|
}
|
|
541
537
|
|
|
542
538
|
return ToolResult(
|
|
543
|
-
content=json.dumps(result, indent=2),
|
|
544
539
|
structured_content=result,
|
|
545
540
|
)
|
|
546
541
|
|
|
547
542
|
except Exception as e:
|
|
548
|
-
raise ToolError(
|
|
549
|
-
content=json.dumps(
|
|
550
|
-
{
|
|
551
|
-
"error": f"Failed to start Autopilot: {str(e)}",
|
|
552
|
-
"project_id": project.id if project else None,
|
|
553
|
-
"target": target,
|
|
554
|
-
"mode": mode,
|
|
555
|
-
},
|
|
556
|
-
indent=2,
|
|
557
|
-
)
|
|
558
|
-
)
|
|
543
|
+
raise ToolError(f"Failed to start Autopilot: {str(e)}")
|
|
559
544
|
|
|
560
545
|
|
|
561
546
|
@dr_mcp_tool(tags={"prediction", "training", "read", "model", "evaluation"})
|
|
@@ -611,7 +596,6 @@ async def get_model_roc_curve(
|
|
|
611
596
|
}
|
|
612
597
|
|
|
613
598
|
return ToolResult(
|
|
614
|
-
content=json.dumps({"data": roc_data}, indent=2),
|
|
615
599
|
structured_content={"data": roc_data},
|
|
616
600
|
)
|
|
617
601
|
except Exception as e:
|
|
@@ -638,7 +622,6 @@ async def get_model_feature_impact(
|
|
|
638
622
|
feature_impact = model.get_or_request_feature_impact()
|
|
639
623
|
|
|
640
624
|
return ToolResult(
|
|
641
|
-
content=json.dumps({"data": feature_impact}, indent=2),
|
|
642
625
|
structured_content={"data": feature_impact},
|
|
643
626
|
)
|
|
644
627
|
|
|
@@ -684,6 +667,5 @@ async def get_model_lift_chart(
|
|
|
684
667
|
}
|
|
685
668
|
|
|
686
669
|
return ToolResult(
|
|
687
|
-
content=json.dumps({"data": lift_chart_data}, indent=2),
|
|
688
670
|
structured_content={"data": lift_chart_data},
|
|
689
671
|
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|