datarobot-genai 0.2.26__py3-none-any.whl → 0.2.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. datarobot_genai/core/cli/agent_kernel.py +4 -1
  2. datarobot_genai/drmcp/__init__.py +2 -2
  3. datarobot_genai/drmcp/core/config.py +121 -83
  4. datarobot_genai/drmcp/core/exceptions.py +0 -4
  5. datarobot_genai/drmcp/core/logging.py +2 -2
  6. datarobot_genai/drmcp/core/tool_config.py +17 -9
  7. datarobot_genai/drmcp/test_utils/clients/__init__.py +0 -0
  8. datarobot_genai/drmcp/test_utils/clients/anthropic.py +68 -0
  9. datarobot_genai/drmcp/test_utils/{openai_llm_mcp_client.py → clients/base.py} +38 -40
  10. datarobot_genai/drmcp/test_utils/clients/dr_gateway.py +58 -0
  11. datarobot_genai/drmcp/test_utils/clients/openai.py +68 -0
  12. datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +20 -0
  13. datarobot_genai/drmcp/test_utils/test_interactive.py +16 -16
  14. datarobot_genai/drmcp/test_utils/tool_base_ete.py +69 -2
  15. datarobot_genai/drmcp/test_utils/utils.py +1 -1
  16. datarobot_genai/drmcp/tools/clients/gdrive.py +314 -1
  17. datarobot_genai/drmcp/tools/clients/microsoft_graph.py +479 -0
  18. datarobot_genai/drmcp/tools/gdrive/tools.py +273 -4
  19. datarobot_genai/drmcp/tools/microsoft_graph/__init__.py +13 -0
  20. datarobot_genai/drmcp/tools/microsoft_graph/tools.py +198 -0
  21. datarobot_genai/drmcp/tools/predictive/data.py +16 -8
  22. datarobot_genai/drmcp/tools/predictive/model.py +87 -52
  23. datarobot_genai/drmcp/tools/predictive/project.py +2 -2
  24. datarobot_genai/drmcp/tools/predictive/training.py +15 -14
  25. datarobot_genai/nat/datarobot_llm_clients.py +90 -54
  26. datarobot_genai/nat/datarobot_mcp_client.py +47 -15
  27. {datarobot_genai-0.2.26.dist-info → datarobot_genai-0.2.34.dist-info}/METADATA +1 -1
  28. {datarobot_genai-0.2.26.dist-info → datarobot_genai-0.2.34.dist-info}/RECORD +32 -25
  29. {datarobot_genai-0.2.26.dist-info → datarobot_genai-0.2.34.dist-info}/WHEEL +0 -0
  30. {datarobot_genai-0.2.26.dist-info → datarobot_genai-0.2.34.dist-info}/entry_points.txt +0 -0
  31. {datarobot_genai-0.2.26.dist-info → datarobot_genai-0.2.34.dist-info}/licenses/AUTHORS +0 -0
  32. {datarobot_genai-0.2.26.dist-info → datarobot_genai-0.2.34.dist-info}/licenses/LICENSE +0 -0
@@ -14,9 +14,12 @@
14
14
 
15
15
  import json
16
16
  import logging
17
+ from typing import Annotated
17
18
  from typing import Any
18
19
 
19
20
  from datarobot.models.model import Model
21
+ from fastmcp.exceptions import ToolError
22
+ from fastmcp.tools.tool import ToolResult
20
23
 
21
24
  from datarobot_genai.drmcp.core.clients import get_sdk_client
22
25
  from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
@@ -50,33 +53,25 @@ class ModelEncoder(json.JSONEncoder):
50
53
  return super().default(obj)
51
54
 
52
55
 
53
- @dr_mcp_tool(tags={"model", "management", "info"})
54
- async def get_best_model(project_id: str, metric: str | None = None) -> str:
55
- """
56
- Get the best model for a DataRobot project, optionally by a specific metric.
56
+ @dr_mcp_tool(tags={"predictive", "model", "read", "management", "info"})
57
+ async def get_best_model(
58
+ *,
59
+ project_id: Annotated[str, "The DataRobot project ID"] | None = None,
60
+ metric: Annotated[str, "The metric to use for best model selection (e.g., 'AUC', 'LogLoss')"]
61
+ | None = None,
62
+ ) -> ToolError | ToolResult:
63
+ """Get the best model for a DataRobot project, optionally by a specific metric."""
64
+ if not project_id:
65
+ raise ToolError("Project ID must be provided")
57
66
 
58
- Args:
59
- project_id: The ID of the DataRobot project.
60
- metric: (Optional) The metric to use for best model selection (e.g., 'AUC', 'LogLoss').
61
-
62
- Returns
63
- -------
64
- A formatted string describing the best model.
65
-
66
- Raises
67
- ------
68
- Exception: If project not found or no models exist in the project.
69
- """
70
67
  client = get_sdk_client()
71
68
  project = client.Project.get(project_id)
72
69
  if not project:
73
- logger.error(f"Project with ID {project_id} not found")
74
- raise Exception(f"Project with ID {project_id} not found.")
70
+ raise ToolError(f"Project with ID {project_id} not found.")
75
71
 
76
72
  leaderboard = project.get_models()
77
73
  if not leaderboard:
78
- logger.info(f"No models found for project {project_id}")
79
- raise Exception("No models found for this project.")
74
+ raise ToolError("No models found for this project.")
80
75
 
81
76
  if metric:
82
77
  reverse_sort = metric.upper() in [
@@ -98,51 +93,91 @@ async def get_best_model(project_id: str, metric: str | None = None) -> str:
98
93
  best_model = leaderboard[0]
99
94
  logger.info(f"Found best model {best_model.id} for project {project_id}")
100
95
 
101
- # Format the response as a human-readable string
102
96
  metric_info = ""
97
+ metric_value = None
98
+
103
99
  if metric and best_model.metrics and metric in best_model.metrics:
104
100
  metric_value = best_model.metrics[metric].get("validation")
105
101
  if metric_value is not None:
106
102
  metric_info = f" with {metric}: {metric_value:.2f}"
107
103
 
108
- return f"Best model: {best_model.model_type}{metric_info}"
109
-
110
-
111
- @dr_mcp_tool(tags={"model", "prediction", "scoring"})
112
- async def score_dataset_with_model(project_id: str, model_id: str, dataset_url: str) -> str:
113
- """
114
- Score a dataset using a specific DataRobot model.
104
+ # Include full metrics in the response
105
+ best_model_dict = model_to_dict(best_model)
106
+ best_model_dict["metric"] = metric
107
+ best_model_dict["metric_value"] = metric_value
108
+
109
+ # Format metrics for human-readable content
110
+ metrics_text = ""
111
+ if best_model.metrics:
112
+ metrics_list = []
113
+ for metric_name, metric_data in best_model.metrics.items():
114
+ if isinstance(metric_data, dict) and "validation" in metric_data:
115
+ val = metric_data["validation"]
116
+ if val is not None:
117
+ metrics_list.append(f"{metric_name}: {val:.4f}")
118
+ if metrics_list:
119
+ metrics_text = "\nPerformance metrics:\n" + "\n".join(f" - {m}" for m in metrics_list)
120
+
121
+ return ToolResult(
122
+ content=f"Best model: {best_model.model_type}{metric_info}{metrics_text}",
123
+ structured_content={
124
+ "project_id": project_id,
125
+ "best_model": best_model_dict,
126
+ },
127
+ )
128
+
129
+
130
+ @dr_mcp_tool(tags={"predictive", "model", "read", "scoring", "dataset"})
131
+ async def score_dataset_with_model(
132
+ *,
133
+ project_id: Annotated[str, "The DataRobot project ID"] | None = None,
134
+ model_id: Annotated[str, "The DataRobot model ID"] | None = None,
135
+ dataset_url: Annotated[str, "The dataset URL"] | None = None,
136
+ ) -> ToolError | ToolResult:
137
+ """Score a dataset using a specific DataRobot model."""
138
+ if not project_id:
139
+ raise ToolError("Project ID must be provided")
140
+ if not model_id:
141
+ raise ToolError("Model ID must be provided")
142
+ if not dataset_url:
143
+ raise ToolError("Dataset URL must be provided")
115
144
 
116
- Args:
117
- project_id: The ID of the DataRobot project.
118
- model_id: The ID of the DataRobot model to use for scoring.
119
- dataset_url: The URL to the dataset to score (must be accessible to DataRobot).
120
-
121
- Returns
122
- -------
123
- A string summary of the scoring job or a meaningful error message.
124
- """
125
145
  client = get_sdk_client()
126
146
  project = client.Project.get(project_id)
127
147
  model = client.Model.get(project, model_id)
128
148
  job = model.score(dataset_url)
129
- logger.info(f"Started scoring job {job.id} for model {model_id}")
130
- return f"Scoring job started: {job.id}"
131
-
132
149
 
133
- @dr_mcp_tool(tags={"model", "management", "list"})
134
- async def list_models(project_id: str) -> str:
135
- """
136
- List all models in a project.
150
+ return ToolResult(
151
+ content=f"Scoring job started: {job.id}",
152
+ structured_content={
153
+ "scoring_job_id": job.id,
154
+ "project_id": project_id,
155
+ "model_id": model_id,
156
+ "dataset_url": dataset_url,
157
+ },
158
+ )
159
+
160
+
161
+ @dr_mcp_tool(tags={"predictive", "model", "read", "management", "list"})
162
+ async def list_models(
163
+ *,
164
+ project_id: Annotated[str, "The DataRobot project ID"] | None = None,
165
+ ) -> ToolError | ToolResult:
166
+ """List all models in a project."""
167
+ if not project_id:
168
+ raise ToolError("Project ID must be provided")
137
169
 
138
- Args:
139
- project_id: The ID of the DataRobot project.
140
-
141
- Returns
142
- -------
143
- A string summary of the models in the project.
144
- """
145
170
  client = get_sdk_client()
146
171
  project = client.Project.get(project_id)
147
172
  models = project.get_models()
148
- return json.dumps(models, indent=2, cls=ModelEncoder)
173
+
174
+ return ToolResult(
175
+ content=(
176
+ f"Found {len(models)} models in project {project_id}, here are the details:\n"
177
+ f"{json.dumps(models, indent=2, cls=ModelEncoder)}"
178
+ ),
179
+ structured_content={
180
+ "project_id": project_id,
181
+ "models": [model_to_dict(model) for model in models],
182
+ },
183
+ )
@@ -54,9 +54,9 @@ async def get_project_dataset_by_name(
54
54
  The dataset ID and the dataset type (source or prediction) as a string, or an error message.
55
55
  """
56
56
  if not project_id:
57
- return ToolError("Project ID is required.")
57
+ raise ToolError("Project ID is required.")
58
58
  if not dataset_name:
59
- return ToolError("Dataset name is required.")
59
+ raise ToolError("Dataset name is required.")
60
60
 
61
61
  client = get_sdk_client()
62
62
  project = client.Project.get(project_id)
@@ -63,7 +63,7 @@ async def analyze_dataset(
63
63
  ) -> ToolError | ToolResult:
64
64
  """Analyze a dataset to understand its structure and potential use cases."""
65
65
  if not dataset_id:
66
- return ToolError("Dataset ID must be provided")
66
+ raise ToolError("Dataset ID must be provided")
67
67
 
68
68
  client = get_sdk_client()
69
69
  dataset = client.Dataset.get(dataset_id)
@@ -116,7 +116,7 @@ async def suggest_use_cases(
116
116
  ) -> ToolError | ToolResult:
117
117
  """Analyze a dataset and suggest potential machine learning use cases."""
118
118
  if not dataset_id:
119
- return ToolError("Dataset ID must be provided")
119
+ raise ToolError("Dataset ID must be provided")
120
120
 
121
121
  client = get_sdk_client()
122
122
  dataset = client.Dataset.get(dataset_id)
@@ -148,7 +148,7 @@ async def get_exploratory_insights(
148
148
  ) -> ToolError | ToolResult:
149
149
  """Generate exploratory data insights for a dataset."""
150
150
  if not dataset_id:
151
- return ToolError("Dataset ID must be provided")
151
+ raise ToolError("Dataset ID must be provided")
152
152
 
153
153
  client = get_sdk_client()
154
154
  dataset = client.Dataset.get(dataset_id)
@@ -481,9 +481,9 @@ async def start_autopilot(
481
481
 
482
482
  if not project_id:
483
483
  if not dataset_url and not dataset_id:
484
- return ToolError("Either dataset_url or dataset_id must be provided")
484
+ raise ToolError("Either dataset_url or dataset_id must be provided")
485
485
  if dataset_url and dataset_id:
486
- return ToolError("Please provide either dataset_url or dataset_id, not both")
486
+ raise ToolError("Please provide either dataset_url or dataset_id, not both")
487
487
 
488
488
  if dataset_url:
489
489
  dataset = client.Dataset.create_from_url(dataset_url)
@@ -497,7 +497,7 @@ async def start_autopilot(
497
497
  project = client.Project.get(project_id)
498
498
 
499
499
  if not target:
500
- return ToolError("Target variable must be specified")
500
+ raise ToolError("Target variable must be specified")
501
501
 
502
502
  try:
503
503
  # Start modeling
@@ -517,7 +517,7 @@ async def start_autopilot(
517
517
  )
518
518
 
519
519
  except Exception as e:
520
- return ToolError(
520
+ raise ToolError(
521
521
  content=json.dumps(
522
522
  {
523
523
  "error": f"Failed to start Autopilot: {str(e)}",
@@ -546,9 +546,9 @@ async def get_model_roc_curve(
546
546
  ) -> ToolError | ToolResult:
547
547
  """Get detailed ROC curve for a specific model."""
548
548
  if not project_id:
549
- return ToolError("Project ID must be provided")
549
+ raise ToolError("Project ID must be provided")
550
550
  if not model_id:
551
- return ToolError("Model ID must be provided")
551
+ raise ToolError("Model ID must be provided")
552
552
 
553
553
  client = get_sdk_client()
554
554
  project = client.Project.get(project_id)
@@ -587,7 +587,7 @@ async def get_model_roc_curve(
587
587
  structured_content={"data": roc_data},
588
588
  )
589
589
  except Exception as e:
590
- return ToolError(f"Failed to get ROC curve: {str(e)}")
590
+ raise ToolError(f"Failed to get ROC curve: {str(e)}")
591
591
 
592
592
 
593
593
  @dr_mcp_tool(tags={"predictive", "training", "read", "model", "evaluation"})
@@ -598,9 +598,9 @@ async def get_model_feature_impact(
598
598
  ) -> ToolError | ToolResult:
599
599
  """Get detailed feature impact for a specific model."""
600
600
  if not project_id:
601
- return ToolError("Project ID must be provided")
601
+ raise ToolError("Project ID must be provided")
602
602
  if not model_id:
603
- return ToolError("Model ID must be provided")
603
+ raise ToolError("Model ID must be provided")
604
604
 
605
605
  client = get_sdk_client()
606
606
  project = client.Project.get(project_id)
@@ -617,6 +617,7 @@ async def get_model_feature_impact(
617
617
 
618
618
  @dr_mcp_tool(tags={"predictive", "training", "read", "model", "evaluation"})
619
619
  async def get_model_lift_chart(
620
+ *,
620
621
  project_id: Annotated[str, "The ID of the DataRobot project"] | None = None,
621
622
  model_id: Annotated[str, "The ID of the model to analyze"] | None = None,
622
623
  source: Annotated[
@@ -630,9 +631,9 @@ async def get_model_lift_chart(
630
631
  ) -> ToolError | ToolResult:
631
632
  """Get detailed lift chart for a specific model."""
632
633
  if not project_id:
633
- return ToolError("Project ID must be provided")
634
+ raise ToolError("Project ID must be provided")
634
635
  if not model_id:
635
- return ToolError("Model ID must be provided")
636
+ raise ToolError("Model ID must be provided")
636
637
 
637
638
  client = get_sdk_client()
638
639
  project = client.Project.get(project_id)
@@ -12,22 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from __future__ import annotations
16
+
15
17
  from collections.abc import AsyncGenerator
18
+ from typing import TYPE_CHECKING
16
19
  from typing import Any
17
20
  from typing import TypeVar
18
21
 
19
- from crewai import LLM
20
- from langchain_openai import ChatOpenAI
21
- from llama_index.core.base.llms.types import LLMMetadata
22
- from llama_index.llms.litellm import LiteLLM
23
22
  from nat.builder.builder import Builder
24
23
  from nat.builder.framework_enum import LLMFrameworkEnum
25
24
  from nat.cli.register_workflow import register_llm_client
26
25
  from nat.data_models.llm import LLMBaseConfig
27
26
  from nat.data_models.retry_mixin import RetryMixin
28
- from nat.plugins.langchain.llm import (
29
- _patch_llm_based_on_config as langchain_patch_llm_based_on_config,
30
- )
31
27
  from nat.utils.exception_handlers.automatic_retries import patch_with_retry
32
28
 
33
29
  from ..nat.datarobot_llm_providers import DataRobotLLMComponentModelConfig
@@ -35,6 +31,11 @@ from ..nat.datarobot_llm_providers import DataRobotLLMDeploymentModelConfig
35
31
  from ..nat.datarobot_llm_providers import DataRobotLLMGatewayModelConfig
36
32
  from ..nat.datarobot_llm_providers import DataRobotNIMModelConfig
37
33
 
34
+ if TYPE_CHECKING:
35
+ from crewai import LLM
36
+ from langchain_openai import ChatOpenAI
37
+ from llama_index.llms.litellm import LiteLLM
38
+
38
39
  ModelType = TypeVar("ModelType")
39
40
 
40
41
 
@@ -50,42 +51,53 @@ def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) ->
50
51
  return client
51
52
 
52
53
 
53
- class DataRobotChatOpenAI(ChatOpenAI):
54
- def _get_request_payload(
55
- self,
56
- *args: Any,
57
- **kwargs: Any,
58
- ) -> dict:
59
- # We need to default to include_usage=True for streaming but we get 400 response
60
- # if stream_options is present for a non-streaming call.
61
- payload = super()._get_request_payload(*args, **kwargs)
62
- if not payload.get("stream"):
63
- payload.pop("stream_options", None)
64
- return payload
65
-
66
-
67
- class DataRobotLiteLLM(LiteLLM): # type: ignore[misc]
68
- """DataRobotLiteLLM is a small LiteLLM wrapper class that makes all LiteLLM endpoints
69
- compatible with the LlamaIndex library.
70
- """
71
-
72
- @property
73
- def metadata(self) -> LLMMetadata:
74
- """Returns the metadata for the LLM.
75
-
76
- This is required to enable the is_chat_model and is_function_calling_model, which are
77
- mandatory for LlamaIndex agents. By default, LlamaIndex assumes these are false unless each
78
- individual model config in LiteLLM explicitly sets them to true. To use custom LLM
79
- endpoints with LlamaIndex agents, you must override this method to return the appropriate
80
- metadata.
54
+ def _create_datarobot_chat_openai(config: dict[str, Any]) -> Any:
55
+ from langchain_openai import ChatOpenAI # noqa: PLC0415
56
+
57
+ class DataRobotChatOpenAI(ChatOpenAI):
58
+ def _get_request_payload( # type: ignore[override]
59
+ self,
60
+ *args: Any,
61
+ **kwargs: Any,
62
+ ) -> dict:
63
+ # We need to default to include_usage=True for streaming but we get 400 response
64
+ # if stream_options is present for a non-streaming call.
65
+ payload = super()._get_request_payload(*args, **kwargs)
66
+ if not payload.get("stream"):
67
+ payload.pop("stream_options", None)
68
+ return payload
69
+
70
+ return DataRobotChatOpenAI(**config)
71
+
72
+
73
+ def _create_datarobot_litellm(config: dict[str, Any]) -> Any:
74
+ from llama_index.core.base.llms.types import LLMMetadata # noqa: PLC0415
75
+ from llama_index.llms.litellm import LiteLLM # noqa: PLC0415
76
+
77
+ class DataRobotLiteLLM(LiteLLM): # type: ignore[misc]
78
+ """DataRobotLiteLLM is a small LiteLLM wrapper class that makes all LiteLLM endpoints
79
+ compatible with the LlamaIndex library.
81
80
  """
82
- return LLMMetadata(
83
- context_window=128000,
84
- num_output=self.max_tokens or -1,
85
- is_chat_model=True,
86
- is_function_calling_model=True,
87
- model_name=self.model,
88
- )
81
+
82
+ @property
83
+ def metadata(self) -> LLMMetadata:
84
+ """Returns the metadata for the LLM.
85
+
86
+ This is required to enable the is_chat_model and is_function_calling_model, which are
87
+ mandatory for LlamaIndex agents. By default, LlamaIndex assumes these are false unless
88
+ each individual model config in LiteLLM explicitly sets them to true. To use custom LLM
89
+ endpoints with LlamaIndex agents, you must override this method to return the
90
+ appropriate metadata.
91
+ """
92
+ return LLMMetadata(
93
+ context_window=128000,
94
+ num_output=self.max_tokens or -1,
95
+ is_chat_model=True,
96
+ is_function_calling_model=True,
97
+ model_name=self.model,
98
+ )
99
+
100
+ return DataRobotLiteLLM(**config)
89
101
 
90
102
 
91
103
  @register_llm_client(
@@ -94,11 +106,15 @@ class DataRobotLiteLLM(LiteLLM): # type: ignore[misc]
94
106
  async def datarobot_llm_gateway_langchain(
95
107
  llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
96
108
  ) -> AsyncGenerator[ChatOpenAI]:
109
+ from nat.plugins.langchain.llm import ( # noqa: PLC0415
110
+ _patch_llm_based_on_config as langchain_patch_llm_based_on_config,
111
+ )
112
+
97
113
  config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
98
114
  config["base_url"] = config["base_url"] + "/genai/llmgw"
99
115
  config["stream_options"] = {"include_usage": True}
100
116
  config["model"] = config["model"].removeprefix("datarobot/")
101
- client = DataRobotChatOpenAI(**config)
117
+ client = _create_datarobot_chat_openai(config)
102
118
  yield langchain_patch_llm_based_on_config(client, config)
103
119
 
104
120
 
@@ -108,6 +124,8 @@ async def datarobot_llm_gateway_langchain(
108
124
  async def datarobot_llm_gateway_crewai(
109
125
  llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
110
126
  ) -> AsyncGenerator[LLM]:
127
+ from crewai import LLM # noqa: PLC0415
128
+
111
129
  config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
112
130
  if not config["model"].startswith("datarobot/"):
113
131
  config["model"] = "datarobot/" + config["model"]
@@ -121,12 +139,12 @@ async def datarobot_llm_gateway_crewai(
121
139
  )
122
140
  async def datarobot_llm_gateway_llamaindex(
123
141
  llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
124
- ) -> AsyncGenerator[LLM]:
142
+ ) -> AsyncGenerator[LiteLLM]:
125
143
  config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
126
144
  if not config["model"].startswith("datarobot/"):
127
145
  config["model"] = "datarobot/" + config["model"]
128
146
  config["api_base"] = config.pop("base_url").removesuffix("/api/v2")
129
- client = DataRobotLiteLLM(**config)
147
+ client = _create_datarobot_litellm(config)
130
148
  yield _patch_llm_based_on_config(client, config)
131
149
 
132
150
 
@@ -136,6 +154,10 @@ async def datarobot_llm_gateway_llamaindex(
136
154
  async def datarobot_llm_deployment_langchain(
137
155
  llm_config: DataRobotLLMDeploymentModelConfig, builder: Builder
138
156
  ) -> AsyncGenerator[ChatOpenAI]:
157
+ from nat.plugins.langchain.llm import ( # noqa: PLC0415
158
+ _patch_llm_based_on_config as langchain_patch_llm_based_on_config,
159
+ )
160
+
139
161
  config = llm_config.model_dump(
140
162
  exclude={"type", "thinking"},
141
163
  by_alias=True,
@@ -143,7 +165,7 @@ async def datarobot_llm_deployment_langchain(
143
165
  )
144
166
  config["stream_options"] = {"include_usage": True}
145
167
  config["model"] = config["model"].removeprefix("datarobot/")
146
- client = DataRobotChatOpenAI(**config)
168
+ client = _create_datarobot_chat_openai(config)
147
169
  yield langchain_patch_llm_based_on_config(client, config)
148
170
 
149
171
 
@@ -153,6 +175,8 @@ async def datarobot_llm_deployment_langchain(
153
175
  async def datarobot_llm_deployment_crewai(
154
176
  llm_config: DataRobotLLMDeploymentModelConfig, builder: Builder
155
177
  ) -> AsyncGenerator[LLM]:
178
+ from crewai import LLM # noqa: PLC0415
179
+
156
180
  config = llm_config.model_dump(
157
181
  exclude={"type", "thinking"},
158
182
  by_alias=True,
@@ -170,7 +194,7 @@ async def datarobot_llm_deployment_crewai(
170
194
  )
171
195
  async def datarobot_llm_deployment_llamaindex(
172
196
  llm_config: DataRobotLLMDeploymentModelConfig, builder: Builder
173
- ) -> AsyncGenerator[LLM]:
197
+ ) -> AsyncGenerator[LiteLLM]:
174
198
  config = llm_config.model_dump(
175
199
  exclude={"type", "thinking"},
176
200
  by_alias=True,
@@ -179,7 +203,7 @@ async def datarobot_llm_deployment_llamaindex(
179
203
  if not config["model"].startswith("datarobot/"):
180
204
  config["model"] = "datarobot/" + config["model"]
181
205
  config["api_base"] = config.pop("base_url") + "/chat/completions"
182
- client = DataRobotLiteLLM(**config)
206
+ client = _create_datarobot_litellm(config)
183
207
  yield _patch_llm_based_on_config(client, config)
184
208
 
185
209
 
@@ -187,6 +211,10 @@ async def datarobot_llm_deployment_llamaindex(
187
211
  async def datarobot_nim_langchain(
188
212
  llm_config: DataRobotNIMModelConfig, builder: Builder
189
213
  ) -> AsyncGenerator[ChatOpenAI]:
214
+ from nat.plugins.langchain.llm import ( # noqa: PLC0415
215
+ _patch_llm_based_on_config as langchain_patch_llm_based_on_config,
216
+ )
217
+
190
218
  config = llm_config.model_dump(
191
219
  exclude={"type", "thinking"},
192
220
  by_alias=True,
@@ -194,7 +222,7 @@ async def datarobot_nim_langchain(
194
222
  )
195
223
  config["stream_options"] = {"include_usage": True}
196
224
  config["model"] = config["model"].removeprefix("datarobot/")
197
- client = DataRobotChatOpenAI(**config)
225
+ client = _create_datarobot_chat_openai(config)
198
226
  yield langchain_patch_llm_based_on_config(client, config)
199
227
 
200
228
 
@@ -202,6 +230,8 @@ async def datarobot_nim_langchain(
202
230
  async def datarobot_nim_crewai(
203
231
  llm_config: DataRobotNIMModelConfig, builder: Builder
204
232
  ) -> AsyncGenerator[LLM]:
233
+ from crewai import LLM # noqa: PLC0415
234
+
205
235
  config = llm_config.model_dump(
206
236
  exclude={"type", "thinking", "max_retries"},
207
237
  by_alias=True,
@@ -217,7 +247,7 @@ async def datarobot_nim_crewai(
217
247
  @register_llm_client(config_type=DataRobotNIMModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX)
218
248
  async def datarobot_nim_llamaindex(
219
249
  llm_config: DataRobotNIMModelConfig, builder: Builder
220
- ) -> AsyncGenerator[LLM]:
250
+ ) -> AsyncGenerator[LiteLLM]:
221
251
  config = llm_config.model_dump(
222
252
  exclude={"type", "thinking"},
223
253
  by_alias=True,
@@ -226,7 +256,7 @@ async def datarobot_nim_llamaindex(
226
256
  if not config["model"].startswith("datarobot/"):
227
257
  config["model"] = "datarobot/" + config["model"]
228
258
  config["api_base"] = config.pop("base_url") + "/chat/completions"
229
- client = DataRobotLiteLLM(**config)
259
+ client = _create_datarobot_litellm(config)
230
260
  yield _patch_llm_based_on_config(client, config)
231
261
 
232
262
 
@@ -236,13 +266,17 @@ async def datarobot_nim_llamaindex(
236
266
  async def datarobot_llm_component_langchain(
237
267
  llm_config: DataRobotLLMComponentModelConfig, builder: Builder
238
268
  ) -> AsyncGenerator[ChatOpenAI]:
269
+ from nat.plugins.langchain.llm import ( # noqa: PLC0415
270
+ _patch_llm_based_on_config as langchain_patch_llm_based_on_config,
271
+ )
272
+
239
273
  config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
240
274
  if config["use_datarobot_llm_gateway"]:
241
275
  config["base_url"] = config["base_url"] + "/genai/llmgw"
242
276
  config["stream_options"] = {"include_usage": True}
243
277
  config["model"] = config["model"].removeprefix("datarobot/")
244
278
  config.pop("use_datarobot_llm_gateway")
245
- client = DataRobotChatOpenAI(**config)
279
+ client = _create_datarobot_chat_openai(config)
246
280
  yield langchain_patch_llm_based_on_config(client, config)
247
281
 
248
282
 
@@ -252,6 +286,8 @@ async def datarobot_llm_component_langchain(
252
286
  async def datarobot_llm_component_crewai(
253
287
  llm_config: DataRobotLLMComponentModelConfig, builder: Builder
254
288
  ) -> AsyncGenerator[LLM]:
289
+ from crewai import LLM # noqa: PLC0415
290
+
255
291
  config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
256
292
  if not config["model"].startswith("datarobot/"):
257
293
  config["model"] = "datarobot/" + config["model"]
@@ -269,7 +305,7 @@ async def datarobot_llm_component_crewai(
269
305
  )
270
306
  async def datarobot_llm_component_llamaindex(
271
307
  llm_config: DataRobotLLMComponentModelConfig, builder: Builder
272
- ) -> AsyncGenerator[LLM]:
308
+ ) -> AsyncGenerator[LiteLLM]:
273
309
  config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
274
310
  if not config["model"].startswith("datarobot/"):
275
311
  config["model"] = "datarobot/" + config["model"]
@@ -278,5 +314,5 @@ async def datarobot_llm_component_llamaindex(
278
314
  else:
279
315
  config["api_base"] = config.pop("base_url") + "/chat/completions"
280
316
  config.pop("use_datarobot_llm_gateway")
281
- client = DataRobotLiteLLM(**config)
317
+ client = _create_datarobot_litellm(config)
282
318
  yield _patch_llm_based_on_config(client, config)