datarobot-genai 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/__init__.py +19 -0
- datarobot_genai/core/__init__.py +0 -0
- datarobot_genai/core/agents/__init__.py +43 -0
- datarobot_genai/core/agents/base.py +195 -0
- datarobot_genai/core/chat/__init__.py +19 -0
- datarobot_genai/core/chat/auth.py +146 -0
- datarobot_genai/core/chat/client.py +178 -0
- datarobot_genai/core/chat/responses.py +297 -0
- datarobot_genai/core/cli/__init__.py +18 -0
- datarobot_genai/core/cli/agent_environment.py +47 -0
- datarobot_genai/core/cli/agent_kernel.py +211 -0
- datarobot_genai/core/custom_model.py +141 -0
- datarobot_genai/core/mcp/__init__.py +0 -0
- datarobot_genai/core/mcp/common.py +218 -0
- datarobot_genai/core/telemetry_agent.py +126 -0
- datarobot_genai/core/utils/__init__.py +3 -0
- datarobot_genai/core/utils/auth.py +234 -0
- datarobot_genai/core/utils/urls.py +64 -0
- datarobot_genai/crewai/__init__.py +24 -0
- datarobot_genai/crewai/agent.py +42 -0
- datarobot_genai/crewai/base.py +159 -0
- datarobot_genai/crewai/events.py +117 -0
- datarobot_genai/crewai/mcp.py +59 -0
- datarobot_genai/drmcp/__init__.py +78 -0
- datarobot_genai/drmcp/core/__init__.py +13 -0
- datarobot_genai/drmcp/core/auth.py +165 -0
- datarobot_genai/drmcp/core/clients.py +180 -0
- datarobot_genai/drmcp/core/config.py +364 -0
- datarobot_genai/drmcp/core/config_utils.py +174 -0
- datarobot_genai/drmcp/core/constants.py +18 -0
- datarobot_genai/drmcp/core/credentials.py +190 -0
- datarobot_genai/drmcp/core/dr_mcp_server.py +350 -0
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
- datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
- datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +70 -0
- datarobot_genai/drmcp/core/dynamic_prompts/register.py +205 -0
- datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
- datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
- datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
- datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
- datarobot_genai/drmcp/core/exceptions.py +25 -0
- datarobot_genai/drmcp/core/logging.py +98 -0
- datarobot_genai/drmcp/core/mcp_instance.py +515 -0
- datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
- datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
- datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
- datarobot_genai/drmcp/core/routes.py +439 -0
- datarobot_genai/drmcp/core/routes_utils.py +30 -0
- datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
- datarobot_genai/drmcp/core/telemetry.py +424 -0
- datarobot_genai/drmcp/core/tool_config.py +111 -0
- datarobot_genai/drmcp/core/tool_filter.py +117 -0
- datarobot_genai/drmcp/core/utils.py +138 -0
- datarobot_genai/drmcp/server.py +19 -0
- datarobot_genai/drmcp/test_utils/__init__.py +13 -0
- datarobot_genai/drmcp/test_utils/clients/__init__.py +0 -0
- datarobot_genai/drmcp/test_utils/clients/anthropic.py +68 -0
- datarobot_genai/drmcp/test_utils/clients/base.py +300 -0
- datarobot_genai/drmcp/test_utils/clients/dr_gateway.py +58 -0
- datarobot_genai/drmcp/test_utils/clients/openai.py +68 -0
- datarobot_genai/drmcp/test_utils/elicitation_test_tool.py +89 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +109 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +133 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +107 -0
- datarobot_genai/drmcp/test_utils/test_interactive.py +205 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +220 -0
- datarobot_genai/drmcp/test_utils/utils.py +91 -0
- datarobot_genai/drmcp/tools/__init__.py +14 -0
- datarobot_genai/drmcp/tools/clients/__init__.py +14 -0
- datarobot_genai/drmcp/tools/clients/atlassian.py +188 -0
- datarobot_genai/drmcp/tools/clients/confluence.py +584 -0
- datarobot_genai/drmcp/tools/clients/gdrive.py +832 -0
- datarobot_genai/drmcp/tools/clients/jira.py +334 -0
- datarobot_genai/drmcp/tools/clients/microsoft_graph.py +479 -0
- datarobot_genai/drmcp/tools/clients/s3.py +28 -0
- datarobot_genai/drmcp/tools/confluence/__init__.py +14 -0
- datarobot_genai/drmcp/tools/confluence/tools.py +321 -0
- datarobot_genai/drmcp/tools/gdrive/__init__.py +0 -0
- datarobot_genai/drmcp/tools/gdrive/tools.py +347 -0
- datarobot_genai/drmcp/tools/jira/__init__.py +14 -0
- datarobot_genai/drmcp/tools/jira/tools.py +243 -0
- datarobot_genai/drmcp/tools/microsoft_graph/__init__.py +13 -0
- datarobot_genai/drmcp/tools/microsoft_graph/tools.py +198 -0
- datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
- datarobot_genai/drmcp/tools/predictive/data.py +133 -0
- datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
- datarobot_genai/drmcp/tools/predictive/model.py +148 -0
- datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
- datarobot_genai/drmcp/tools/predictive/project.py +90 -0
- datarobot_genai/drmcp/tools/predictive/training.py +661 -0
- datarobot_genai/langgraph/__init__.py +0 -0
- datarobot_genai/langgraph/agent.py +341 -0
- datarobot_genai/langgraph/mcp.py +73 -0
- datarobot_genai/llama_index/__init__.py +16 -0
- datarobot_genai/llama_index/agent.py +50 -0
- datarobot_genai/llama_index/base.py +299 -0
- datarobot_genai/llama_index/mcp.py +79 -0
- datarobot_genai/nat/__init__.py +0 -0
- datarobot_genai/nat/agent.py +275 -0
- datarobot_genai/nat/datarobot_auth_provider.py +110 -0
- datarobot_genai/nat/datarobot_llm_clients.py +318 -0
- datarobot_genai/nat/datarobot_llm_providers.py +130 -0
- datarobot_genai/nat/datarobot_mcp_client.py +266 -0
- datarobot_genai/nat/helpers.py +87 -0
- datarobot_genai/py.typed +0 -0
- datarobot_genai-0.2.31.dist-info/METADATA +145 -0
- datarobot_genai-0.2.31.dist-info/RECORD +125 -0
- datarobot_genai-0.2.31.dist-info/WHEEL +4 -0
- datarobot_genai-0.2.31.dist-info/entry_points.txt +5 -0
- datarobot_genai-0.2.31.dist-info/licenses/AUTHORS +2 -0
- datarobot_genai-0.2.31.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from importlib.metadata import PackageNotFoundError
|
|
2
|
+
from importlib.metadata import version
|
|
3
|
+
|
|
4
|
+
from datarobot_genai.core.utils.urls import get_api_base
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
Only add imports with the core dependencies here at the top level.
|
|
8
|
+
For the optional extras, these need to be imported from their respective sub packages.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"get_api_base",
|
|
13
|
+
"__version__",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
__version__ = version("datarobot-genai")
|
|
18
|
+
except PackageNotFoundError: # pragma: no cover - during local dev without install
|
|
19
|
+
__version__ = "0.0.0"
|
|
File without changes
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Reusable agent utilities and base classes for end-user templates.
|
|
15
|
+
|
|
16
|
+
This package provides:
|
|
17
|
+
- BaseAgent: common initialization for agent env/config fields
|
|
18
|
+
- Common helpers: make_system_prompt, extract_user_prompt_content
|
|
19
|
+
- Framework utilities (optional extras):
|
|
20
|
+
- crewai: build_llm, create_pipeline_interactions_from_messages
|
|
21
|
+
- langgraph: create_pipeline_interactions_from_events
|
|
22
|
+
- llamaindex: DataRobotLiteLLM, create_pipeline_interactions_from_events
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from ..mcp.common import MCPConfig
|
|
26
|
+
from .base import BaseAgent
|
|
27
|
+
from .base import InvokeReturn
|
|
28
|
+
from .base import UsageMetrics
|
|
29
|
+
from .base import default_usage_metrics
|
|
30
|
+
from .base import extract_user_prompt_content
|
|
31
|
+
from .base import is_streaming
|
|
32
|
+
from .base import make_system_prompt
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
"BaseAgent",
|
|
36
|
+
"make_system_prompt",
|
|
37
|
+
"extract_user_prompt_content",
|
|
38
|
+
"default_usage_metrics",
|
|
39
|
+
"is_streaming",
|
|
40
|
+
"InvokeReturn",
|
|
41
|
+
"UsageMetrics",
|
|
42
|
+
"MCPConfig",
|
|
43
|
+
]
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import abc
|
|
16
|
+
import json
|
|
17
|
+
import os
|
|
18
|
+
from collections.abc import AsyncGenerator
|
|
19
|
+
from collections.abc import Mapping
|
|
20
|
+
from typing import Any
|
|
21
|
+
from typing import Generic
|
|
22
|
+
from typing import TypedDict
|
|
23
|
+
from typing import TypeVar
|
|
24
|
+
from typing import cast
|
|
25
|
+
|
|
26
|
+
from ag_ui.core import Event
|
|
27
|
+
from openai.types.chat import CompletionCreateParams
|
|
28
|
+
from ragas import MultiTurnSample
|
|
29
|
+
|
|
30
|
+
from datarobot_genai.core.utils.urls import get_api_base
|
|
31
|
+
|
|
32
|
+
TTool = TypeVar("TTool")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BaseAgent(Generic[TTool], abc.ABC):
|
|
36
|
+
"""BaseAgent centralizes common initialization for agent templates.
|
|
37
|
+
|
|
38
|
+
Fields:
|
|
39
|
+
- api_key: DataRobot API token
|
|
40
|
+
- api_base: Endpoint for DataRobot, normalized for LLM Gateway usage
|
|
41
|
+
- model: Preferred model name
|
|
42
|
+
- timeout: Request timeout
|
|
43
|
+
- verbose: Verbosity flag
|
|
44
|
+
- authorization_context: Authorization context for downstream agents/tools
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
*,
|
|
50
|
+
api_key: str | None = None,
|
|
51
|
+
api_base: str | None = None,
|
|
52
|
+
model: str | None = None,
|
|
53
|
+
verbose: bool | str | None = True,
|
|
54
|
+
timeout: int | None = 90,
|
|
55
|
+
authorization_context: dict[str, Any] | None = None,
|
|
56
|
+
forwarded_headers: dict[str, str] | None = None,
|
|
57
|
+
**_: Any,
|
|
58
|
+
) -> None:
|
|
59
|
+
self.api_key = api_key or os.environ.get("DATAROBOT_API_TOKEN")
|
|
60
|
+
self.api_base = (
|
|
61
|
+
api_base or os.environ.get("DATAROBOT_ENDPOINT") or "https://app.datarobot.com"
|
|
62
|
+
)
|
|
63
|
+
self.model = model
|
|
64
|
+
self.timeout = timeout if timeout is not None else 90
|
|
65
|
+
if isinstance(verbose, str):
|
|
66
|
+
self.verbose = verbose.lower() == "true"
|
|
67
|
+
elif verbose is None:
|
|
68
|
+
self.verbose = True
|
|
69
|
+
else:
|
|
70
|
+
self.verbose = bool(verbose)
|
|
71
|
+
self._mcp_tools: list[TTool] = []
|
|
72
|
+
self._authorization_context = authorization_context or {}
|
|
73
|
+
self._forwarded_headers: dict[str, str] = forwarded_headers or {}
|
|
74
|
+
|
|
75
|
+
def set_mcp_tools(self, tools: list[TTool]) -> None:
|
|
76
|
+
self._mcp_tools = tools
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def mcp_tools(self) -> list[TTool]:
|
|
80
|
+
"""Return the list of MCP tools available to this agent.
|
|
81
|
+
|
|
82
|
+
Subclasses can use this to wire tools into CrewAI agents/tasks during
|
|
83
|
+
workflow construction inside ``build_crewai_workflow``.
|
|
84
|
+
"""
|
|
85
|
+
return self._mcp_tools
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def authorization_context(self) -> dict[str, Any]:
|
|
89
|
+
"""Return the authorization context for this agent."""
|
|
90
|
+
return self._authorization_context
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def forwarded_headers(self) -> dict[str, str]:
|
|
94
|
+
"""Return the forwarded headers for this agent."""
|
|
95
|
+
return self._forwarded_headers
|
|
96
|
+
|
|
97
|
+
def litellm_api_base(self, deployment_id: str | None) -> str:
|
|
98
|
+
return get_api_base(self.api_base, deployment_id)
|
|
99
|
+
|
|
100
|
+
@abc.abstractmethod
|
|
101
|
+
async def invoke(self, completion_create_params: CompletionCreateParams) -> "InvokeReturn":
|
|
102
|
+
raise NotImplementedError("Not implemented")
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
def create_pipeline_interactions_from_events(
|
|
106
|
+
cls,
|
|
107
|
+
events: list[Any] | None,
|
|
108
|
+
) -> MultiTurnSample | None:
|
|
109
|
+
"""Create a simple MultiTurnSample from a list of generic events/messages."""
|
|
110
|
+
if not events:
|
|
111
|
+
return None
|
|
112
|
+
return MultiTurnSample(user_input=events)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def extract_user_prompt_content(
|
|
116
|
+
completion_create_params: CompletionCreateParams | Mapping[str, Any],
|
|
117
|
+
) -> Any:
|
|
118
|
+
"""Extract first user message content from OpenAI messages."""
|
|
119
|
+
params = cast(Mapping[str, Any], completion_create_params)
|
|
120
|
+
user_messages = [msg for msg in params.get("messages", []) if msg.get("role") == "user"]
|
|
121
|
+
# Get the last user message
|
|
122
|
+
user_prompt = user_messages[-1] if user_messages else {}
|
|
123
|
+
content = user_prompt.get("content", {})
|
|
124
|
+
# Try converting prompt from json to a dict
|
|
125
|
+
if isinstance(content, str):
|
|
126
|
+
try:
|
|
127
|
+
content = json.loads(content)
|
|
128
|
+
except json.JSONDecodeError:
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
return content
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def make_system_prompt(suffix: str = "", *, prefix: str | None = None) -> str:
|
|
135
|
+
"""Build a system prompt with optional prefix and suffix.
|
|
136
|
+
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
suffix : str, default ""
|
|
140
|
+
Text appended after the prefix. If non-empty, it is placed on a new line.
|
|
141
|
+
prefix : str | None, keyword-only, default None
|
|
142
|
+
Custom prefix text. When ``None``, a default collaborative assistant
|
|
143
|
+
instruction is used.
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
str
|
|
148
|
+
The composed system prompt string.
|
|
149
|
+
"""
|
|
150
|
+
default_prefix = (
|
|
151
|
+
"You are a helpful AI assistant, collaborating with other assistants."
|
|
152
|
+
" Use the provided tools to progress towards answering the question."
|
|
153
|
+
" If you are unable to fully answer, that's OK, another assistant with different tools "
|
|
154
|
+
" will help where you left off. Execute what you can to make progress."
|
|
155
|
+
)
|
|
156
|
+
head = prefix if prefix is not None else default_prefix
|
|
157
|
+
if suffix:
|
|
158
|
+
return head + "\n" + suffix
|
|
159
|
+
return head
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# Structured type for token usage metrics in responses
|
|
163
|
+
class UsageMetrics(TypedDict):
|
|
164
|
+
completion_tokens: int
|
|
165
|
+
prompt_tokens: int
|
|
166
|
+
total_tokens: int
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# Canonical return type for DRUM-compatible invoke implementations
|
|
170
|
+
InvokeReturn = (
|
|
171
|
+
AsyncGenerator[tuple[str | Event, MultiTurnSample | None, UsageMetrics], None]
|
|
172
|
+
| tuple[str, MultiTurnSample | None, UsageMetrics]
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def default_usage_metrics() -> UsageMetrics:
|
|
177
|
+
"""Return a metrics dict with required keys for OpenAI-compatible responses."""
|
|
178
|
+
return {
|
|
179
|
+
"completion_tokens": 0,
|
|
180
|
+
"prompt_tokens": 0,
|
|
181
|
+
"total_tokens": 0,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def is_streaming(completion_create_params: CompletionCreateParams | Mapping[str, Any]) -> bool:
|
|
186
|
+
"""Return True when the request asks for streaming, False otherwise.
|
|
187
|
+
|
|
188
|
+
Accepts both pydantic types and plain dictionaries.
|
|
189
|
+
"""
|
|
190
|
+
params = cast(Mapping[str, Any], completion_create_params)
|
|
191
|
+
value = params.get("stream", False)
|
|
192
|
+
# Handle non-bool truthy values defensively (e.g., "true")
|
|
193
|
+
if isinstance(value, str):
|
|
194
|
+
return value.lower() == "true"
|
|
195
|
+
return bool(value)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Chat helpers and client utilities."""
|
|
2
|
+
|
|
3
|
+
from .auth import initialize_authorization_context
|
|
4
|
+
from .auth import resolve_authorization_context
|
|
5
|
+
from .client import ToolClient
|
|
6
|
+
from .responses import CustomModelChatResponse
|
|
7
|
+
from .responses import CustomModelStreamingResponse
|
|
8
|
+
from .responses import to_custom_model_chat_response
|
|
9
|
+
from .responses import to_custom_model_streaming_response
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"CustomModelChatResponse",
|
|
13
|
+
"CustomModelStreamingResponse",
|
|
14
|
+
"to_custom_model_chat_response",
|
|
15
|
+
"to_custom_model_streaming_response",
|
|
16
|
+
"ToolClient",
|
|
17
|
+
"resolve_authorization_context",
|
|
18
|
+
"initialize_authorization_context",
|
|
19
|
+
]
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Authorization context helpers for chat flows."""
|
|
16
|
+
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from datarobot.models.genai.agent.auth import set_authorization_context
|
|
20
|
+
from openai.types import CompletionCreateParams
|
|
21
|
+
from openai.types.chat.completion_create_params import CompletionCreateParamsNonStreaming
|
|
22
|
+
from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming
|
|
23
|
+
|
|
24
|
+
from datarobot_genai.core.utils.auth import AuthContextHeaderHandler
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _get_authorization_context_from_headers(
|
|
28
|
+
headers: dict[str, str],
|
|
29
|
+
secret_key: str | None = None,
|
|
30
|
+
) -> dict[str, Any] | None:
|
|
31
|
+
"""Extract authorization context from headers using AuthContextHeaderHandler.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
headers : dict[str, str]
|
|
36
|
+
HTTP headers from which to extract the authorization context.
|
|
37
|
+
secret_key : str | None
|
|
38
|
+
Secret key for JWT decoding. If None, retrieves from environment variable.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
dict[str, Any] | None
|
|
43
|
+
The extracted authorization context, or None if not found.
|
|
44
|
+
"""
|
|
45
|
+
handler = AuthContextHeaderHandler(secret_key=secret_key)
|
|
46
|
+
if context := handler.get_context(headers):
|
|
47
|
+
return context.model_dump()
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_authorization_context_from_params(
|
|
52
|
+
completion_create_params: CompletionCreateParams
|
|
53
|
+
| CompletionCreateParamsNonStreaming
|
|
54
|
+
| CompletionCreateParamsStreaming,
|
|
55
|
+
) -> dict[str, Any] | None:
|
|
56
|
+
"""Extract authorization context from completion create parameters.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
completion_create_params : CompletionCreateParams
|
|
61
|
+
The parameters used to create the completion.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
dict[str, Any] | None
|
|
66
|
+
The extracted authorization context, or None if not found.
|
|
67
|
+
"""
|
|
68
|
+
return completion_create_params.get("authorization_context", None)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def resolve_authorization_context(
|
|
72
|
+
completion_create_params: CompletionCreateParams
|
|
73
|
+
| CompletionCreateParamsNonStreaming
|
|
74
|
+
| CompletionCreateParamsStreaming,
|
|
75
|
+
**kwargs: Any,
|
|
76
|
+
) -> dict[str, Any]:
|
|
77
|
+
"""Resolve the authorization context for the agent.
|
|
78
|
+
|
|
79
|
+
Authorization context is required for propagating information needed by downstream
|
|
80
|
+
agents and tools to retrieve access tokens to connect to external services. This method
|
|
81
|
+
extracts the authorization context from either the incoming HTTP headers or the completion
|
|
82
|
+
create parameters.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
completion_create_params : CompletionCreateParams | CompletionCreateParamsNonStreaming |
|
|
87
|
+
CompletionCreateParamsStreaming
|
|
88
|
+
Parameters supplied to the completion API. May include a fallback
|
|
89
|
+
``authorization_context`` mapping under the same key.
|
|
90
|
+
**kwargs : Any
|
|
91
|
+
Additional keyword arguments. Expected to include a ``headers`` key
|
|
92
|
+
containing incoming HTTP headers as ``dict[str, str]``.
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
dict[str, Any]
|
|
97
|
+
The initialized authorization context.
|
|
98
|
+
"""
|
|
99
|
+
incoming_headers = kwargs.get("headers", {})
|
|
100
|
+
|
|
101
|
+
# Recommended way of propagating authorization context is via headers
|
|
102
|
+
# with JWT endoding/decoding for additional security. The completion params
|
|
103
|
+
# is used as a fallback for backward compatibility only and may be removed in
|
|
104
|
+
# the future.
|
|
105
|
+
authorization_context: dict[str, Any] = (
|
|
106
|
+
_get_authorization_context_from_headers(incoming_headers)
|
|
107
|
+
or _get_authorization_context_from_params(completion_create_params)
|
|
108
|
+
or {}
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return authorization_context
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def initialize_authorization_context(
|
|
115
|
+
completion_create_params: CompletionCreateParams
|
|
116
|
+
| CompletionCreateParamsNonStreaming
|
|
117
|
+
| CompletionCreateParamsStreaming,
|
|
118
|
+
**kwargs: Any,
|
|
119
|
+
) -> None:
|
|
120
|
+
"""Set the authorization context for the agent.
|
|
121
|
+
|
|
122
|
+
Authorization context is required for propagating information needed by downstream
|
|
123
|
+
agents and tools to retrieve access tokens to connect to external services. When set,
|
|
124
|
+
authorization context will be automatically propagated when using ToolClient class.
|
|
125
|
+
authorization context will be propagated when using MCP Server component or when
|
|
126
|
+
using ToolClient class.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
completion_create_params : CompletionCreateParams | CompletionCreateParamsNonStreaming |
|
|
131
|
+
CompletionCreateParamsStreaming
|
|
132
|
+
Parameters supplied to the completion API. May include a fallback
|
|
133
|
+
``authorization_context`` mapping under the same key.
|
|
134
|
+
**kwargs : Any
|
|
135
|
+
Additional keyword arguments. Expected to include a ``headers`` key
|
|
136
|
+
containing incoming HTTP headers as ``dict[str, str]``.
|
|
137
|
+
|
|
138
|
+
"""
|
|
139
|
+
authorization_context = resolve_authorization_context(
|
|
140
|
+
completion_create_params,
|
|
141
|
+
**kwargs,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Note: authorization context internally uses contextvars, which are
|
|
145
|
+
# thread-safe and async-safe.
|
|
146
|
+
set_authorization_context(authorization_context)
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Client for interacting with Agent Tools deployments for chat and scoring."""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import os
|
|
19
|
+
from collections.abc import Iterator
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
import datarobot as dr
|
|
23
|
+
import openai
|
|
24
|
+
import pandas as pd
|
|
25
|
+
from datarobot.models.genai.agent.auth import get_authorization_context
|
|
26
|
+
from datarobot_predict.deployment import PredictionResult
|
|
27
|
+
from datarobot_predict.deployment import UnstructuredPredictionResult
|
|
28
|
+
from datarobot_predict.deployment import predict
|
|
29
|
+
from datarobot_predict.deployment import predict_unstructured
|
|
30
|
+
from openai.types import CompletionCreateParams
|
|
31
|
+
from openai.types.chat import ChatCompletion
|
|
32
|
+
from openai.types.chat import ChatCompletionChunk
|
|
33
|
+
|
|
34
|
+
from ..utils.urls import get_api_base
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ToolClient:
|
|
38
|
+
"""Client for interacting with Agent Tools Deployments.
|
|
39
|
+
|
|
40
|
+
This class provides methods to call the custom model tool using various hooks:
|
|
41
|
+
`score`, `score_unstructured`, and `chat`. When the `authorization_context` is set,
|
|
42
|
+
the client automatically propagates it to the agent tool. The `authorization_context`
|
|
43
|
+
is required for retrieving access tokens to connect to external services.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
api_key: str | None = None,
|
|
49
|
+
base_url: str | None = None,
|
|
50
|
+
authorization_context: dict[str, Any] | None = None,
|
|
51
|
+
):
|
|
52
|
+
"""Initialize the ToolClient.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
api_key (str | None): API key for authentication. Defaults to the
|
|
56
|
+
environment variable `DATAROBOT_API_TOKEN`.
|
|
57
|
+
base_url (str | None): Base URL for the DataRobot API. Defaults to the
|
|
58
|
+
environment variable `DATAROBOT_ENDPOINT` or 'app.datarobot.com'.
|
|
59
|
+
authorization_context (dict[str, Any] | None): Authorization context to use
|
|
60
|
+
for tool calls. If None, will attempt to get from ContextVar (for backward
|
|
61
|
+
compatibility).
|
|
62
|
+
"""
|
|
63
|
+
self.api_key = api_key or os.getenv("DATAROBOT_API_TOKEN")
|
|
64
|
+
base_url = base_url or os.getenv("DATAROBOT_ENDPOINT") or "https://app.datarobot.com"
|
|
65
|
+
base_url = get_api_base(base_url, deployment_id=None)
|
|
66
|
+
self.base_url = base_url
|
|
67
|
+
self._authorization_context = authorization_context
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def datarobot_api_endpoint(self) -> str:
|
|
71
|
+
return self.base_url + "api/v2"
|
|
72
|
+
|
|
73
|
+
def get_deployment(self, deployment_id: str) -> dr.Deployment:
|
|
74
|
+
"""Retrieve a deployment by its ID.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
deployment_id (str): The ID of the deployment.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
dr.Deployment: The deployment object.
|
|
82
|
+
"""
|
|
83
|
+
dr.Client(self.api_key, self.datarobot_api_endpoint)
|
|
84
|
+
return dr.Deployment.get(deployment_id=deployment_id)
|
|
85
|
+
|
|
86
|
+
def call(
|
|
87
|
+
self,
|
|
88
|
+
deployment_id: str,
|
|
89
|
+
payload: dict[str, Any],
|
|
90
|
+
authorization_context: dict[str, Any] | None = None,
|
|
91
|
+
**kwargs: Any,
|
|
92
|
+
) -> UnstructuredPredictionResult:
|
|
93
|
+
"""Run the custom model tool using score_unstructured hook.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
deployment_id (str): The ID of the deployment.
|
|
97
|
+
payload (dict[str, Any]): The input payload.
|
|
98
|
+
authorization_context (dict[str, Any] | None): Authorization context to use.
|
|
99
|
+
If None, uses the context from initialization or falls back to ContextVar.
|
|
100
|
+
**kwargs: Additional keyword arguments.
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
UnstructuredPredictionResult: The response content and headers.
|
|
105
|
+
"""
|
|
106
|
+
# Use explicit context, fall back to instance context, then ContextVar
|
|
107
|
+
auth_ctx = authorization_context or self._authorization_context
|
|
108
|
+
if auth_ctx is None:
|
|
109
|
+
try:
|
|
110
|
+
auth_ctx = get_authorization_context()
|
|
111
|
+
except LookupError:
|
|
112
|
+
auth_ctx = {}
|
|
113
|
+
|
|
114
|
+
data = {
|
|
115
|
+
"payload": payload,
|
|
116
|
+
"authorization_context": auth_ctx,
|
|
117
|
+
}
|
|
118
|
+
return predict_unstructured(
|
|
119
|
+
deployment=self.get_deployment(deployment_id),
|
|
120
|
+
data=json.dumps(data),
|
|
121
|
+
content_type="application/json",
|
|
122
|
+
**kwargs,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def score(
|
|
126
|
+
self, deployment_id: str, data_frame: pd.DataFrame, **kwargs: Any
|
|
127
|
+
) -> PredictionResult:
|
|
128
|
+
"""Run the custom model tool using score hook.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
deployment_id (str): The ID of the deployment.
|
|
132
|
+
data_frame (pd.DataFrame): The input data frame.
|
|
133
|
+
**kwargs: Additional keyword arguments.
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
PredictionResult: The response content and headers.
|
|
138
|
+
"""
|
|
139
|
+
return predict(
|
|
140
|
+
deployment=self.get_deployment(deployment_id),
|
|
141
|
+
data_frame=data_frame,
|
|
142
|
+
**kwargs,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def chat(
|
|
146
|
+
self,
|
|
147
|
+
completion_create_params: CompletionCreateParams,
|
|
148
|
+
model: str,
|
|
149
|
+
authorization_context: dict[str, Any] | None = None,
|
|
150
|
+
) -> ChatCompletion | Iterator[ChatCompletionChunk]:
|
|
151
|
+
"""Run the custom model tool with the chat hook.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
completion_create_params (CompletionCreateParams): Parameters for the chat completion.
|
|
155
|
+
model (str): The model to use.
|
|
156
|
+
authorization_context (dict[str, Any] | None): Authorization context to use.
|
|
157
|
+
If None, uses the context from initialization or falls back to ContextVar.
|
|
158
|
+
|
|
159
|
+
Returns
|
|
160
|
+
-------
|
|
161
|
+
Union[ChatCompletion, Iterator[ChatCompletionChunk]]: The chat completion response.
|
|
162
|
+
"""
|
|
163
|
+
# Use explicit context, fall back to instance context, then ContextVar
|
|
164
|
+
auth_ctx = authorization_context or self._authorization_context
|
|
165
|
+
if auth_ctx is None:
|
|
166
|
+
try:
|
|
167
|
+
auth_ctx = get_authorization_context()
|
|
168
|
+
except LookupError:
|
|
169
|
+
auth_ctx = {}
|
|
170
|
+
|
|
171
|
+
extra_body = {
|
|
172
|
+
"authorization_context": auth_ctx,
|
|
173
|
+
}
|
|
174
|
+
return openai.chat.completions.create(
|
|
175
|
+
**completion_create_params,
|
|
176
|
+
model=model,
|
|
177
|
+
extra_body=extra_body,
|
|
178
|
+
)
|