datarobot-genai 0.2.26__py3-none-any.whl → 0.2.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/core/cli/agent_kernel.py +4 -1
- datarobot_genai/drmcp/__init__.py +2 -2
- datarobot_genai/drmcp/core/config.py +121 -83
- datarobot_genai/drmcp/core/exceptions.py +0 -4
- datarobot_genai/drmcp/core/logging.py +2 -2
- datarobot_genai/drmcp/core/tool_config.py +17 -9
- datarobot_genai/drmcp/test_utils/clients/__init__.py +0 -0
- datarobot_genai/drmcp/test_utils/clients/anthropic.py +68 -0
- datarobot_genai/drmcp/test_utils/{openai_llm_mcp_client.py → clients/base.py} +38 -40
- datarobot_genai/drmcp/test_utils/clients/dr_gateway.py +58 -0
- datarobot_genai/drmcp/test_utils/clients/openai.py +68 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +20 -0
- datarobot_genai/drmcp/test_utils/test_interactive.py +16 -16
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +69 -2
- datarobot_genai/drmcp/test_utils/utils.py +1 -1
- datarobot_genai/drmcp/tools/clients/gdrive.py +314 -1
- datarobot_genai/drmcp/tools/clients/microsoft_graph.py +479 -0
- datarobot_genai/drmcp/tools/gdrive/tools.py +273 -4
- datarobot_genai/drmcp/tools/microsoft_graph/__init__.py +13 -0
- datarobot_genai/drmcp/tools/microsoft_graph/tools.py +198 -0
- datarobot_genai/drmcp/tools/predictive/data.py +16 -8
- datarobot_genai/drmcp/tools/predictive/model.py +87 -52
- datarobot_genai/drmcp/tools/predictive/project.py +2 -2
- datarobot_genai/drmcp/tools/predictive/training.py +15 -14
- datarobot_genai/nat/datarobot_llm_clients.py +90 -54
- datarobot_genai/nat/datarobot_mcp_client.py +47 -15
- {datarobot_genai-0.2.26.dist-info → datarobot_genai-0.2.34.dist-info}/METADATA +1 -1
- {datarobot_genai-0.2.26.dist-info → datarobot_genai-0.2.34.dist-info}/RECORD +32 -25
- {datarobot_genai-0.2.26.dist-info → datarobot_genai-0.2.34.dist-info}/WHEEL +0 -0
- {datarobot_genai-0.2.26.dist-info → datarobot_genai-0.2.34.dist-info}/entry_points.txt +0 -0
- {datarobot_genai-0.2.26.dist-info → datarobot_genai-0.2.34.dist-info}/licenses/AUTHORS +0 -0
- {datarobot_genai-0.2.26.dist-info → datarobot_genai-0.2.34.dist-info}/licenses/LICENSE +0 -0
|
@@ -14,9 +14,12 @@
|
|
|
14
14
|
|
|
15
15
|
import json
|
|
16
16
|
import logging
|
|
17
|
+
from typing import Annotated
|
|
17
18
|
from typing import Any
|
|
18
19
|
|
|
19
20
|
from datarobot.models.model import Model
|
|
21
|
+
from fastmcp.exceptions import ToolError
|
|
22
|
+
from fastmcp.tools.tool import ToolResult
|
|
20
23
|
|
|
21
24
|
from datarobot_genai.drmcp.core.clients import get_sdk_client
|
|
22
25
|
from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
|
|
@@ -50,33 +53,25 @@ class ModelEncoder(json.JSONEncoder):
|
|
|
50
53
|
return super().default(obj)
|
|
51
54
|
|
|
52
55
|
|
|
53
|
-
@dr_mcp_tool(tags={"model", "management", "info"})
|
|
54
|
-
async def get_best_model(
|
|
55
|
-
|
|
56
|
-
|
|
56
|
+
@dr_mcp_tool(tags={"predictive", "model", "read", "management", "info"})
|
|
57
|
+
async def get_best_model(
|
|
58
|
+
*,
|
|
59
|
+
project_id: Annotated[str, "The DataRobot project ID"] | None = None,
|
|
60
|
+
metric: Annotated[str, "The metric to use for best model selection (e.g., 'AUC', 'LogLoss')"]
|
|
61
|
+
| None = None,
|
|
62
|
+
) -> ToolError | ToolResult:
|
|
63
|
+
"""Get the best model for a DataRobot project, optionally by a specific metric."""
|
|
64
|
+
if not project_id:
|
|
65
|
+
raise ToolError("Project ID must be provided")
|
|
57
66
|
|
|
58
|
-
Args:
|
|
59
|
-
project_id: The ID of the DataRobot project.
|
|
60
|
-
metric: (Optional) The metric to use for best model selection (e.g., 'AUC', 'LogLoss').
|
|
61
|
-
|
|
62
|
-
Returns
|
|
63
|
-
-------
|
|
64
|
-
A formatted string describing the best model.
|
|
65
|
-
|
|
66
|
-
Raises
|
|
67
|
-
------
|
|
68
|
-
Exception: If project not found or no models exist in the project.
|
|
69
|
-
"""
|
|
70
67
|
client = get_sdk_client()
|
|
71
68
|
project = client.Project.get(project_id)
|
|
72
69
|
if not project:
|
|
73
|
-
|
|
74
|
-
raise Exception(f"Project with ID {project_id} not found.")
|
|
70
|
+
raise ToolError(f"Project with ID {project_id} not found.")
|
|
75
71
|
|
|
76
72
|
leaderboard = project.get_models()
|
|
77
73
|
if not leaderboard:
|
|
78
|
-
|
|
79
|
-
raise Exception("No models found for this project.")
|
|
74
|
+
raise ToolError("No models found for this project.")
|
|
80
75
|
|
|
81
76
|
if metric:
|
|
82
77
|
reverse_sort = metric.upper() in [
|
|
@@ -98,51 +93,91 @@ async def get_best_model(project_id: str, metric: str | None = None) -> str:
|
|
|
98
93
|
best_model = leaderboard[0]
|
|
99
94
|
logger.info(f"Found best model {best_model.id} for project {project_id}")
|
|
100
95
|
|
|
101
|
-
# Format the response as a human-readable string
|
|
102
96
|
metric_info = ""
|
|
97
|
+
metric_value = None
|
|
98
|
+
|
|
103
99
|
if metric and best_model.metrics and metric in best_model.metrics:
|
|
104
100
|
metric_value = best_model.metrics[metric].get("validation")
|
|
105
101
|
if metric_value is not None:
|
|
106
102
|
metric_info = f" with {metric}: {metric_value:.2f}"
|
|
107
103
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
104
|
+
# Include full metrics in the response
|
|
105
|
+
best_model_dict = model_to_dict(best_model)
|
|
106
|
+
best_model_dict["metric"] = metric
|
|
107
|
+
best_model_dict["metric_value"] = metric_value
|
|
108
|
+
|
|
109
|
+
# Format metrics for human-readable content
|
|
110
|
+
metrics_text = ""
|
|
111
|
+
if best_model.metrics:
|
|
112
|
+
metrics_list = []
|
|
113
|
+
for metric_name, metric_data in best_model.metrics.items():
|
|
114
|
+
if isinstance(metric_data, dict) and "validation" in metric_data:
|
|
115
|
+
val = metric_data["validation"]
|
|
116
|
+
if val is not None:
|
|
117
|
+
metrics_list.append(f"{metric_name}: {val:.4f}")
|
|
118
|
+
if metrics_list:
|
|
119
|
+
metrics_text = "\nPerformance metrics:\n" + "\n".join(f" - {m}" for m in metrics_list)
|
|
120
|
+
|
|
121
|
+
return ToolResult(
|
|
122
|
+
content=f"Best model: {best_model.model_type}{metric_info}{metrics_text}",
|
|
123
|
+
structured_content={
|
|
124
|
+
"project_id": project_id,
|
|
125
|
+
"best_model": best_model_dict,
|
|
126
|
+
},
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@dr_mcp_tool(tags={"predictive", "model", "read", "scoring", "dataset"})
|
|
131
|
+
async def score_dataset_with_model(
|
|
132
|
+
*,
|
|
133
|
+
project_id: Annotated[str, "The DataRobot project ID"] | None = None,
|
|
134
|
+
model_id: Annotated[str, "The DataRobot model ID"] | None = None,
|
|
135
|
+
dataset_url: Annotated[str, "The dataset URL"] | None = None,
|
|
136
|
+
) -> ToolError | ToolResult:
|
|
137
|
+
"""Score a dataset using a specific DataRobot model."""
|
|
138
|
+
if not project_id:
|
|
139
|
+
raise ToolError("Project ID must be provided")
|
|
140
|
+
if not model_id:
|
|
141
|
+
raise ToolError("Model ID must be provided")
|
|
142
|
+
if not dataset_url:
|
|
143
|
+
raise ToolError("Dataset URL must be provided")
|
|
115
144
|
|
|
116
|
-
Args:
|
|
117
|
-
project_id: The ID of the DataRobot project.
|
|
118
|
-
model_id: The ID of the DataRobot model to use for scoring.
|
|
119
|
-
dataset_url: The URL to the dataset to score (must be accessible to DataRobot).
|
|
120
|
-
|
|
121
|
-
Returns
|
|
122
|
-
-------
|
|
123
|
-
A string summary of the scoring job or a meaningful error message.
|
|
124
|
-
"""
|
|
125
145
|
client = get_sdk_client()
|
|
126
146
|
project = client.Project.get(project_id)
|
|
127
147
|
model = client.Model.get(project, model_id)
|
|
128
148
|
job = model.score(dataset_url)
|
|
129
|
-
logger.info(f"Started scoring job {job.id} for model {model_id}")
|
|
130
|
-
return f"Scoring job started: {job.id}"
|
|
131
|
-
|
|
132
149
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
150
|
+
return ToolResult(
|
|
151
|
+
content=f"Scoring job started: {job.id}",
|
|
152
|
+
structured_content={
|
|
153
|
+
"scoring_job_id": job.id,
|
|
154
|
+
"project_id": project_id,
|
|
155
|
+
"model_id": model_id,
|
|
156
|
+
"dataset_url": dataset_url,
|
|
157
|
+
},
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@dr_mcp_tool(tags={"predictive", "model", "read", "management", "list"})
|
|
162
|
+
async def list_models(
|
|
163
|
+
*,
|
|
164
|
+
project_id: Annotated[str, "The DataRobot project ID"] | None = None,
|
|
165
|
+
) -> ToolError | ToolResult:
|
|
166
|
+
"""List all models in a project."""
|
|
167
|
+
if not project_id:
|
|
168
|
+
raise ToolError("Project ID must be provided")
|
|
137
169
|
|
|
138
|
-
Args:
|
|
139
|
-
project_id: The ID of the DataRobot project.
|
|
140
|
-
|
|
141
|
-
Returns
|
|
142
|
-
-------
|
|
143
|
-
A string summary of the models in the project.
|
|
144
|
-
"""
|
|
145
170
|
client = get_sdk_client()
|
|
146
171
|
project = client.Project.get(project_id)
|
|
147
172
|
models = project.get_models()
|
|
148
|
-
|
|
173
|
+
|
|
174
|
+
return ToolResult(
|
|
175
|
+
content=(
|
|
176
|
+
f"Found {len(models)} models in project {project_id}, here are the details:\n"
|
|
177
|
+
f"{json.dumps(models, indent=2, cls=ModelEncoder)}"
|
|
178
|
+
),
|
|
179
|
+
structured_content={
|
|
180
|
+
"project_id": project_id,
|
|
181
|
+
"models": [model_to_dict(model) for model in models],
|
|
182
|
+
},
|
|
183
|
+
)
|
|
@@ -54,9 +54,9 @@ async def get_project_dataset_by_name(
|
|
|
54
54
|
The dataset ID and the dataset type (source or prediction) as a string, or an error message.
|
|
55
55
|
"""
|
|
56
56
|
if not project_id:
|
|
57
|
-
|
|
57
|
+
raise ToolError("Project ID is required.")
|
|
58
58
|
if not dataset_name:
|
|
59
|
-
|
|
59
|
+
raise ToolError("Dataset name is required.")
|
|
60
60
|
|
|
61
61
|
client = get_sdk_client()
|
|
62
62
|
project = client.Project.get(project_id)
|
|
@@ -63,7 +63,7 @@ async def analyze_dataset(
|
|
|
63
63
|
) -> ToolError | ToolResult:
|
|
64
64
|
"""Analyze a dataset to understand its structure and potential use cases."""
|
|
65
65
|
if not dataset_id:
|
|
66
|
-
|
|
66
|
+
raise ToolError("Dataset ID must be provided")
|
|
67
67
|
|
|
68
68
|
client = get_sdk_client()
|
|
69
69
|
dataset = client.Dataset.get(dataset_id)
|
|
@@ -116,7 +116,7 @@ async def suggest_use_cases(
|
|
|
116
116
|
) -> ToolError | ToolResult:
|
|
117
117
|
"""Analyze a dataset and suggest potential machine learning use cases."""
|
|
118
118
|
if not dataset_id:
|
|
119
|
-
|
|
119
|
+
raise ToolError("Dataset ID must be provided")
|
|
120
120
|
|
|
121
121
|
client = get_sdk_client()
|
|
122
122
|
dataset = client.Dataset.get(dataset_id)
|
|
@@ -148,7 +148,7 @@ async def get_exploratory_insights(
|
|
|
148
148
|
) -> ToolError | ToolResult:
|
|
149
149
|
"""Generate exploratory data insights for a dataset."""
|
|
150
150
|
if not dataset_id:
|
|
151
|
-
|
|
151
|
+
raise ToolError("Dataset ID must be provided")
|
|
152
152
|
|
|
153
153
|
client = get_sdk_client()
|
|
154
154
|
dataset = client.Dataset.get(dataset_id)
|
|
@@ -481,9 +481,9 @@ async def start_autopilot(
|
|
|
481
481
|
|
|
482
482
|
if not project_id:
|
|
483
483
|
if not dataset_url and not dataset_id:
|
|
484
|
-
|
|
484
|
+
raise ToolError("Either dataset_url or dataset_id must be provided")
|
|
485
485
|
if dataset_url and dataset_id:
|
|
486
|
-
|
|
486
|
+
raise ToolError("Please provide either dataset_url or dataset_id, not both")
|
|
487
487
|
|
|
488
488
|
if dataset_url:
|
|
489
489
|
dataset = client.Dataset.create_from_url(dataset_url)
|
|
@@ -497,7 +497,7 @@ async def start_autopilot(
|
|
|
497
497
|
project = client.Project.get(project_id)
|
|
498
498
|
|
|
499
499
|
if not target:
|
|
500
|
-
|
|
500
|
+
raise ToolError("Target variable must be specified")
|
|
501
501
|
|
|
502
502
|
try:
|
|
503
503
|
# Start modeling
|
|
@@ -517,7 +517,7 @@ async def start_autopilot(
|
|
|
517
517
|
)
|
|
518
518
|
|
|
519
519
|
except Exception as e:
|
|
520
|
-
|
|
520
|
+
raise ToolError(
|
|
521
521
|
content=json.dumps(
|
|
522
522
|
{
|
|
523
523
|
"error": f"Failed to start Autopilot: {str(e)}",
|
|
@@ -546,9 +546,9 @@ async def get_model_roc_curve(
|
|
|
546
546
|
) -> ToolError | ToolResult:
|
|
547
547
|
"""Get detailed ROC curve for a specific model."""
|
|
548
548
|
if not project_id:
|
|
549
|
-
|
|
549
|
+
raise ToolError("Project ID must be provided")
|
|
550
550
|
if not model_id:
|
|
551
|
-
|
|
551
|
+
raise ToolError("Model ID must be provided")
|
|
552
552
|
|
|
553
553
|
client = get_sdk_client()
|
|
554
554
|
project = client.Project.get(project_id)
|
|
@@ -587,7 +587,7 @@ async def get_model_roc_curve(
|
|
|
587
587
|
structured_content={"data": roc_data},
|
|
588
588
|
)
|
|
589
589
|
except Exception as e:
|
|
590
|
-
|
|
590
|
+
raise ToolError(f"Failed to get ROC curve: {str(e)}")
|
|
591
591
|
|
|
592
592
|
|
|
593
593
|
@dr_mcp_tool(tags={"predictive", "training", "read", "model", "evaluation"})
|
|
@@ -598,9 +598,9 @@ async def get_model_feature_impact(
|
|
|
598
598
|
) -> ToolError | ToolResult:
|
|
599
599
|
"""Get detailed feature impact for a specific model."""
|
|
600
600
|
if not project_id:
|
|
601
|
-
|
|
601
|
+
raise ToolError("Project ID must be provided")
|
|
602
602
|
if not model_id:
|
|
603
|
-
|
|
603
|
+
raise ToolError("Model ID must be provided")
|
|
604
604
|
|
|
605
605
|
client = get_sdk_client()
|
|
606
606
|
project = client.Project.get(project_id)
|
|
@@ -617,6 +617,7 @@ async def get_model_feature_impact(
|
|
|
617
617
|
|
|
618
618
|
@dr_mcp_tool(tags={"predictive", "training", "read", "model", "evaluation"})
|
|
619
619
|
async def get_model_lift_chart(
|
|
620
|
+
*,
|
|
620
621
|
project_id: Annotated[str, "The ID of the DataRobot project"] | None = None,
|
|
621
622
|
model_id: Annotated[str, "The ID of the model to analyze"] | None = None,
|
|
622
623
|
source: Annotated[
|
|
@@ -630,9 +631,9 @@ async def get_model_lift_chart(
|
|
|
630
631
|
) -> ToolError | ToolResult:
|
|
631
632
|
"""Get detailed lift chart for a specific model."""
|
|
632
633
|
if not project_id:
|
|
633
|
-
|
|
634
|
+
raise ToolError("Project ID must be provided")
|
|
634
635
|
if not model_id:
|
|
635
|
-
|
|
636
|
+
raise ToolError("Model ID must be provided")
|
|
636
637
|
|
|
637
638
|
client = get_sdk_client()
|
|
638
639
|
project = client.Project.get(project_id)
|
|
@@ -12,22 +12,18 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
15
17
|
from collections.abc import AsyncGenerator
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
16
19
|
from typing import Any
|
|
17
20
|
from typing import TypeVar
|
|
18
21
|
|
|
19
|
-
from crewai import LLM
|
|
20
|
-
from langchain_openai import ChatOpenAI
|
|
21
|
-
from llama_index.core.base.llms.types import LLMMetadata
|
|
22
|
-
from llama_index.llms.litellm import LiteLLM
|
|
23
22
|
from nat.builder.builder import Builder
|
|
24
23
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
25
24
|
from nat.cli.register_workflow import register_llm_client
|
|
26
25
|
from nat.data_models.llm import LLMBaseConfig
|
|
27
26
|
from nat.data_models.retry_mixin import RetryMixin
|
|
28
|
-
from nat.plugins.langchain.llm import (
|
|
29
|
-
_patch_llm_based_on_config as langchain_patch_llm_based_on_config,
|
|
30
|
-
)
|
|
31
27
|
from nat.utils.exception_handlers.automatic_retries import patch_with_retry
|
|
32
28
|
|
|
33
29
|
from ..nat.datarobot_llm_providers import DataRobotLLMComponentModelConfig
|
|
@@ -35,6 +31,11 @@ from ..nat.datarobot_llm_providers import DataRobotLLMDeploymentModelConfig
|
|
|
35
31
|
from ..nat.datarobot_llm_providers import DataRobotLLMGatewayModelConfig
|
|
36
32
|
from ..nat.datarobot_llm_providers import DataRobotNIMModelConfig
|
|
37
33
|
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from crewai import LLM
|
|
36
|
+
from langchain_openai import ChatOpenAI
|
|
37
|
+
from llama_index.llms.litellm import LiteLLM
|
|
38
|
+
|
|
38
39
|
ModelType = TypeVar("ModelType")
|
|
39
40
|
|
|
40
41
|
|
|
@@ -50,42 +51,53 @@ def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) ->
|
|
|
50
51
|
return client
|
|
51
52
|
|
|
52
53
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
endpoints with LlamaIndex agents, you must override this method to return the appropriate
|
|
80
|
-
metadata.
|
|
54
|
+
def _create_datarobot_chat_openai(config: dict[str, Any]) -> Any:
|
|
55
|
+
from langchain_openai import ChatOpenAI # noqa: PLC0415
|
|
56
|
+
|
|
57
|
+
class DataRobotChatOpenAI(ChatOpenAI):
|
|
58
|
+
def _get_request_payload( # type: ignore[override]
|
|
59
|
+
self,
|
|
60
|
+
*args: Any,
|
|
61
|
+
**kwargs: Any,
|
|
62
|
+
) -> dict:
|
|
63
|
+
# We need to default to include_usage=True for streaming but we get 400 response
|
|
64
|
+
# if stream_options is present for a non-streaming call.
|
|
65
|
+
payload = super()._get_request_payload(*args, **kwargs)
|
|
66
|
+
if not payload.get("stream"):
|
|
67
|
+
payload.pop("stream_options", None)
|
|
68
|
+
return payload
|
|
69
|
+
|
|
70
|
+
return DataRobotChatOpenAI(**config)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _create_datarobot_litellm(config: dict[str, Any]) -> Any:
|
|
74
|
+
from llama_index.core.base.llms.types import LLMMetadata # noqa: PLC0415
|
|
75
|
+
from llama_index.llms.litellm import LiteLLM # noqa: PLC0415
|
|
76
|
+
|
|
77
|
+
class DataRobotLiteLLM(LiteLLM): # type: ignore[misc]
|
|
78
|
+
"""DataRobotLiteLLM is a small LiteLLM wrapper class that makes all LiteLLM endpoints
|
|
79
|
+
compatible with the LlamaIndex library.
|
|
81
80
|
"""
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def metadata(self) -> LLMMetadata:
|
|
84
|
+
"""Returns the metadata for the LLM.
|
|
85
|
+
|
|
86
|
+
This is required to enable the is_chat_model and is_function_calling_model, which are
|
|
87
|
+
mandatory for LlamaIndex agents. By default, LlamaIndex assumes these are false unless
|
|
88
|
+
each individual model config in LiteLLM explicitly sets them to true. To use custom LLM
|
|
89
|
+
endpoints with LlamaIndex agents, you must override this method to return the
|
|
90
|
+
appropriate metadata.
|
|
91
|
+
"""
|
|
92
|
+
return LLMMetadata(
|
|
93
|
+
context_window=128000,
|
|
94
|
+
num_output=self.max_tokens or -1,
|
|
95
|
+
is_chat_model=True,
|
|
96
|
+
is_function_calling_model=True,
|
|
97
|
+
model_name=self.model,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return DataRobotLiteLLM(**config)
|
|
89
101
|
|
|
90
102
|
|
|
91
103
|
@register_llm_client(
|
|
@@ -94,11 +106,15 @@ class DataRobotLiteLLM(LiteLLM): # type: ignore[misc]
|
|
|
94
106
|
async def datarobot_llm_gateway_langchain(
|
|
95
107
|
llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
|
|
96
108
|
) -> AsyncGenerator[ChatOpenAI]:
|
|
109
|
+
from nat.plugins.langchain.llm import ( # noqa: PLC0415
|
|
110
|
+
_patch_llm_based_on_config as langchain_patch_llm_based_on_config,
|
|
111
|
+
)
|
|
112
|
+
|
|
97
113
|
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
98
114
|
config["base_url"] = config["base_url"] + "/genai/llmgw"
|
|
99
115
|
config["stream_options"] = {"include_usage": True}
|
|
100
116
|
config["model"] = config["model"].removeprefix("datarobot/")
|
|
101
|
-
client =
|
|
117
|
+
client = _create_datarobot_chat_openai(config)
|
|
102
118
|
yield langchain_patch_llm_based_on_config(client, config)
|
|
103
119
|
|
|
104
120
|
|
|
@@ -108,6 +124,8 @@ async def datarobot_llm_gateway_langchain(
|
|
|
108
124
|
async def datarobot_llm_gateway_crewai(
|
|
109
125
|
llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
|
|
110
126
|
) -> AsyncGenerator[LLM]:
|
|
127
|
+
from crewai import LLM # noqa: PLC0415
|
|
128
|
+
|
|
111
129
|
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
112
130
|
if not config["model"].startswith("datarobot/"):
|
|
113
131
|
config["model"] = "datarobot/" + config["model"]
|
|
@@ -121,12 +139,12 @@ async def datarobot_llm_gateway_crewai(
|
|
|
121
139
|
)
|
|
122
140
|
async def datarobot_llm_gateway_llamaindex(
|
|
123
141
|
llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
|
|
124
|
-
) -> AsyncGenerator[
|
|
142
|
+
) -> AsyncGenerator[LiteLLM]:
|
|
125
143
|
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
126
144
|
if not config["model"].startswith("datarobot/"):
|
|
127
145
|
config["model"] = "datarobot/" + config["model"]
|
|
128
146
|
config["api_base"] = config.pop("base_url").removesuffix("/api/v2")
|
|
129
|
-
client =
|
|
147
|
+
client = _create_datarobot_litellm(config)
|
|
130
148
|
yield _patch_llm_based_on_config(client, config)
|
|
131
149
|
|
|
132
150
|
|
|
@@ -136,6 +154,10 @@ async def datarobot_llm_gateway_llamaindex(
|
|
|
136
154
|
async def datarobot_llm_deployment_langchain(
|
|
137
155
|
llm_config: DataRobotLLMDeploymentModelConfig, builder: Builder
|
|
138
156
|
) -> AsyncGenerator[ChatOpenAI]:
|
|
157
|
+
from nat.plugins.langchain.llm import ( # noqa: PLC0415
|
|
158
|
+
_patch_llm_based_on_config as langchain_patch_llm_based_on_config,
|
|
159
|
+
)
|
|
160
|
+
|
|
139
161
|
config = llm_config.model_dump(
|
|
140
162
|
exclude={"type", "thinking"},
|
|
141
163
|
by_alias=True,
|
|
@@ -143,7 +165,7 @@ async def datarobot_llm_deployment_langchain(
|
|
|
143
165
|
)
|
|
144
166
|
config["stream_options"] = {"include_usage": True}
|
|
145
167
|
config["model"] = config["model"].removeprefix("datarobot/")
|
|
146
|
-
client =
|
|
168
|
+
client = _create_datarobot_chat_openai(config)
|
|
147
169
|
yield langchain_patch_llm_based_on_config(client, config)
|
|
148
170
|
|
|
149
171
|
|
|
@@ -153,6 +175,8 @@ async def datarobot_llm_deployment_langchain(
|
|
|
153
175
|
async def datarobot_llm_deployment_crewai(
|
|
154
176
|
llm_config: DataRobotLLMDeploymentModelConfig, builder: Builder
|
|
155
177
|
) -> AsyncGenerator[LLM]:
|
|
178
|
+
from crewai import LLM # noqa: PLC0415
|
|
179
|
+
|
|
156
180
|
config = llm_config.model_dump(
|
|
157
181
|
exclude={"type", "thinking"},
|
|
158
182
|
by_alias=True,
|
|
@@ -170,7 +194,7 @@ async def datarobot_llm_deployment_crewai(
|
|
|
170
194
|
)
|
|
171
195
|
async def datarobot_llm_deployment_llamaindex(
|
|
172
196
|
llm_config: DataRobotLLMDeploymentModelConfig, builder: Builder
|
|
173
|
-
) -> AsyncGenerator[
|
|
197
|
+
) -> AsyncGenerator[LiteLLM]:
|
|
174
198
|
config = llm_config.model_dump(
|
|
175
199
|
exclude={"type", "thinking"},
|
|
176
200
|
by_alias=True,
|
|
@@ -179,7 +203,7 @@ async def datarobot_llm_deployment_llamaindex(
|
|
|
179
203
|
if not config["model"].startswith("datarobot/"):
|
|
180
204
|
config["model"] = "datarobot/" + config["model"]
|
|
181
205
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
182
|
-
client =
|
|
206
|
+
client = _create_datarobot_litellm(config)
|
|
183
207
|
yield _patch_llm_based_on_config(client, config)
|
|
184
208
|
|
|
185
209
|
|
|
@@ -187,6 +211,10 @@ async def datarobot_llm_deployment_llamaindex(
|
|
|
187
211
|
async def datarobot_nim_langchain(
|
|
188
212
|
llm_config: DataRobotNIMModelConfig, builder: Builder
|
|
189
213
|
) -> AsyncGenerator[ChatOpenAI]:
|
|
214
|
+
from nat.plugins.langchain.llm import ( # noqa: PLC0415
|
|
215
|
+
_patch_llm_based_on_config as langchain_patch_llm_based_on_config,
|
|
216
|
+
)
|
|
217
|
+
|
|
190
218
|
config = llm_config.model_dump(
|
|
191
219
|
exclude={"type", "thinking"},
|
|
192
220
|
by_alias=True,
|
|
@@ -194,7 +222,7 @@ async def datarobot_nim_langchain(
|
|
|
194
222
|
)
|
|
195
223
|
config["stream_options"] = {"include_usage": True}
|
|
196
224
|
config["model"] = config["model"].removeprefix("datarobot/")
|
|
197
|
-
client =
|
|
225
|
+
client = _create_datarobot_chat_openai(config)
|
|
198
226
|
yield langchain_patch_llm_based_on_config(client, config)
|
|
199
227
|
|
|
200
228
|
|
|
@@ -202,6 +230,8 @@ async def datarobot_nim_langchain(
|
|
|
202
230
|
async def datarobot_nim_crewai(
|
|
203
231
|
llm_config: DataRobotNIMModelConfig, builder: Builder
|
|
204
232
|
) -> AsyncGenerator[LLM]:
|
|
233
|
+
from crewai import LLM # noqa: PLC0415
|
|
234
|
+
|
|
205
235
|
config = llm_config.model_dump(
|
|
206
236
|
exclude={"type", "thinking", "max_retries"},
|
|
207
237
|
by_alias=True,
|
|
@@ -217,7 +247,7 @@ async def datarobot_nim_crewai(
|
|
|
217
247
|
@register_llm_client(config_type=DataRobotNIMModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX)
|
|
218
248
|
async def datarobot_nim_llamaindex(
|
|
219
249
|
llm_config: DataRobotNIMModelConfig, builder: Builder
|
|
220
|
-
) -> AsyncGenerator[
|
|
250
|
+
) -> AsyncGenerator[LiteLLM]:
|
|
221
251
|
config = llm_config.model_dump(
|
|
222
252
|
exclude={"type", "thinking"},
|
|
223
253
|
by_alias=True,
|
|
@@ -226,7 +256,7 @@ async def datarobot_nim_llamaindex(
|
|
|
226
256
|
if not config["model"].startswith("datarobot/"):
|
|
227
257
|
config["model"] = "datarobot/" + config["model"]
|
|
228
258
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
229
|
-
client =
|
|
259
|
+
client = _create_datarobot_litellm(config)
|
|
230
260
|
yield _patch_llm_based_on_config(client, config)
|
|
231
261
|
|
|
232
262
|
|
|
@@ -236,13 +266,17 @@ async def datarobot_nim_llamaindex(
|
|
|
236
266
|
async def datarobot_llm_component_langchain(
|
|
237
267
|
llm_config: DataRobotLLMComponentModelConfig, builder: Builder
|
|
238
268
|
) -> AsyncGenerator[ChatOpenAI]:
|
|
269
|
+
from nat.plugins.langchain.llm import ( # noqa: PLC0415
|
|
270
|
+
_patch_llm_based_on_config as langchain_patch_llm_based_on_config,
|
|
271
|
+
)
|
|
272
|
+
|
|
239
273
|
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
240
274
|
if config["use_datarobot_llm_gateway"]:
|
|
241
275
|
config["base_url"] = config["base_url"] + "/genai/llmgw"
|
|
242
276
|
config["stream_options"] = {"include_usage": True}
|
|
243
277
|
config["model"] = config["model"].removeprefix("datarobot/")
|
|
244
278
|
config.pop("use_datarobot_llm_gateway")
|
|
245
|
-
client =
|
|
279
|
+
client = _create_datarobot_chat_openai(config)
|
|
246
280
|
yield langchain_patch_llm_based_on_config(client, config)
|
|
247
281
|
|
|
248
282
|
|
|
@@ -252,6 +286,8 @@ async def datarobot_llm_component_langchain(
|
|
|
252
286
|
async def datarobot_llm_component_crewai(
|
|
253
287
|
llm_config: DataRobotLLMComponentModelConfig, builder: Builder
|
|
254
288
|
) -> AsyncGenerator[LLM]:
|
|
289
|
+
from crewai import LLM # noqa: PLC0415
|
|
290
|
+
|
|
255
291
|
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
256
292
|
if not config["model"].startswith("datarobot/"):
|
|
257
293
|
config["model"] = "datarobot/" + config["model"]
|
|
@@ -269,7 +305,7 @@ async def datarobot_llm_component_crewai(
|
|
|
269
305
|
)
|
|
270
306
|
async def datarobot_llm_component_llamaindex(
|
|
271
307
|
llm_config: DataRobotLLMComponentModelConfig, builder: Builder
|
|
272
|
-
) -> AsyncGenerator[
|
|
308
|
+
) -> AsyncGenerator[LiteLLM]:
|
|
273
309
|
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
274
310
|
if not config["model"].startswith("datarobot/"):
|
|
275
311
|
config["model"] = "datarobot/" + config["model"]
|
|
@@ -278,5 +314,5 @@ async def datarobot_llm_component_llamaindex(
|
|
|
278
314
|
else:
|
|
279
315
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
280
316
|
config.pop("use_datarobot_llm_gateway")
|
|
281
|
-
client =
|
|
317
|
+
client = _create_datarobot_litellm(config)
|
|
282
318
|
yield _patch_llm_based_on_config(client, config)
|