datarobot-genai 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/__init__.py +19 -0
- datarobot_genai/core/__init__.py +0 -0
- datarobot_genai/core/agents/__init__.py +43 -0
- datarobot_genai/core/agents/base.py +195 -0
- datarobot_genai/core/chat/__init__.py +19 -0
- datarobot_genai/core/chat/auth.py +146 -0
- datarobot_genai/core/chat/client.py +178 -0
- datarobot_genai/core/chat/responses.py +297 -0
- datarobot_genai/core/cli/__init__.py +18 -0
- datarobot_genai/core/cli/agent_environment.py +47 -0
- datarobot_genai/core/cli/agent_kernel.py +211 -0
- datarobot_genai/core/custom_model.py +141 -0
- datarobot_genai/core/mcp/__init__.py +0 -0
- datarobot_genai/core/mcp/common.py +218 -0
- datarobot_genai/core/telemetry_agent.py +126 -0
- datarobot_genai/core/utils/__init__.py +3 -0
- datarobot_genai/core/utils/auth.py +234 -0
- datarobot_genai/core/utils/urls.py +64 -0
- datarobot_genai/crewai/__init__.py +24 -0
- datarobot_genai/crewai/agent.py +42 -0
- datarobot_genai/crewai/base.py +159 -0
- datarobot_genai/crewai/events.py +117 -0
- datarobot_genai/crewai/mcp.py +59 -0
- datarobot_genai/drmcp/__init__.py +78 -0
- datarobot_genai/drmcp/core/__init__.py +13 -0
- datarobot_genai/drmcp/core/auth.py +165 -0
- datarobot_genai/drmcp/core/clients.py +180 -0
- datarobot_genai/drmcp/core/config.py +364 -0
- datarobot_genai/drmcp/core/config_utils.py +174 -0
- datarobot_genai/drmcp/core/constants.py +18 -0
- datarobot_genai/drmcp/core/credentials.py +190 -0
- datarobot_genai/drmcp/core/dr_mcp_server.py +350 -0
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
- datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
- datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +70 -0
- datarobot_genai/drmcp/core/dynamic_prompts/register.py +205 -0
- datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
- datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
- datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
- datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
- datarobot_genai/drmcp/core/exceptions.py +25 -0
- datarobot_genai/drmcp/core/logging.py +98 -0
- datarobot_genai/drmcp/core/mcp_instance.py +515 -0
- datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
- datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
- datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
- datarobot_genai/drmcp/core/routes.py +439 -0
- datarobot_genai/drmcp/core/routes_utils.py +30 -0
- datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
- datarobot_genai/drmcp/core/telemetry.py +424 -0
- datarobot_genai/drmcp/core/tool_config.py +111 -0
- datarobot_genai/drmcp/core/tool_filter.py +117 -0
- datarobot_genai/drmcp/core/utils.py +138 -0
- datarobot_genai/drmcp/server.py +19 -0
- datarobot_genai/drmcp/test_utils/__init__.py +13 -0
- datarobot_genai/drmcp/test_utils/clients/__init__.py +0 -0
- datarobot_genai/drmcp/test_utils/clients/anthropic.py +68 -0
- datarobot_genai/drmcp/test_utils/clients/base.py +300 -0
- datarobot_genai/drmcp/test_utils/clients/dr_gateway.py +58 -0
- datarobot_genai/drmcp/test_utils/clients/openai.py +68 -0
- datarobot_genai/drmcp/test_utils/elicitation_test_tool.py +89 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +109 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +133 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +107 -0
- datarobot_genai/drmcp/test_utils/test_interactive.py +205 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +220 -0
- datarobot_genai/drmcp/test_utils/utils.py +91 -0
- datarobot_genai/drmcp/tools/__init__.py +14 -0
- datarobot_genai/drmcp/tools/clients/__init__.py +14 -0
- datarobot_genai/drmcp/tools/clients/atlassian.py +188 -0
- datarobot_genai/drmcp/tools/clients/confluence.py +584 -0
- datarobot_genai/drmcp/tools/clients/gdrive.py +832 -0
- datarobot_genai/drmcp/tools/clients/jira.py +334 -0
- datarobot_genai/drmcp/tools/clients/microsoft_graph.py +479 -0
- datarobot_genai/drmcp/tools/clients/s3.py +28 -0
- datarobot_genai/drmcp/tools/confluence/__init__.py +14 -0
- datarobot_genai/drmcp/tools/confluence/tools.py +321 -0
- datarobot_genai/drmcp/tools/gdrive/__init__.py +0 -0
- datarobot_genai/drmcp/tools/gdrive/tools.py +347 -0
- datarobot_genai/drmcp/tools/jira/__init__.py +14 -0
- datarobot_genai/drmcp/tools/jira/tools.py +243 -0
- datarobot_genai/drmcp/tools/microsoft_graph/__init__.py +13 -0
- datarobot_genai/drmcp/tools/microsoft_graph/tools.py +198 -0
- datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
- datarobot_genai/drmcp/tools/predictive/data.py +133 -0
- datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
- datarobot_genai/drmcp/tools/predictive/model.py +148 -0
- datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
- datarobot_genai/drmcp/tools/predictive/project.py +90 -0
- datarobot_genai/drmcp/tools/predictive/training.py +661 -0
- datarobot_genai/langgraph/__init__.py +0 -0
- datarobot_genai/langgraph/agent.py +341 -0
- datarobot_genai/langgraph/mcp.py +73 -0
- datarobot_genai/llama_index/__init__.py +16 -0
- datarobot_genai/llama_index/agent.py +50 -0
- datarobot_genai/llama_index/base.py +299 -0
- datarobot_genai/llama_index/mcp.py +79 -0
- datarobot_genai/nat/__init__.py +0 -0
- datarobot_genai/nat/agent.py +275 -0
- datarobot_genai/nat/datarobot_auth_provider.py +110 -0
- datarobot_genai/nat/datarobot_llm_clients.py +318 -0
- datarobot_genai/nat/datarobot_llm_providers.py +130 -0
- datarobot_genai/nat/datarobot_mcp_client.py +266 -0
- datarobot_genai/nat/helpers.py +87 -0
- datarobot_genai/py.typed +0 -0
- datarobot_genai-0.2.31.dist-info/METADATA +145 -0
- datarobot_genai-0.2.31.dist-info/RECORD +125 -0
- datarobot_genai-0.2.31.dist-info/WHEEL +4 -0
- datarobot_genai-0.2.31.dist-info/entry_points.txt +5 -0
- datarobot_genai-0.2.31.dist-info/licenses/AUTHORS +2 -0
- datarobot_genai-0.2.31.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from collections.abc import AsyncGenerator
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from datarobot.core.config import DataRobotAppFrameworkBaseSettings
|
|
18
|
+
from nat.authentication.api_key.api_key_auth_provider import APIKeyAuthProvider
|
|
19
|
+
from nat.authentication.api_key.api_key_auth_provider_config import APIKeyAuthProviderConfig
|
|
20
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
21
|
+
from nat.builder.builder import Builder
|
|
22
|
+
from nat.cli.register_workflow import register_auth_provider
|
|
23
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
24
|
+
from nat.data_models.authentication import AuthResult
|
|
25
|
+
from nat.data_models.authentication import HeaderCred
|
|
26
|
+
from pydantic import Field
|
|
27
|
+
|
|
28
|
+
from datarobot_genai.core.mcp.common import MCPConfig
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Config(DataRobotAppFrameworkBaseSettings):
|
|
32
|
+
"""
|
|
33
|
+
Finds variables in the priority order of: env
|
|
34
|
+
variables (including Runtime Parameters), .env, file_secrets, then
|
|
35
|
+
Pulumi output variables.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
datarobot_api_token: str | None = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
config = Config()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DataRobotAPIKeyAuthProviderConfig(APIKeyAuthProviderConfig, name="datarobot_api_key"): # type: ignore[call-arg]
|
|
45
|
+
raw_key: str = Field(
|
|
46
|
+
description=(
|
|
47
|
+
"Raw API token or credential to be injected into the request parameter. "
|
|
48
|
+
"Used for 'bearer','x-api-key','custom', and other schemes. "
|
|
49
|
+
),
|
|
50
|
+
default=config.datarobot_api_token,
|
|
51
|
+
)
|
|
52
|
+
default_user_id: str | None = Field(default="default-user", description="Default user ID")
|
|
53
|
+
allow_default_user_id_for_tool_calls: bool = Field(
|
|
54
|
+
default=True, description="Allow default user ID for tool calls"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@register_auth_provider(config_type=DataRobotAPIKeyAuthProviderConfig)
|
|
59
|
+
async def datarobot_api_key_client(
|
|
60
|
+
config: DataRobotAPIKeyAuthProviderConfig, builder: Builder
|
|
61
|
+
) -> AsyncGenerator[APIKeyAuthProvider]:
|
|
62
|
+
yield APIKeyAuthProvider(config=config)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
mcp_config = MCPConfig().server_config
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class DataRobotMCPAuthProviderConfig(AuthProviderBaseConfig, name="datarobot_mcp_auth"): # type: ignore[call-arg]
|
|
69
|
+
headers: dict[str, str] | None = Field(
|
|
70
|
+
description=("Headers to be used for authentication. "),
|
|
71
|
+
default=mcp_config["headers"] if mcp_config else None,
|
|
72
|
+
)
|
|
73
|
+
default_user_id: str | None = Field(default="default-user", description="Default user ID")
|
|
74
|
+
allow_default_user_id_for_tool_calls: bool = Field(
|
|
75
|
+
default=True, description="Allow default user ID for tool calls"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class DataRobotMCPAuthProvider(AuthProviderBase[DataRobotMCPAuthProviderConfig]):
|
|
80
|
+
def __init__(
|
|
81
|
+
self, config: DataRobotMCPAuthProviderConfig, config_name: str | None = None
|
|
82
|
+
) -> None:
|
|
83
|
+
assert isinstance(config, DataRobotMCPAuthProviderConfig), (
|
|
84
|
+
"Config is not DataRobotMCPAuthProviderConfig"
|
|
85
|
+
)
|
|
86
|
+
super().__init__(config)
|
|
87
|
+
|
|
88
|
+
async def authenticate(self, user_id: str | None = None, **kwargs: Any) -> AuthResult | None:
|
|
89
|
+
"""
|
|
90
|
+
Authenticate the user using the API key credentials.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
user_id (str): The user ID to authenticate.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
AuthenticatedContext: The authenticated context containing headers
|
|
98
|
+
"""
|
|
99
|
+
return AuthResult(
|
|
100
|
+
credentials=[
|
|
101
|
+
HeaderCred(name=name, value=value) for name, value in self.config.headers.items()
|
|
102
|
+
]
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@register_auth_provider(config_type=DataRobotMCPAuthProviderConfig)
|
|
107
|
+
async def datarobot_mcp_auth_provider(
|
|
108
|
+
config: DataRobotMCPAuthProviderConfig, builder: Builder
|
|
109
|
+
) -> AsyncGenerator[DataRobotMCPAuthProvider]:
|
|
110
|
+
yield DataRobotMCPAuthProvider(config=config)
|
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import AsyncGenerator
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
from typing import Any
|
|
20
|
+
from typing import TypeVar
|
|
21
|
+
|
|
22
|
+
from nat.builder.builder import Builder
|
|
23
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
24
|
+
from nat.cli.register_workflow import register_llm_client
|
|
25
|
+
from nat.data_models.llm import LLMBaseConfig
|
|
26
|
+
from nat.data_models.retry_mixin import RetryMixin
|
|
27
|
+
from nat.utils.exception_handlers.automatic_retries import patch_with_retry
|
|
28
|
+
|
|
29
|
+
from ..nat.datarobot_llm_providers import DataRobotLLMComponentModelConfig
|
|
30
|
+
from ..nat.datarobot_llm_providers import DataRobotLLMDeploymentModelConfig
|
|
31
|
+
from ..nat.datarobot_llm_providers import DataRobotLLMGatewayModelConfig
|
|
32
|
+
from ..nat.datarobot_llm_providers import DataRobotNIMModelConfig
|
|
33
|
+
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from crewai import LLM
|
|
36
|
+
from langchain_openai import ChatOpenAI
|
|
37
|
+
from llama_index.llms.litellm import LiteLLM
|
|
38
|
+
|
|
39
|
+
ModelType = TypeVar("ModelType")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType:
|
|
43
|
+
if isinstance(llm_config, RetryMixin):
|
|
44
|
+
client = patch_with_retry(
|
|
45
|
+
client,
|
|
46
|
+
retries=llm_config.num_retries,
|
|
47
|
+
retry_codes=llm_config.retry_on_status_codes,
|
|
48
|
+
retry_on_messages=llm_config.retry_on_errors,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
return client
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _create_datarobot_chat_openai(config: dict[str, Any]) -> Any:
|
|
55
|
+
from langchain_openai import ChatOpenAI # noqa: PLC0415
|
|
56
|
+
|
|
57
|
+
class DataRobotChatOpenAI(ChatOpenAI):
|
|
58
|
+
def _get_request_payload( # type: ignore[override]
|
|
59
|
+
self,
|
|
60
|
+
*args: Any,
|
|
61
|
+
**kwargs: Any,
|
|
62
|
+
) -> dict:
|
|
63
|
+
# We need to default to include_usage=True for streaming but we get 400 response
|
|
64
|
+
# if stream_options is present for a non-streaming call.
|
|
65
|
+
payload = super()._get_request_payload(*args, **kwargs)
|
|
66
|
+
if not payload.get("stream"):
|
|
67
|
+
payload.pop("stream_options", None)
|
|
68
|
+
return payload
|
|
69
|
+
|
|
70
|
+
return DataRobotChatOpenAI(**config)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _create_datarobot_litellm(config: dict[str, Any]) -> Any:
|
|
74
|
+
from llama_index.core.base.llms.types import LLMMetadata # noqa: PLC0415
|
|
75
|
+
from llama_index.llms.litellm import LiteLLM # noqa: PLC0415
|
|
76
|
+
|
|
77
|
+
class DataRobotLiteLLM(LiteLLM): # type: ignore[misc]
|
|
78
|
+
"""DataRobotLiteLLM is a small LiteLLM wrapper class that makes all LiteLLM endpoints
|
|
79
|
+
compatible with the LlamaIndex library.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def metadata(self) -> LLMMetadata:
|
|
84
|
+
"""Returns the metadata for the LLM.
|
|
85
|
+
|
|
86
|
+
This is required to enable the is_chat_model and is_function_calling_model, which are
|
|
87
|
+
mandatory for LlamaIndex agents. By default, LlamaIndex assumes these are false unless
|
|
88
|
+
each individual model config in LiteLLM explicitly sets them to true. To use custom LLM
|
|
89
|
+
endpoints with LlamaIndex agents, you must override this method to return the
|
|
90
|
+
appropriate metadata.
|
|
91
|
+
"""
|
|
92
|
+
return LLMMetadata(
|
|
93
|
+
context_window=128000,
|
|
94
|
+
num_output=self.max_tokens or -1,
|
|
95
|
+
is_chat_model=True,
|
|
96
|
+
is_function_calling_model=True,
|
|
97
|
+
model_name=self.model,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return DataRobotLiteLLM(**config)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@register_llm_client(
|
|
104
|
+
config_type=DataRobotLLMGatewayModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN
|
|
105
|
+
)
|
|
106
|
+
async def datarobot_llm_gateway_langchain(
|
|
107
|
+
llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
|
|
108
|
+
) -> AsyncGenerator[ChatOpenAI]:
|
|
109
|
+
from nat.plugins.langchain.llm import ( # noqa: PLC0415
|
|
110
|
+
_patch_llm_based_on_config as langchain_patch_llm_based_on_config,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
114
|
+
config["base_url"] = config["base_url"] + "/genai/llmgw"
|
|
115
|
+
config["stream_options"] = {"include_usage": True}
|
|
116
|
+
config["model"] = config["model"].removeprefix("datarobot/")
|
|
117
|
+
client = _create_datarobot_chat_openai(config)
|
|
118
|
+
yield langchain_patch_llm_based_on_config(client, config)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@register_llm_client(
|
|
122
|
+
config_type=DataRobotLLMGatewayModelConfig, wrapper_type=LLMFrameworkEnum.CREWAI
|
|
123
|
+
)
|
|
124
|
+
async def datarobot_llm_gateway_crewai(
|
|
125
|
+
llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
|
|
126
|
+
) -> AsyncGenerator[LLM]:
|
|
127
|
+
from crewai import LLM # noqa: PLC0415
|
|
128
|
+
|
|
129
|
+
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
130
|
+
if not config["model"].startswith("datarobot/"):
|
|
131
|
+
config["model"] = "datarobot/" + config["model"]
|
|
132
|
+
config["base_url"] = config["base_url"].removesuffix("/api/v2")
|
|
133
|
+
client = LLM(**config)
|
|
134
|
+
yield _patch_llm_based_on_config(client, config)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@register_llm_client(
|
|
138
|
+
config_type=DataRobotLLMGatewayModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX
|
|
139
|
+
)
|
|
140
|
+
async def datarobot_llm_gateway_llamaindex(
|
|
141
|
+
llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
|
|
142
|
+
) -> AsyncGenerator[LiteLLM]:
|
|
143
|
+
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
144
|
+
if not config["model"].startswith("datarobot/"):
|
|
145
|
+
config["model"] = "datarobot/" + config["model"]
|
|
146
|
+
config["api_base"] = config.pop("base_url").removesuffix("/api/v2")
|
|
147
|
+
client = _create_datarobot_litellm(config)
|
|
148
|
+
yield _patch_llm_based_on_config(client, config)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@register_llm_client(
|
|
152
|
+
config_type=DataRobotLLMDeploymentModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN
|
|
153
|
+
)
|
|
154
|
+
async def datarobot_llm_deployment_langchain(
|
|
155
|
+
llm_config: DataRobotLLMDeploymentModelConfig, builder: Builder
|
|
156
|
+
) -> AsyncGenerator[ChatOpenAI]:
|
|
157
|
+
from nat.plugins.langchain.llm import ( # noqa: PLC0415
|
|
158
|
+
_patch_llm_based_on_config as langchain_patch_llm_based_on_config,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
config = llm_config.model_dump(
|
|
162
|
+
exclude={"type", "thinking"},
|
|
163
|
+
by_alias=True,
|
|
164
|
+
exclude_none=True,
|
|
165
|
+
)
|
|
166
|
+
config["stream_options"] = {"include_usage": True}
|
|
167
|
+
config["model"] = config["model"].removeprefix("datarobot/")
|
|
168
|
+
client = _create_datarobot_chat_openai(config)
|
|
169
|
+
yield langchain_patch_llm_based_on_config(client, config)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@register_llm_client(
|
|
173
|
+
config_type=DataRobotLLMDeploymentModelConfig, wrapper_type=LLMFrameworkEnum.CREWAI
|
|
174
|
+
)
|
|
175
|
+
async def datarobot_llm_deployment_crewai(
|
|
176
|
+
llm_config: DataRobotLLMDeploymentModelConfig, builder: Builder
|
|
177
|
+
) -> AsyncGenerator[LLM]:
|
|
178
|
+
from crewai import LLM # noqa: PLC0415
|
|
179
|
+
|
|
180
|
+
config = llm_config.model_dump(
|
|
181
|
+
exclude={"type", "thinking"},
|
|
182
|
+
by_alias=True,
|
|
183
|
+
exclude_none=True,
|
|
184
|
+
)
|
|
185
|
+
if not config["model"].startswith("datarobot/"):
|
|
186
|
+
config["model"] = "datarobot/" + config["model"]
|
|
187
|
+
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
188
|
+
client = LLM(**config)
|
|
189
|
+
yield _patch_llm_based_on_config(client, config)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@register_llm_client(
|
|
193
|
+
config_type=DataRobotLLMDeploymentModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX
|
|
194
|
+
)
|
|
195
|
+
async def datarobot_llm_deployment_llamaindex(
|
|
196
|
+
llm_config: DataRobotLLMDeploymentModelConfig, builder: Builder
|
|
197
|
+
) -> AsyncGenerator[LiteLLM]:
|
|
198
|
+
config = llm_config.model_dump(
|
|
199
|
+
exclude={"type", "thinking"},
|
|
200
|
+
by_alias=True,
|
|
201
|
+
exclude_none=True,
|
|
202
|
+
)
|
|
203
|
+
if not config["model"].startswith("datarobot/"):
|
|
204
|
+
config["model"] = "datarobot/" + config["model"]
|
|
205
|
+
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
206
|
+
client = _create_datarobot_litellm(config)
|
|
207
|
+
yield _patch_llm_based_on_config(client, config)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
@register_llm_client(config_type=DataRobotNIMModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
211
|
+
async def datarobot_nim_langchain(
|
|
212
|
+
llm_config: DataRobotNIMModelConfig, builder: Builder
|
|
213
|
+
) -> AsyncGenerator[ChatOpenAI]:
|
|
214
|
+
from nat.plugins.langchain.llm import ( # noqa: PLC0415
|
|
215
|
+
_patch_llm_based_on_config as langchain_patch_llm_based_on_config,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
config = llm_config.model_dump(
|
|
219
|
+
exclude={"type", "thinking"},
|
|
220
|
+
by_alias=True,
|
|
221
|
+
exclude_none=True,
|
|
222
|
+
)
|
|
223
|
+
config["stream_options"] = {"include_usage": True}
|
|
224
|
+
config["model"] = config["model"].removeprefix("datarobot/")
|
|
225
|
+
client = _create_datarobot_chat_openai(config)
|
|
226
|
+
yield langchain_patch_llm_based_on_config(client, config)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@register_llm_client(config_type=DataRobotNIMModelConfig, wrapper_type=LLMFrameworkEnum.CREWAI)
|
|
230
|
+
async def datarobot_nim_crewai(
|
|
231
|
+
llm_config: DataRobotNIMModelConfig, builder: Builder
|
|
232
|
+
) -> AsyncGenerator[LLM]:
|
|
233
|
+
from crewai import LLM # noqa: PLC0415
|
|
234
|
+
|
|
235
|
+
config = llm_config.model_dump(
|
|
236
|
+
exclude={"type", "thinking", "max_retries"},
|
|
237
|
+
by_alias=True,
|
|
238
|
+
exclude_none=True,
|
|
239
|
+
)
|
|
240
|
+
if not config["model"].startswith("datarobot/"):
|
|
241
|
+
config["model"] = "datarobot/" + config["model"]
|
|
242
|
+
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
243
|
+
client = LLM(**config)
|
|
244
|
+
yield _patch_llm_based_on_config(client, config)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@register_llm_client(config_type=DataRobotNIMModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX)
|
|
248
|
+
async def datarobot_nim_llamaindex(
|
|
249
|
+
llm_config: DataRobotNIMModelConfig, builder: Builder
|
|
250
|
+
) -> AsyncGenerator[LiteLLM]:
|
|
251
|
+
config = llm_config.model_dump(
|
|
252
|
+
exclude={"type", "thinking"},
|
|
253
|
+
by_alias=True,
|
|
254
|
+
exclude_none=True,
|
|
255
|
+
)
|
|
256
|
+
if not config["model"].startswith("datarobot/"):
|
|
257
|
+
config["model"] = "datarobot/" + config["model"]
|
|
258
|
+
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
259
|
+
client = _create_datarobot_litellm(config)
|
|
260
|
+
yield _patch_llm_based_on_config(client, config)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@register_llm_client(
|
|
264
|
+
config_type=DataRobotLLMComponentModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN
|
|
265
|
+
)
|
|
266
|
+
async def datarobot_llm_component_langchain(
|
|
267
|
+
llm_config: DataRobotLLMComponentModelConfig, builder: Builder
|
|
268
|
+
) -> AsyncGenerator[ChatOpenAI]:
|
|
269
|
+
from nat.plugins.langchain.llm import ( # noqa: PLC0415
|
|
270
|
+
_patch_llm_based_on_config as langchain_patch_llm_based_on_config,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
274
|
+
if config["use_datarobot_llm_gateway"]:
|
|
275
|
+
config["base_url"] = config["base_url"] + "/genai/llmgw"
|
|
276
|
+
config["stream_options"] = {"include_usage": True}
|
|
277
|
+
config["model"] = config["model"].removeprefix("datarobot/")
|
|
278
|
+
config.pop("use_datarobot_llm_gateway")
|
|
279
|
+
client = _create_datarobot_chat_openai(config)
|
|
280
|
+
yield langchain_patch_llm_based_on_config(client, config)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@register_llm_client(
|
|
284
|
+
config_type=DataRobotLLMComponentModelConfig, wrapper_type=LLMFrameworkEnum.CREWAI
|
|
285
|
+
)
|
|
286
|
+
async def datarobot_llm_component_crewai(
|
|
287
|
+
llm_config: DataRobotLLMComponentModelConfig, builder: Builder
|
|
288
|
+
) -> AsyncGenerator[LLM]:
|
|
289
|
+
from crewai import LLM # noqa: PLC0415
|
|
290
|
+
|
|
291
|
+
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
292
|
+
if not config["model"].startswith("datarobot/"):
|
|
293
|
+
config["model"] = "datarobot/" + config["model"]
|
|
294
|
+
if config["use_datarobot_llm_gateway"]:
|
|
295
|
+
config["base_url"] = config["base_url"].removesuffix("/api/v2")
|
|
296
|
+
else:
|
|
297
|
+
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
298
|
+
config.pop("use_datarobot_llm_gateway")
|
|
299
|
+
client = LLM(**config)
|
|
300
|
+
yield _patch_llm_based_on_config(client, config)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@register_llm_client(
|
|
304
|
+
config_type=DataRobotLLMComponentModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX
|
|
305
|
+
)
|
|
306
|
+
async def datarobot_llm_component_llamaindex(
|
|
307
|
+
llm_config: DataRobotLLMComponentModelConfig, builder: Builder
|
|
308
|
+
) -> AsyncGenerator[LiteLLM]:
|
|
309
|
+
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
310
|
+
if not config["model"].startswith("datarobot/"):
|
|
311
|
+
config["model"] = "datarobot/" + config["model"]
|
|
312
|
+
if config["use_datarobot_llm_gateway"]:
|
|
313
|
+
config["api_base"] = config.pop("base_url").removesuffix("/api/v2")
|
|
314
|
+
else:
|
|
315
|
+
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
316
|
+
config.pop("use_datarobot_llm_gateway")
|
|
317
|
+
client = _create_datarobot_litellm(config)
|
|
318
|
+
yield _patch_llm_based_on_config(client, config)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from datarobot.core.config import DataRobotAppFrameworkBaseSettings
|
|
16
|
+
from nat.builder.builder import Builder
|
|
17
|
+
from nat.builder.llm import LLMProviderInfo
|
|
18
|
+
from nat.cli.register_workflow import register_llm_provider
|
|
19
|
+
from nat.llm.openai_llm import OpenAIModelConfig
|
|
20
|
+
from pydantic import AliasChoices
|
|
21
|
+
from pydantic import Field
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Config(DataRobotAppFrameworkBaseSettings):
|
|
25
|
+
"""
|
|
26
|
+
Finds variables in the priority order of: env
|
|
27
|
+
variables (including Runtime Parameters), .env, file_secrets, then
|
|
28
|
+
Pulumi output variables.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
datarobot_endpoint: str = "https://app.datarobot.com/api/v2"
|
|
32
|
+
datarobot_api_token: str | None = None
|
|
33
|
+
llm_deployment_id: str | None = None
|
|
34
|
+
nim_deployment_id: str | None = None
|
|
35
|
+
use_datarobot_llm_gateway: bool = False
|
|
36
|
+
llm_default_model: str | None = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
config = Config()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class DataRobotLLMComponentModelConfig(OpenAIModelConfig, name="datarobot-llm-component"): # type: ignore[call-arg]
|
|
43
|
+
"""A DataRobot LLM provider to be used with an LLM client."""
|
|
44
|
+
|
|
45
|
+
api_key: str | None = Field(
|
|
46
|
+
default=config.datarobot_api_token, description="DataRobot API key."
|
|
47
|
+
)
|
|
48
|
+
base_url: str | None = Field(
|
|
49
|
+
default=config.datarobot_endpoint.rstrip("/")
|
|
50
|
+
if config.use_datarobot_llm_gateway
|
|
51
|
+
else config.datarobot_endpoint + f"/deployments/{config.llm_deployment_id}",
|
|
52
|
+
description="DataRobot LLM URL.",
|
|
53
|
+
)
|
|
54
|
+
model_name: str = Field(
|
|
55
|
+
validation_alias=AliasChoices("model_name", "model"),
|
|
56
|
+
serialization_alias="model",
|
|
57
|
+
description="The model name.",
|
|
58
|
+
default=config.llm_default_model or "datarobot-deployed-llm",
|
|
59
|
+
)
|
|
60
|
+
use_datarobot_llm_gateway: bool = config.use_datarobot_llm_gateway
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@register_llm_provider(config_type=DataRobotLLMComponentModelConfig)
|
|
64
|
+
async def datarobot_llm_component(
|
|
65
|
+
config: DataRobotLLMComponentModelConfig, _builder: Builder
|
|
66
|
+
) -> LLMProviderInfo:
|
|
67
|
+
yield LLMProviderInfo(
|
|
68
|
+
config=config, description="DataRobot LLM Component for use with an LLM client."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class DataRobotLLMGatewayModelConfig(OpenAIModelConfig, name="datarobot-llm-gateway"): # type: ignore[call-arg]
|
|
73
|
+
"""A DataRobot LLM provider to be used with an LLM client."""
|
|
74
|
+
|
|
75
|
+
api_key: str | None = Field(
|
|
76
|
+
default=config.datarobot_api_token, description="DataRobot API key."
|
|
77
|
+
)
|
|
78
|
+
base_url: str | None = Field(
|
|
79
|
+
default=config.datarobot_endpoint.rstrip("/"), description="DataRobot LLM gateway URL."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@register_llm_provider(config_type=DataRobotLLMGatewayModelConfig)
|
|
84
|
+
async def datarobot_llm_gateway(
|
|
85
|
+
config: DataRobotLLMGatewayModelConfig, _builder: Builder
|
|
86
|
+
) -> LLMProviderInfo:
|
|
87
|
+
yield LLMProviderInfo(
|
|
88
|
+
config=config, description="DataRobot LLM Gateway for use with an LLM client."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class DataRobotLLMDeploymentModelConfig(OpenAIModelConfig, name="datarobot-llm-deployment"): # type: ignore[call-arg]
|
|
93
|
+
"""A DataRobot LLM provider to be used with an LLM client."""
|
|
94
|
+
|
|
95
|
+
api_key: str | None = Field(
|
|
96
|
+
default=config.datarobot_api_token, description="DataRobot API key."
|
|
97
|
+
)
|
|
98
|
+
base_url: str | None = Field(
|
|
99
|
+
default=config.datarobot_endpoint + f"/deployments/{config.llm_deployment_id}"
|
|
100
|
+
)
|
|
101
|
+
model_name: str = Field(
|
|
102
|
+
validation_alias=AliasChoices("model_name", "model"),
|
|
103
|
+
serialization_alias="model",
|
|
104
|
+
description="The model name to pass through to the deployment.",
|
|
105
|
+
default="datarobot-deployed-llm",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@register_llm_provider(config_type=DataRobotLLMDeploymentModelConfig)
|
|
110
|
+
async def datarobot_llm_deployment(
|
|
111
|
+
config: DataRobotLLMDeploymentModelConfig, _builder: Builder
|
|
112
|
+
) -> LLMProviderInfo:
|
|
113
|
+
yield LLMProviderInfo(
|
|
114
|
+
config=config, description="DataRobot LLM deployment for use with an LLM client."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class DataRobotNIMModelConfig(DataRobotLLMDeploymentModelConfig, name="datarobot-nim"): # type: ignore[call-arg]
|
|
119
|
+
"""A DataRobot NIM LLM provider to be used with an LLM client."""
|
|
120
|
+
|
|
121
|
+
base_url: str | None = Field(
|
|
122
|
+
default=config.datarobot_endpoint + f"/deployments/{config.nim_deployment_id}"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@register_llm_provider(config_type=DataRobotNIMModelConfig)
|
|
127
|
+
async def datarobot_nim(config: DataRobotNIMModelConfig, _builder: Builder) -> LLMProviderInfo:
|
|
128
|
+
yield LLMProviderInfo(
|
|
129
|
+
config=config, description="DataRobot NIM deployment for use with an LLM client."
|
|
130
|
+
)
|