datarobot-genai 0.2.11__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,6 +14,7 @@
14
14
 
15
15
  import logging
16
16
  from typing import Annotated
17
+ from typing import Any
17
18
 
18
19
  from fastmcp.exceptions import ToolError
19
20
  from fastmcp.tools.tool import ToolResult
@@ -25,6 +26,40 @@ from datarobot_genai.drmcp.tools.clients.jira import JiraClient
25
26
  logger = logging.getLogger(__name__)
26
27
 
27
28
 
29
+ @dr_mcp_tool(tags={"jira", "search", "issues"})
30
+ async def jira_search_issues(
31
+ *,
32
+ jql_query: Annotated[
33
+ str, "The JQL (Jira Query Language) string used to filter and search for issues."
34
+ ],
35
+ max_results: Annotated[int, "Maximum number of issues to return. Default is 50."] = 50,
36
+ ) -> ToolResult:
37
+ """
38
+ Search for Jira issues using a powerful JQL query string.
39
+
40
+ Refer to JQL documentation for advanced query construction:
41
+ JQL functions: https://support.atlassian.com/jira-service-management-cloud/docs/jql-functions/
42
+ JQL fields: https://support.atlassian.com/jira-service-management-cloud/docs/jql-fields/
43
+ JQL keywords: https://support.atlassian.com/jira-service-management-cloud/docs/use-advanced-search-with-jira-query-language-jql/
44
+ JQL operators: https://support.atlassian.com/jira-service-management-cloud/docs/jql-operators/
45
+ """
46
+ if not jql_query:
47
+ raise ToolError("Argument validation error: 'jql_query' cannot be empty.")
48
+
49
+ access_token = await get_atlassian_access_token()
50
+ if isinstance(access_token, ToolError):
51
+ raise access_token
52
+
53
+ async with JiraClient(access_token) as client:
54
+ issues = await client.search_jira_issues(jql_query=jql_query, max_results=max_results)
55
+
56
+ n = len(issues)
57
+ return ToolResult(
58
+ content=f"Successfully executed JQL query and retrieved {n} issue(s).",
59
+ structured_content={"data": [issue.as_flat_dict() for issue in issues], "count": n},
60
+ )
61
+
62
+
28
63
  @dr_mcp_tool(tags={"jira", "read", "get", "issue"})
29
64
  async def jira_get_issue(
30
65
  *, issue_key: Annotated[str, "The key (ID) of the Jira issue to retrieve, e.g., 'PROJ-123'."]
@@ -41,7 +76,7 @@ async def jira_get_issue(
41
76
  async with JiraClient(access_token) as client:
42
77
  issue = await client.get_jira_issue(issue_key)
43
78
  except Exception as e:
44
- logger.error(f"Unexpected error getting Jira issue: {e}")
79
+ logger.error(f"Unexpected error while getting Jira issue: {e}")
45
80
  raise ToolError(
46
81
  f"An unexpected error occurred while getting Jira issue '{issue_key}': {str(e)}"
47
82
  )
@@ -50,3 +85,159 @@ async def jira_get_issue(
50
85
  content=f"Successfully retrieved details for issue '{issue_key}'.",
51
86
  structured_content=issue.as_flat_dict(),
52
87
  )
88
+
89
+
90
+ @dr_mcp_tool(tags={"jira", "create", "add", "issue"})
91
+ async def jira_create_issue(
92
+ *,
93
+ project_key: Annotated[str, "The key of the project where the issue should be created."],
94
+ summary: Annotated[str, "A brief summary or title for the new issue."],
95
+ issue_type: Annotated[str, "The type of issue to create (e.g., 'Task', 'Bug', 'Story')."],
96
+ description: Annotated[str | None, "Detailed description of the issue."] = None,
97
+ ) -> ToolResult:
98
+ """Create a new Jira issue with mandatory project, summary, and type information."""
99
+ if not all([project_key, summary, issue_type]):
100
+ raise ToolError(
101
+ "Argument validation error: project_key, summary, and issue_type are required fields."
102
+ )
103
+
104
+ access_token = await get_atlassian_access_token()
105
+ if isinstance(access_token, ToolError):
106
+ raise access_token
107
+
108
+ async with JiraClient(access_token) as client:
109
+ # Maybe we should cache it somehow?
110
+ # It'll be probably constant through whole mcp server lifecycle...
111
+ issue_types = await client.get_jira_issue_types(project_key=project_key)
112
+
113
+ try:
114
+ issue_type_id = issue_types[issue_type]
115
+ except KeyError:
116
+ possible_issue_types = ",".join(issue_types)
117
+ raise ToolError(
118
+ f"Unexpected issue type `{issue_type}`. Possible values are {possible_issue_types}."
119
+ )
120
+
121
+ try:
122
+ async with JiraClient(access_token) as client:
123
+ issue_key = await client.create_jira_issue(
124
+ project_key=project_key,
125
+ summary=summary,
126
+ issue_type_id=issue_type_id,
127
+ description=description,
128
+ )
129
+ except Exception as e:
130
+ logger.error(f"Unexpected error while creating Jira issue: {e}")
131
+ raise ToolError(f"An unexpected error occurred while creating Jira issue: {str(e)}")
132
+
133
+ return ToolResult(
134
+ content=f"Successfully created issue '{issue_key}'.",
135
+ structured_content={"newIssueKey": issue_key, "projectKey": project_key},
136
+ )
137
+
138
+
139
+ @dr_mcp_tool(tags={"jira", "update", "edit", "issue"})
140
+ async def jira_update_issue(
141
+ *,
142
+ issue_key: Annotated[str, "The key (ID) of the Jira issue to retrieve, e.g., 'PROJ-123'."],
143
+ fields_to_update: Annotated[
144
+ dict[str, Any],
145
+ "A dictionary of field names and their new values (e.g., {'summary': 'New content'}).",
146
+ ],
147
+ ) -> ToolResult:
148
+ """
149
+ Modify descriptive fields or custom fields on an existing Jira issue using its key.
150
+ If you want to update issue status you should use `jira_transition_issue` tool instead.
151
+
152
+ Some fields needs very specific schema to allow update.
153
+ You should follow jira rest api guidance.
154
+ Good example is description field:
155
+ "description": {
156
+ "type": "text",
157
+ "version": 1,
158
+ "text": [
159
+ {
160
+ "type": "paragraph",
161
+ "content": [
162
+ {
163
+ "type": "text",
164
+ "text": "[HERE YOU PUT REAL DESCRIPTION]"
165
+ }
166
+ ]
167
+ }
168
+ ]
169
+ }
170
+ """
171
+ if not issue_key:
172
+ raise ToolError("Argument validation error: 'issue_key' cannot be empty.")
173
+ if not fields_to_update or not isinstance(fields_to_update, dict):
174
+ raise ToolError(
175
+ "Argument validation error: 'fields_to_update' must be a non-empty dictionary."
176
+ )
177
+
178
+ access_token = await get_atlassian_access_token()
179
+ if isinstance(access_token, ToolError):
180
+ raise access_token
181
+
182
+ try:
183
+ async with JiraClient(access_token) as client:
184
+ updated_fields = await client.update_jira_issue(
185
+ issue_key=issue_key, fields=fields_to_update
186
+ )
187
+ except Exception as e:
188
+ logger.error(f"Unexpected error while updating Jira issue: {e}")
189
+ raise ToolError(f"An unexpected error occurred while updating Jira issue: {str(e)}")
190
+
191
+ updated_fields_str = ",".join(updated_fields)
192
+ return ToolResult(
193
+ content=f"Successfully updated issue '{issue_key}'. Fields modified: {updated_fields_str}.",
194
+ structured_content={"updatedIssueKey": issue_key, "fields": updated_fields},
195
+ )
196
+
197
+
198
+ @dr_mcp_tool(tags={"jira", "update", "transition", "issue"})
199
+ async def jira_transition_issue(
200
+ *,
201
+ issue_key: Annotated[str, "The key (ID) of the Jira issue to transition, e.g. 'PROJ-123'."],
202
+ transition_name: Annotated[
203
+ str, "The exact name of the target status/transition (e.g., 'In Progress')."
204
+ ],
205
+ ) -> ToolResult:
206
+ """
207
+ Move a Jira issue through its defined workflow to a new status.
208
+ This leverages Jira's workflow engine directly.
209
+ """
210
+ if not all([issue_key, transition_name]):
211
+ raise ToolError("Argument validation error: issue_key and transition name/ID are required.")
212
+
213
+ access_token = await get_atlassian_access_token()
214
+ if isinstance(access_token, ToolError):
215
+ raise access_token
216
+
217
+ async with JiraClient(access_token) as client:
218
+ available_transitions = await client.get_available_jira_transitions(issue_key=issue_key)
219
+
220
+ try:
221
+ transition_id = available_transitions[transition_name]
222
+ except KeyError:
223
+ available_transitions_str = ",".join(available_transitions)
224
+ raise ToolError(
225
+ f"Unexpected transition name `{transition_name}`. "
226
+ f"Possible values are {available_transitions_str}."
227
+ )
228
+
229
+ try:
230
+ async with JiraClient(access_token) as client:
231
+ await client.transition_jira_issue(issue_key=issue_key, transition_id=transition_id)
232
+ except Exception as e:
233
+ logger.error(f"Unexpected error while transitioning Jira issue: {e}")
234
+ raise ToolError(f"An unexpected error occurred while transitioning Jira issue: {str(e)}")
235
+
236
+ return ToolResult(
237
+ content=f"Successfully transitioned issue '{issue_key}' to status '{transition_name}'.",
238
+ structured_content={
239
+ "transitionedIssueKey": issue_key,
240
+ "newStatusName": transition_name,
241
+ "newStatusId": transition_id,
242
+ },
243
+ )
@@ -14,51 +14,79 @@
14
14
 
15
15
  import logging
16
16
  import os
17
+ from typing import Annotated
18
+
19
+ from fastmcp.exceptions import ToolError
20
+ from fastmcp.tools.tool import ToolResult
17
21
 
18
22
  from datarobot_genai.drmcp.core.clients import get_sdk_client
19
23
  from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
24
+ from datarobot_genai.drmcp.core.utils import is_valid_url
20
25
 
21
26
  logger = logging.getLogger(__name__)
22
27
 
23
28
 
24
- @dr_mcp_tool(tags={"data", "management", "upload"})
25
- async def upload_dataset_to_ai_catalog(file_path: str) -> str:
26
- """
27
- Upload a dataset to the DataRobot AI Catalog.
28
-
29
- Args:
30
- file_path: Path to the file to upload.
29
+ @dr_mcp_tool(tags={"predictive", "data", "write", "upload", "catalog"})
30
+ async def upload_dataset_to_ai_catalog(
31
+ file_path: Annotated[str, "The path to the dataset file to upload."] | None = None,
32
+ file_url: Annotated[str, "The URL to the dataset file to upload."] | None = None,
33
+ ) -> ToolError | ToolResult:
34
+ """Upload a dataset to the DataRobot AI Catalog / Data Registry."""
35
+ if not file_path and not file_url:
36
+ return ToolError("Either file_path or file_url must be provided.")
37
+ if file_path and file_url:
38
+ return ToolError("Please provide either file_path or file_url, not both.")
31
39
 
32
- Returns
33
- -------
34
- A string summary of the upload result.
35
- """
40
+ # Get client
36
41
  client = get_sdk_client()
37
- if not os.path.exists(file_path):
38
- logger.error(f"File not found: {file_path}")
39
- return f"File not found: {file_path}"
40
- catalog_item = client.Dataset.create_from_file(file_path)
41
- logger.info(f"Successfully uploaded dataset: {catalog_item.id}")
42
- return f"AI Catalog ID: {catalog_item.id}"
43
-
44
-
45
- @dr_mcp_tool(tags={"data", "management", "list"})
46
- async def list_ai_catalog_items() -> str:
47
- """
48
- List all AI Catalog items (datasets) for the authenticated user.
49
-
50
- Returns
51
- -------
52
- A string summary of the AI Catalog items with their IDs and names.
53
- """
42
+ catalog_item = None
43
+ # If file path is provided, create dataset from file.
44
+ if file_path:
45
+ # Does file exist?
46
+ if not os.path.exists(file_path):
47
+ logger.error("File not found: %s", file_path)
48
+ return ToolError(f"File not found: {file_path}")
49
+ catalog_item = client.Dataset.create_from_file(file_path)
50
+ else:
51
+ # Does URL exist?
52
+ if file_url is None or not is_valid_url(file_url):
53
+ logger.error("Invalid file URL: %s", file_url)
54
+ return ToolError(f"Invalid file URL: {file_url}")
55
+ catalog_item = client.Dataset.create_from_url(file_url)
56
+
57
+ if not catalog_item:
58
+ return ToolError("Failed to upload dataset.")
59
+
60
+ return ToolResult(
61
+ content=f"Successfully uploaded dataset: {catalog_item.id}",
62
+ structured_content={
63
+ "dataset_id": catalog_item.id,
64
+ "dataset_version_id": catalog_item.version_id,
65
+ "dataset_name": catalog_item.name,
66
+ },
67
+ )
68
+
69
+
70
+ @dr_mcp_tool(tags={"predictive", "data", "read", "list", "catalog"})
71
+ async def list_ai_catalog_items() -> ToolResult:
72
+ """List all AI Catalog items (datasets) for the authenticated user."""
54
73
  client = get_sdk_client()
55
74
  datasets = client.Dataset.list()
75
+
56
76
  if not datasets:
57
77
  logger.info("No AI Catalog items found")
58
- return "No AI Catalog items found."
59
- result = "\n".join(f"{ds.id}: {ds.name}" for ds in datasets)
60
- logger.info(f"Found {len(datasets)} AI Catalog items")
61
- return result
78
+ return ToolResult(
79
+ content="No AI Catalog items found.",
80
+ structured_content={"datasets": []},
81
+ )
82
+
83
+ return ToolResult(
84
+ content=f"Found {len(datasets)} AI Catalog items.",
85
+ structured_content={
86
+ "datasets": [{"id": ds.id, "name": ds.name} for ds in datasets],
87
+ "count": len(datasets),
88
+ },
89
+ )
62
90
 
63
91
 
64
92
  # from fastmcp import Context
@@ -21,7 +21,6 @@ from nat.data_models.api_server import ChatRequest
21
21
  from nat.data_models.api_server import ChatResponse
22
22
  from nat.data_models.intermediate_step import IntermediateStep
23
23
  from nat.data_models.intermediate_step import IntermediateStepType
24
- from nat.runtime.loader import load_workflow
25
24
  from nat.utils.type_utils import StrPath
26
25
  from openai.types.chat import CompletionCreateParams
27
26
  from ragas import MultiTurnSample
@@ -34,6 +33,8 @@ from datarobot_genai.core.agents.base import InvokeReturn
34
33
  from datarobot_genai.core.agents.base import UsageMetrics
35
34
  from datarobot_genai.core.agents.base import extract_user_prompt_content
36
35
  from datarobot_genai.core.agents.base import is_streaming
36
+ from datarobot_genai.core.mcp.common import MCPConfig
37
+ from datarobot_genai.nat.helpers import load_workflow
37
38
 
38
39
  logger = logging.getLogger(__name__)
39
40
 
@@ -166,17 +167,24 @@ class NatAgent(BaseAgent[None]):
166
167
  # Print commands may need flush=True to ensure they are displayed in real-time.
167
168
  print("Running agent with user prompt:", chat_request.messages[0].content, flush=True)
168
169
 
170
+ mcp_config = MCPConfig(
171
+ authorization_context=self.authorization_context,
172
+ forwarded_headers=self.forwarded_headers,
173
+ )
174
+ server_config = mcp_config.server_config
175
+ headers = server_config["headers"] if server_config else None
176
+
169
177
  if is_streaming(completion_create_params):
170
178
 
171
179
  async def stream_generator() -> AsyncGenerator[
172
180
  tuple[str, MultiTurnSample | None, UsageMetrics], None
173
181
  ]:
174
- usage_metrics: UsageMetrics = {
182
+ default_usage_metrics: UsageMetrics = {
175
183
  "completion_tokens": 0,
176
184
  "prompt_tokens": 0,
177
185
  "total_tokens": 0,
178
186
  }
179
- async with load_workflow(self.workflow_path) as workflow:
187
+ async with load_workflow(self.workflow_path, headers=headers) as workflow:
180
188
  async with workflow.run(chat_request) as runner:
181
189
  intermediate_future = pull_intermediate_structured()
182
190
  async for result in runner.result_stream():
@@ -188,7 +196,7 @@ class NatAgent(BaseAgent[None]):
188
196
  yield (
189
197
  result_text,
190
198
  None,
191
- usage_metrics,
199
+ default_usage_metrics,
192
200
  )
193
201
 
194
202
  steps = await intermediate_future
@@ -197,6 +205,11 @@ class NatAgent(BaseAgent[None]):
197
205
  for step in steps
198
206
  if step.event_type == IntermediateStepType.LLM_END
199
207
  ]
208
+ usage_metrics: UsageMetrics = {
209
+ "completion_tokens": 0,
210
+ "prompt_tokens": 0,
211
+ "total_tokens": 0,
212
+ }
200
213
  for step in llm_end_steps:
201
214
  if step.usage_info:
202
215
  token_usage = step.usage_info.token_usage
@@ -210,7 +223,7 @@ class NatAgent(BaseAgent[None]):
210
223
  return stream_generator()
211
224
 
212
225
  # Create and invoke the NAT (Nemo Agent Toolkit) Agentic Workflow with the inputs
213
- result, steps = await self.run_nat_workflow(self.workflow_path, chat_request)
226
+ result, steps = await self.run_nat_workflow(self.workflow_path, chat_request, headers)
214
227
 
215
228
  llm_end_steps = [step for step in steps if step.event_type == IntermediateStepType.LLM_END]
216
229
  usage_metrics: UsageMetrics = {
@@ -234,7 +247,7 @@ class NatAgent(BaseAgent[None]):
234
247
  return result_text, pipeline_interactions, usage_metrics
235
248
 
236
249
  async def run_nat_workflow(
237
- self, workflow_path: StrPath, chat_request: ChatRequest
250
+ self, workflow_path: StrPath, chat_request: ChatRequest, headers: dict[str, str] | None
238
251
  ) -> tuple[ChatResponse | str, list[IntermediateStep]]:
239
252
  """Run the NAT workflow with the provided config file and input string.
240
253
 
@@ -247,7 +260,7 @@ class NatAgent(BaseAgent[None]):
247
260
  ChatResponse | str: The result from the NAT workflow
248
261
  list[IntermediateStep]: The list of intermediate steps
249
262
  """
250
- async with load_workflow(workflow_path) as workflow:
263
+ async with load_workflow(workflow_path, headers=headers) as workflow:
251
264
  async with workflow.run(chat_request) as runner:
252
265
  intermediate_future = pull_intermediate_structured()
253
266
  runner_outputs = await runner.result()
@@ -14,6 +14,7 @@
14
14
 
15
15
  from collections.abc import AsyncGenerator
16
16
  from typing import Any
17
+ from typing import TypeVar
17
18
 
18
19
  from crewai import LLM
19
20
  from langchain_openai import ChatOpenAI
@@ -22,12 +23,32 @@ from llama_index.llms.litellm import LiteLLM
22
23
  from nat.builder.builder import Builder
23
24
  from nat.builder.framework_enum import LLMFrameworkEnum
24
25
  from nat.cli.register_workflow import register_llm_client
26
+ from nat.data_models.llm import LLMBaseConfig
27
+ from nat.data_models.retry_mixin import RetryMixin
28
+ from nat.plugins.langchain.llm import (
29
+ _patch_llm_based_on_config as langchain_patch_llm_based_on_config,
30
+ )
31
+ from nat.utils.exception_handlers.automatic_retries import patch_with_retry
25
32
 
26
33
  from ..nat.datarobot_llm_providers import DataRobotLLMComponentModelConfig
27
34
  from ..nat.datarobot_llm_providers import DataRobotLLMDeploymentModelConfig
28
35
  from ..nat.datarobot_llm_providers import DataRobotLLMGatewayModelConfig
29
36
  from ..nat.datarobot_llm_providers import DataRobotNIMModelConfig
30
37
 
38
+ ModelType = TypeVar("ModelType")
39
+
40
+
41
+ def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType:
42
+ if isinstance(llm_config, RetryMixin):
43
+ client = patch_with_retry(
44
+ client,
45
+ retries=llm_config.num_retries,
46
+ retry_codes=llm_config.retry_on_status_codes,
47
+ retry_on_messages=llm_config.retry_on_errors,
48
+ )
49
+
50
+ return client
51
+
31
52
 
32
53
  class DataRobotChatOpenAI(ChatOpenAI):
33
54
  def _get_request_payload(
@@ -77,7 +98,8 @@ async def datarobot_llm_gateway_langchain(
77
98
  config["base_url"] = config["base_url"] + "/genai/llmgw"
78
99
  config["stream_options"] = {"include_usage": True}
79
100
  config["model"] = config["model"].removeprefix("datarobot/")
80
- yield DataRobotChatOpenAI(**config)
101
+ client = DataRobotChatOpenAI(**config)
102
+ yield langchain_patch_llm_based_on_config(client, config)
81
103
 
82
104
 
83
105
  @register_llm_client(
@@ -90,7 +112,8 @@ async def datarobot_llm_gateway_crewai(
90
112
  if not config["model"].startswith("datarobot/"):
91
113
  config["model"] = "datarobot/" + config["model"]
92
114
  config["base_url"] = config["base_url"].removesuffix("/api/v2")
93
- yield LLM(**config)
115
+ client = LLM(**config)
116
+ yield _patch_llm_based_on_config(client, config)
94
117
 
95
118
 
96
119
  @register_llm_client(
@@ -103,7 +126,8 @@ async def datarobot_llm_gateway_llamaindex(
103
126
  if not config["model"].startswith("datarobot/"):
104
127
  config["model"] = "datarobot/" + config["model"]
105
128
  config["api_base"] = config.pop("base_url").removesuffix("/api/v2")
106
- yield DataRobotLiteLLM(**config)
129
+ client = DataRobotLiteLLM(**config)
130
+ yield _patch_llm_based_on_config(client, config)
107
131
 
108
132
 
109
133
  @register_llm_client(
@@ -119,7 +143,8 @@ async def datarobot_llm_deployment_langchain(
119
143
  )
120
144
  config["stream_options"] = {"include_usage": True}
121
145
  config["model"] = config["model"].removeprefix("datarobot/")
122
- yield DataRobotChatOpenAI(**config)
146
+ client = DataRobotChatOpenAI(**config)
147
+ yield langchain_patch_llm_based_on_config(client, config)
123
148
 
124
149
 
125
150
  @register_llm_client(
@@ -136,7 +161,8 @@ async def datarobot_llm_deployment_crewai(
136
161
  if not config["model"].startswith("datarobot/"):
137
162
  config["model"] = "datarobot/" + config["model"]
138
163
  config["api_base"] = config.pop("base_url") + "/chat/completions"
139
- yield LLM(**config)
164
+ client = LLM(**config)
165
+ yield _patch_llm_based_on_config(client, config)
140
166
 
141
167
 
142
168
  @register_llm_client(
@@ -153,7 +179,8 @@ async def datarobot_llm_deployment_llamaindex(
153
179
  if not config["model"].startswith("datarobot/"):
154
180
  config["model"] = "datarobot/" + config["model"]
155
181
  config["api_base"] = config.pop("base_url") + "/chat/completions"
156
- yield DataRobotLiteLLM(**config)
182
+ client = DataRobotLiteLLM(**config)
183
+ yield _patch_llm_based_on_config(client, config)
157
184
 
158
185
 
159
186
  @register_llm_client(config_type=DataRobotNIMModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
@@ -167,7 +194,8 @@ async def datarobot_nim_langchain(
167
194
  )
168
195
  config["stream_options"] = {"include_usage": True}
169
196
  config["model"] = config["model"].removeprefix("datarobot/")
170
- yield DataRobotChatOpenAI(**config)
197
+ client = DataRobotChatOpenAI(**config)
198
+ yield langchain_patch_llm_based_on_config(client, config)
171
199
 
172
200
 
173
201
  @register_llm_client(config_type=DataRobotNIMModelConfig, wrapper_type=LLMFrameworkEnum.CREWAI)
@@ -182,7 +210,8 @@ async def datarobot_nim_crewai(
182
210
  if not config["model"].startswith("datarobot/"):
183
211
  config["model"] = "datarobot/" + config["model"]
184
212
  config["api_base"] = config.pop("base_url") + "/chat/completions"
185
- yield LLM(**config)
213
+ client = LLM(**config)
214
+ yield _patch_llm_based_on_config(client, config)
186
215
 
187
216
 
188
217
  @register_llm_client(config_type=DataRobotNIMModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX)
@@ -197,7 +226,8 @@ async def datarobot_nim_llamaindex(
197
226
  if not config["model"].startswith("datarobot/"):
198
227
  config["model"] = "datarobot/" + config["model"]
199
228
  config["api_base"] = config.pop("base_url") + "/chat/completions"
200
- yield DataRobotLiteLLM(**config)
229
+ client = DataRobotLiteLLM(**config)
230
+ yield _patch_llm_based_on_config(client, config)
201
231
 
202
232
 
203
233
  @register_llm_client(
@@ -212,7 +242,8 @@ async def datarobot_llm_component_langchain(
212
242
  config["stream_options"] = {"include_usage": True}
213
243
  config["model"] = config["model"].removeprefix("datarobot/")
214
244
  config.pop("use_datarobot_llm_gateway")
215
- yield DataRobotChatOpenAI(**config)
245
+ client = DataRobotChatOpenAI(**config)
246
+ yield langchain_patch_llm_based_on_config(client, config)
216
247
 
217
248
 
218
249
  @register_llm_client(
@@ -229,7 +260,8 @@ async def datarobot_llm_component_crewai(
229
260
  else:
230
261
  config["api_base"] = config.pop("base_url") + "/chat/completions"
231
262
  config.pop("use_datarobot_llm_gateway")
232
- yield LLM(**config)
263
+ client = LLM(**config)
264
+ yield _patch_llm_based_on_config(client, config)
233
265
 
234
266
 
235
267
  @register_llm_client(
@@ -246,4 +278,5 @@ async def datarobot_llm_component_llamaindex(
246
278
  else:
247
279
  config["api_base"] = config.pop("base_url") + "/chat/completions"
248
280
  config.pop("use_datarobot_llm_gateway")
249
- yield DataRobotLiteLLM(**config)
281
+ client = DataRobotLiteLLM(**config)
282
+ yield _patch_llm_based_on_config(client, config)
@@ -0,0 +1,87 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections.abc import AsyncGenerator
16
+ from contextlib import asynccontextmanager
17
+
18
+ from nat.builder.workflow import Workflow
19
+ from nat.builder.workflow_builder import WorkflowBuilder
20
+ from nat.data_models.config import Config
21
+ from nat.runtime.loader import PluginTypes
22
+ from nat.runtime.loader import discover_and_register_plugins
23
+ from nat.runtime.session import SessionManager
24
+ from nat.utils.data_models.schema_validator import validate_schema
25
+ from nat.utils.io.yaml_tools import yaml_load
26
+ from nat.utils.type_utils import StrPath
27
+
28
+
29
+ def load_config(config_file: StrPath, headers: dict[str, str] | None = None) -> Config:
30
+ """
31
+ Load a NAT configuration file with injected headers. It ensures that all plugins are
32
+ loaded and then validates the configuration file against the Config schema.
33
+
34
+ Parameters
35
+ ----------
36
+ config_file : StrPath
37
+ The path to the configuration file
38
+
39
+ Returns
40
+ -------
41
+ Config
42
+ The validated Config object
43
+ """
44
+ # Ensure all of the plugins are loaded
45
+ discover_and_register_plugins(PluginTypes.CONFIG_OBJECT)
46
+
47
+ config_yaml = yaml_load(config_file)
48
+
49
+ add_headers_to_datarobot_mcp_auth(config_yaml, headers)
50
+
51
+ # Validate configuration adheres to NAT schemas
52
+ validated_nat_config = validate_schema(config_yaml, Config)
53
+
54
+ return validated_nat_config
55
+
56
+
57
+ def add_headers_to_datarobot_mcp_auth(config_yaml: dict, headers: dict[str, str] | None) -> None:
58
+ if headers:
59
+ if authentication := config_yaml.get("authentication"):
60
+ for auth_name in authentication:
61
+ auth_config = authentication[auth_name]
62
+ if auth_config.get("_type") == "datarobot_mcp_auth":
63
+ auth_config["headers"] = headers
64
+
65
+
66
+ @asynccontextmanager
67
+ async def load_workflow(
68
+ config_file: StrPath, max_concurrency: int = -1, headers: dict[str, str] | None = None
69
+ ) -> AsyncGenerator[Workflow, None]:
70
+ """
71
+ Load the NAT configuration file and create a Runner object. This is the primary entry point for
72
+ running NAT workflows with injected headers.
73
+
74
+ Parameters
75
+ ----------
76
+ config_file : StrPath
77
+ The path to the configuration file
78
+ max_concurrency : int, optional
79
+ The maximum number of parallel workflow invocations to support. Specifying 0 or -1 will
80
+ allow an unlimited count, by default -1
81
+ """
82
+ # Load the config object
83
+ config = load_config(config_file, headers=headers)
84
+
85
+ # Must yield the workflow function otherwise it cleans up
86
+ async with WorkflowBuilder.from_config(config=config) as workflow:
87
+ yield SessionManager(await workflow.build(), max_concurrency=max_concurrency)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datarobot-genai
3
- Version: 0.2.11
3
+ Version: 0.2.19
4
4
  Summary: Generic helpers for GenAI
5
5
  Project-URL: Homepage, https://github.com/datarobot-oss/datarobot-genai
6
6
  Author: DataRobot, Inc.