datarobot-genai 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. datarobot_genai/__init__.py +19 -0
  2. datarobot_genai/core/__init__.py +0 -0
  3. datarobot_genai/core/agents/__init__.py +43 -0
  4. datarobot_genai/core/agents/base.py +195 -0
  5. datarobot_genai/core/chat/__init__.py +19 -0
  6. datarobot_genai/core/chat/auth.py +146 -0
  7. datarobot_genai/core/chat/client.py +178 -0
  8. datarobot_genai/core/chat/responses.py +297 -0
  9. datarobot_genai/core/cli/__init__.py +18 -0
  10. datarobot_genai/core/cli/agent_environment.py +47 -0
  11. datarobot_genai/core/cli/agent_kernel.py +211 -0
  12. datarobot_genai/core/custom_model.py +141 -0
  13. datarobot_genai/core/mcp/__init__.py +0 -0
  14. datarobot_genai/core/mcp/common.py +218 -0
  15. datarobot_genai/core/telemetry_agent.py +126 -0
  16. datarobot_genai/core/utils/__init__.py +3 -0
  17. datarobot_genai/core/utils/auth.py +234 -0
  18. datarobot_genai/core/utils/urls.py +64 -0
  19. datarobot_genai/crewai/__init__.py +24 -0
  20. datarobot_genai/crewai/agent.py +42 -0
  21. datarobot_genai/crewai/base.py +159 -0
  22. datarobot_genai/crewai/events.py +117 -0
  23. datarobot_genai/crewai/mcp.py +59 -0
  24. datarobot_genai/drmcp/__init__.py +78 -0
  25. datarobot_genai/drmcp/core/__init__.py +13 -0
  26. datarobot_genai/drmcp/core/auth.py +165 -0
  27. datarobot_genai/drmcp/core/clients.py +180 -0
  28. datarobot_genai/drmcp/core/config.py +250 -0
  29. datarobot_genai/drmcp/core/config_utils.py +174 -0
  30. datarobot_genai/drmcp/core/constants.py +18 -0
  31. datarobot_genai/drmcp/core/credentials.py +190 -0
  32. datarobot_genai/drmcp/core/dr_mcp_server.py +316 -0
  33. datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
  34. datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
  35. datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
  36. datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +128 -0
  37. datarobot_genai/drmcp/core/dynamic_prompts/register.py +206 -0
  38. datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
  39. datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
  40. datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
  41. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
  42. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
  43. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
  44. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
  45. datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
  46. datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
  47. datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
  48. datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
  49. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
  50. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
  51. datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
  52. datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
  53. datarobot_genai/drmcp/core/exceptions.py +25 -0
  54. datarobot_genai/drmcp/core/logging.py +98 -0
  55. datarobot_genai/drmcp/core/mcp_instance.py +542 -0
  56. datarobot_genai/drmcp/core/mcp_server_tools.py +129 -0
  57. datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
  58. datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
  59. datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
  60. datarobot_genai/drmcp/core/routes.py +436 -0
  61. datarobot_genai/drmcp/core/routes_utils.py +30 -0
  62. datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
  63. datarobot_genai/drmcp/core/telemetry.py +424 -0
  64. datarobot_genai/drmcp/core/tool_filter.py +108 -0
  65. datarobot_genai/drmcp/core/utils.py +131 -0
  66. datarobot_genai/drmcp/server.py +19 -0
  67. datarobot_genai/drmcp/test_utils/__init__.py +13 -0
  68. datarobot_genai/drmcp/test_utils/integration_mcp_server.py +102 -0
  69. datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +96 -0
  70. datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +94 -0
  71. datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +234 -0
  72. datarobot_genai/drmcp/test_utils/tool_base_ete.py +151 -0
  73. datarobot_genai/drmcp/test_utils/utils.py +91 -0
  74. datarobot_genai/drmcp/tools/__init__.py +14 -0
  75. datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
  76. datarobot_genai/drmcp/tools/predictive/data.py +97 -0
  77. datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
  78. datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
  79. datarobot_genai/drmcp/tools/predictive/model.py +148 -0
  80. datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
  81. datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
  82. datarobot_genai/drmcp/tools/predictive/project.py +72 -0
  83. datarobot_genai/drmcp/tools/predictive/training.py +651 -0
  84. datarobot_genai/langgraph/__init__.py +0 -0
  85. datarobot_genai/langgraph/agent.py +341 -0
  86. datarobot_genai/langgraph/mcp.py +73 -0
  87. datarobot_genai/llama_index/__init__.py +16 -0
  88. datarobot_genai/llama_index/agent.py +50 -0
  89. datarobot_genai/llama_index/base.py +299 -0
  90. datarobot_genai/llama_index/mcp.py +79 -0
  91. datarobot_genai/nat/__init__.py +0 -0
  92. datarobot_genai/nat/agent.py +258 -0
  93. datarobot_genai/nat/datarobot_llm_clients.py +249 -0
  94. datarobot_genai/nat/datarobot_llm_providers.py +130 -0
  95. datarobot_genai/py.typed +0 -0
  96. datarobot_genai-0.2.0.dist-info/METADATA +139 -0
  97. datarobot_genai-0.2.0.dist-info/RECORD +101 -0
  98. datarobot_genai-0.2.0.dist-info/WHEEL +4 -0
  99. datarobot_genai-0.2.0.dist-info/entry_points.txt +3 -0
  100. datarobot_genai-0.2.0.dist-info/licenses/AUTHORS +2 -0
  101. datarobot_genai-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,19 @@
1
+ from importlib.metadata import PackageNotFoundError
2
+ from importlib.metadata import version
3
+
4
+ from datarobot_genai.core.utils.urls import get_api_base
5
+
6
+ """
7
+ Only add imports with the core dependencies here at the top level.
8
+ For the optional extras, these need to be imported from their respective sub packages.
9
+ """
10
+
11
+ __all__ = [
12
+ "get_api_base",
13
+ "__version__",
14
+ ]
15
+
16
+ try:
17
+ __version__ = version("datarobot-genai")
18
+ except PackageNotFoundError: # pragma: no cover - during local dev without install
19
+ __version__ = "0.0.0"
File without changes
@@ -0,0 +1,43 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Reusable agent utilities and base classes for end-user templates.
15
+
16
+ This package provides:
17
+ - BaseAgent: common initialization for agent env/config fields
18
+ - Common helpers: make_system_prompt, extract_user_prompt_content
19
+ - Framework utilities (optional extras):
20
+ - crewai: build_llm, create_pipeline_interactions_from_messages
21
+ - langgraph: create_pipeline_interactions_from_events
22
+ - llamaindex: DataRobotLiteLLM, create_pipeline_interactions_from_events
23
+ """
24
+
25
+ from ..mcp.common import MCPConfig
26
+ from .base import BaseAgent
27
+ from .base import InvokeReturn
28
+ from .base import UsageMetrics
29
+ from .base import default_usage_metrics
30
+ from .base import extract_user_prompt_content
31
+ from .base import is_streaming
32
+ from .base import make_system_prompt
33
+
34
+ __all__ = [
35
+ "BaseAgent",
36
+ "make_system_prompt",
37
+ "extract_user_prompt_content",
38
+ "default_usage_metrics",
39
+ "is_streaming",
40
+ "InvokeReturn",
41
+ "UsageMetrics",
42
+ "MCPConfig",
43
+ ]
@@ -0,0 +1,195 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ import json
17
+ import os
18
+ from collections.abc import AsyncGenerator
19
+ from collections.abc import Mapping
20
+ from typing import Any
21
+ from typing import Generic
22
+ from typing import TypedDict
23
+ from typing import TypeVar
24
+ from typing import cast
25
+
26
+ from ag_ui.core import Event
27
+ from openai.types.chat import CompletionCreateParams
28
+ from ragas import MultiTurnSample
29
+
30
+ from datarobot_genai.core.utils.urls import get_api_base
31
+
32
+ TTool = TypeVar("TTool")
33
+
34
+
35
+ class BaseAgent(Generic[TTool], abc.ABC):
36
+ """BaseAgent centralizes common initialization for agent templates.
37
+
38
+ Fields:
39
+ - api_key: DataRobot API token
40
+ - api_base: Endpoint for DataRobot, normalized for LLM Gateway usage
41
+ - model: Preferred model name
42
+ - timeout: Request timeout
43
+ - verbose: Verbosity flag
44
+ - authorization_context: Authorization context for downstream agents/tools
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ *,
50
+ api_key: str | None = None,
51
+ api_base: str | None = None,
52
+ model: str | None = None,
53
+ verbose: bool | str | None = True,
54
+ timeout: int | None = 90,
55
+ authorization_context: dict[str, Any] | None = None,
56
+ forwarded_headers: dict[str, str] | None = None,
57
+ **_: Any,
58
+ ) -> None:
59
+ self.api_key = api_key or os.environ.get("DATAROBOT_API_TOKEN")
60
+ self.api_base = (
61
+ api_base or os.environ.get("DATAROBOT_ENDPOINT") or "https://app.datarobot.com"
62
+ )
63
+ self.model = model
64
+ self.timeout = timeout if timeout is not None else 90
65
+ if isinstance(verbose, str):
66
+ self.verbose = verbose.lower() == "true"
67
+ elif verbose is None:
68
+ self.verbose = True
69
+ else:
70
+ self.verbose = bool(verbose)
71
+ self._mcp_tools: list[TTool] = []
72
+ self._authorization_context = authorization_context or {}
73
+ self._forwarded_headers: dict[str, str] = forwarded_headers or {}
74
+
75
+ def set_mcp_tools(self, tools: list[TTool]) -> None:
76
+ self._mcp_tools = tools
77
+
78
+ @property
79
+ def mcp_tools(self) -> list[TTool]:
80
+ """Return the list of MCP tools available to this agent.
81
+
82
+ Subclasses can use this to wire tools into CrewAI agents/tasks during
83
+ workflow construction inside ``build_crewai_workflow``.
84
+ """
85
+ return self._mcp_tools
86
+
87
+ @property
88
+ def authorization_context(self) -> dict[str, Any]:
89
+ """Return the authorization context for this agent."""
90
+ return self._authorization_context
91
+
92
+ @property
93
+ def forwarded_headers(self) -> dict[str, str]:
94
+ """Return the forwarded headers for this agent."""
95
+ return self._forwarded_headers
96
+
97
+ def litellm_api_base(self, deployment_id: str | None) -> str:
98
+ return get_api_base(self.api_base, deployment_id)
99
+
100
+ @abc.abstractmethod
101
+ async def invoke(self, completion_create_params: CompletionCreateParams) -> "InvokeReturn":
102
+ raise NotImplementedError("Not implemented")
103
+
104
+ @classmethod
105
+ def create_pipeline_interactions_from_events(
106
+ cls,
107
+ events: list[Any] | None,
108
+ ) -> MultiTurnSample | None:
109
+ """Create a simple MultiTurnSample from a list of generic events/messages."""
110
+ if not events:
111
+ return None
112
+ return MultiTurnSample(user_input=events)
113
+
114
+
115
+ def extract_user_prompt_content(
116
+ completion_create_params: CompletionCreateParams | Mapping[str, Any],
117
+ ) -> Any:
118
+ """Extract first user message content from OpenAI messages."""
119
+ params = cast(Mapping[str, Any], completion_create_params)
120
+ user_messages = [msg for msg in params.get("messages", []) if msg.get("role") == "user"]
121
+ # Get the last user message
122
+ user_prompt = user_messages[-1] if user_messages else {}
123
+ content = user_prompt.get("content", {})
124
+ # Try converting prompt from json to a dict
125
+ if isinstance(content, str):
126
+ try:
127
+ content = json.loads(content)
128
+ except json.JSONDecodeError:
129
+ pass
130
+
131
+ return content
132
+
133
+
134
+ def make_system_prompt(suffix: str = "", *, prefix: str | None = None) -> str:
135
+ """Build a system prompt with optional prefix and suffix.
136
+
137
+ Parameters
138
+ ----------
139
+ suffix : str, default ""
140
+ Text appended after the prefix. If non-empty, it is placed on a new line.
141
+ prefix : str | None, keyword-only, default None
142
+ Custom prefix text. When ``None``, a default collaborative assistant
143
+ instruction is used.
144
+
145
+ Returns
146
+ -------
147
+ str
148
+ The composed system prompt string.
149
+ """
150
+ default_prefix = (
151
+ "You are a helpful AI assistant, collaborating with other assistants."
152
+ " Use the provided tools to progress towards answering the question."
153
+ " If you are unable to fully answer, that's OK, another assistant with different tools "
154
+ " will help where you left off. Execute what you can to make progress."
155
+ )
156
+ head = prefix if prefix is not None else default_prefix
157
+ if suffix:
158
+ return head + "\n" + suffix
159
+ return head
160
+
161
+
162
+ # Structured type for token usage metrics in responses
163
+ class UsageMetrics(TypedDict):
164
+ completion_tokens: int
165
+ prompt_tokens: int
166
+ total_tokens: int
167
+
168
+
169
+ # Canonical return type for DRUM-compatible invoke implementations
170
+ InvokeReturn = (
171
+ AsyncGenerator[tuple[str | Event, MultiTurnSample | None, UsageMetrics], None]
172
+ | tuple[str, MultiTurnSample | None, UsageMetrics]
173
+ )
174
+
175
+
176
+ def default_usage_metrics() -> UsageMetrics:
177
+ """Return a metrics dict with required keys for OpenAI-compatible responses."""
178
+ return {
179
+ "completion_tokens": 0,
180
+ "prompt_tokens": 0,
181
+ "total_tokens": 0,
182
+ }
183
+
184
+
185
+ def is_streaming(completion_create_params: CompletionCreateParams | Mapping[str, Any]) -> bool:
186
+ """Return True when the request asks for streaming, False otherwise.
187
+
188
+ Accepts both pydantic types and plain dictionaries.
189
+ """
190
+ params = cast(Mapping[str, Any], completion_create_params)
191
+ value = params.get("stream", False)
192
+ # Handle non-bool truthy values defensively (e.g., "true")
193
+ if isinstance(value, str):
194
+ return value.lower() == "true"
195
+ return bool(value)
@@ -0,0 +1,19 @@
1
+ """Chat helpers and client utilities."""
2
+
3
+ from .auth import initialize_authorization_context
4
+ from .auth import resolve_authorization_context
5
+ from .client import ToolClient
6
+ from .responses import CustomModelChatResponse
7
+ from .responses import CustomModelStreamingResponse
8
+ from .responses import to_custom_model_chat_response
9
+ from .responses import to_custom_model_streaming_response
10
+
11
+ __all__ = [
12
+ "CustomModelChatResponse",
13
+ "CustomModelStreamingResponse",
14
+ "to_custom_model_chat_response",
15
+ "to_custom_model_streaming_response",
16
+ "ToolClient",
17
+ "resolve_authorization_context",
18
+ "initialize_authorization_context",
19
+ ]
@@ -0,0 +1,146 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Authorization context helpers for chat flows."""
16
+
17
+ from typing import Any
18
+
19
+ from datarobot.models.genai.agent.auth import set_authorization_context
20
+ from openai.types import CompletionCreateParams
21
+ from openai.types.chat.completion_create_params import CompletionCreateParamsNonStreaming
22
+ from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming
23
+
24
+ from datarobot_genai.core.utils.auth import AuthContextHeaderHandler
25
+
26
+
27
+ def _get_authorization_context_from_headers(
28
+ headers: dict[str, str],
29
+ secret_key: str | None = None,
30
+ ) -> dict[str, Any] | None:
31
+ """Extract authorization context from headers using AuthContextHeaderHandler.
32
+
33
+ Parameters
34
+ ----------
35
+ headers : dict[str, str]
36
+ HTTP headers from which to extract the authorization context.
37
+ secret_key : str | None
38
+ Secret key for JWT decoding. If None, retrieves from environment variable.
39
+
40
+ Returns
41
+ -------
42
+ dict[str, Any] | None
43
+ The extracted authorization context, or None if not found.
44
+ """
45
+ handler = AuthContextHeaderHandler(secret_key=secret_key)
46
+ if context := handler.get_context(headers):
47
+ return context.model_dump()
48
+ return None
49
+
50
+
51
+ def _get_authorization_context_from_params(
52
+ completion_create_params: CompletionCreateParams
53
+ | CompletionCreateParamsNonStreaming
54
+ | CompletionCreateParamsStreaming,
55
+ ) -> dict[str, Any] | None:
56
+ """Extract authorization context from completion create parameters.
57
+
58
+ Parameters
59
+ ----------
60
+ completion_create_params : CompletionCreateParams
61
+ The parameters used to create the completion.
62
+
63
+ Returns
64
+ -------
65
+ dict[str, Any] | None
66
+ The extracted authorization context, or None if not found.
67
+ """
68
+ return completion_create_params.get("authorization_context", None)
69
+
70
+
71
+ def resolve_authorization_context(
72
+ completion_create_params: CompletionCreateParams
73
+ | CompletionCreateParamsNonStreaming
74
+ | CompletionCreateParamsStreaming,
75
+ **kwargs: Any,
76
+ ) -> dict[str, Any]:
77
+ """Resolve the authorization context for the agent.
78
+
79
+ Authorization context is required for propagating information needed by downstream
80
+ agents and tools to retrieve access tokens to connect to external services. This method
81
+ extracts the authorization context from either the incoming HTTP headers or the completion
82
+ create parameters.
83
+
84
+ Parameters
85
+ ----------
86
+ completion_create_params : CompletionCreateParams | CompletionCreateParamsNonStreaming |
87
+ CompletionCreateParamsStreaming
88
+ Parameters supplied to the completion API. May include a fallback
89
+ ``authorization_context`` mapping under the same key.
90
+ **kwargs : Any
91
+ Additional keyword arguments. Expected to include a ``headers`` key
92
+ containing incoming HTTP headers as ``dict[str, str]``.
93
+
94
+ Returns
95
+ -------
96
+ dict[str, Any]
97
+ The initialized authorization context.
98
+ """
99
+ incoming_headers = kwargs.get("headers", {})
100
+
101
+ # Recommended way of propagating authorization context is via headers
102
+ # with JWT endoding/decoding for additional security. The completion params
103
+ # is used as a fallback for backward compatibility only and may be removed in
104
+ # the future.
105
+ authorization_context: dict[str, Any] = (
106
+ _get_authorization_context_from_headers(incoming_headers)
107
+ or _get_authorization_context_from_params(completion_create_params)
108
+ or {}
109
+ )
110
+
111
+ return authorization_context
112
+
113
+
114
+ def initialize_authorization_context(
115
+ completion_create_params: CompletionCreateParams
116
+ | CompletionCreateParamsNonStreaming
117
+ | CompletionCreateParamsStreaming,
118
+ **kwargs: Any,
119
+ ) -> None:
120
+ """Set the authorization context for the agent.
121
+
122
+ Authorization context is required for propagating information needed by downstream
123
+ agents and tools to retrieve access tokens to connect to external services. When set,
124
+ authorization context will be automatically propagated when using ToolClient class.
125
+ authorization context will be propagated when using MCP Server component or when
126
+ using ToolClient class.
127
+
128
+ Parameters
129
+ ----------
130
+ completion_create_params : CompletionCreateParams | CompletionCreateParamsNonStreaming |
131
+ CompletionCreateParamsStreaming
132
+ Parameters supplied to the completion API. May include a fallback
133
+ ``authorization_context`` mapping under the same key.
134
+ **kwargs : Any
135
+ Additional keyword arguments. Expected to include a ``headers`` key
136
+ containing incoming HTTP headers as ``dict[str, str]``.
137
+
138
+ """
139
+ authorization_context = resolve_authorization_context(
140
+ completion_create_params,
141
+ **kwargs,
142
+ )
143
+
144
+ # Note: authorization context internally uses contextvars, which are
145
+ # thread-safe and async-safe.
146
+ set_authorization_context(authorization_context)
@@ -0,0 +1,178 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Client for interacting with Agent Tools deployments for chat and scoring."""
16
+
17
+ import json
18
+ import os
19
+ from collections.abc import Iterator
20
+ from typing import Any
21
+
22
+ import datarobot as dr
23
+ import openai
24
+ import pandas as pd
25
+ from datarobot.models.genai.agent.auth import get_authorization_context
26
+ from datarobot_predict.deployment import PredictionResult
27
+ from datarobot_predict.deployment import UnstructuredPredictionResult
28
+ from datarobot_predict.deployment import predict
29
+ from datarobot_predict.deployment import predict_unstructured
30
+ from openai.types import CompletionCreateParams
31
+ from openai.types.chat import ChatCompletion
32
+ from openai.types.chat import ChatCompletionChunk
33
+
34
+ from ..utils.urls import get_api_base
35
+
36
+
37
+ class ToolClient:
38
+ """Client for interacting with Agent Tools Deployments.
39
+
40
+ This class provides methods to call the custom model tool using various hooks:
41
+ `score`, `score_unstructured`, and `chat`. When the `authorization_context` is set,
42
+ the client automatically propagates it to the agent tool. The `authorization_context`
43
+ is required for retrieving access tokens to connect to external services.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ api_key: str | None = None,
49
+ base_url: str | None = None,
50
+ authorization_context: dict[str, Any] | None = None,
51
+ ):
52
+ """Initialize the ToolClient.
53
+
54
+ Args:
55
+ api_key (str | None): API key for authentication. Defaults to the
56
+ environment variable `DATAROBOT_API_TOKEN`.
57
+ base_url (str | None): Base URL for the DataRobot API. Defaults to the
58
+ environment variable `DATAROBOT_ENDPOINT` or 'app.datarobot.com'.
59
+ authorization_context (dict[str, Any] | None): Authorization context to use
60
+ for tool calls. If None, will attempt to get from ContextVar (for backward
61
+ compatibility).
62
+ """
63
+ self.api_key = api_key or os.getenv("DATAROBOT_API_TOKEN")
64
+ base_url = base_url or os.getenv("DATAROBOT_ENDPOINT") or "https://app.datarobot.com"
65
+ base_url = get_api_base(base_url, deployment_id=None)
66
+ self.base_url = base_url
67
+ self._authorization_context = authorization_context
68
+
69
+ @property
70
+ def datarobot_api_endpoint(self) -> str:
71
+ return self.base_url + "api/v2"
72
+
73
+ def get_deployment(self, deployment_id: str) -> dr.Deployment:
74
+ """Retrieve a deployment by its ID.
75
+
76
+ Args:
77
+ deployment_id (str): The ID of the deployment.
78
+
79
+ Returns
80
+ -------
81
+ dr.Deployment: The deployment object.
82
+ """
83
+ dr.Client(self.api_key, self.datarobot_api_endpoint)
84
+ return dr.Deployment.get(deployment_id=deployment_id)
85
+
86
+ def call(
87
+ self,
88
+ deployment_id: str,
89
+ payload: dict[str, Any],
90
+ authorization_context: dict[str, Any] | None = None,
91
+ **kwargs: Any,
92
+ ) -> UnstructuredPredictionResult:
93
+ """Run the custom model tool using score_unstructured hook.
94
+
95
+ Args:
96
+ deployment_id (str): The ID of the deployment.
97
+ payload (dict[str, Any]): The input payload.
98
+ authorization_context (dict[str, Any] | None): Authorization context to use.
99
+ If None, uses the context from initialization or falls back to ContextVar.
100
+ **kwargs: Additional keyword arguments.
101
+
102
+ Returns
103
+ -------
104
+ UnstructuredPredictionResult: The response content and headers.
105
+ """
106
+ # Use explicit context, fall back to instance context, then ContextVar
107
+ auth_ctx = authorization_context or self._authorization_context
108
+ if auth_ctx is None:
109
+ try:
110
+ auth_ctx = get_authorization_context()
111
+ except LookupError:
112
+ auth_ctx = {}
113
+
114
+ data = {
115
+ "payload": payload,
116
+ "authorization_context": auth_ctx,
117
+ }
118
+ return predict_unstructured(
119
+ deployment=self.get_deployment(deployment_id),
120
+ data=json.dumps(data),
121
+ content_type="application/json",
122
+ **kwargs,
123
+ )
124
+
125
+ def score(
126
+ self, deployment_id: str, data_frame: pd.DataFrame, **kwargs: Any
127
+ ) -> PredictionResult:
128
+ """Run the custom model tool using score hook.
129
+
130
+ Args:
131
+ deployment_id (str): The ID of the deployment.
132
+ data_frame (pd.DataFrame): The input data frame.
133
+ **kwargs: Additional keyword arguments.
134
+
135
+ Returns
136
+ -------
137
+ PredictionResult: The response content and headers.
138
+ """
139
+ return predict(
140
+ deployment=self.get_deployment(deployment_id),
141
+ data_frame=data_frame,
142
+ **kwargs,
143
+ )
144
+
145
+ def chat(
146
+ self,
147
+ completion_create_params: CompletionCreateParams,
148
+ model: str,
149
+ authorization_context: dict[str, Any] | None = None,
150
+ ) -> ChatCompletion | Iterator[ChatCompletionChunk]:
151
+ """Run the custom model tool with the chat hook.
152
+
153
+ Args:
154
+ completion_create_params (CompletionCreateParams): Parameters for the chat completion.
155
+ model (str): The model to use.
156
+ authorization_context (dict[str, Any] | None): Authorization context to use.
157
+ If None, uses the context from initialization or falls back to ContextVar.
158
+
159
+ Returns
160
+ -------
161
+ Union[ChatCompletion, Iterator[ChatCompletionChunk]]: The chat completion response.
162
+ """
163
+ # Use explicit context, fall back to instance context, then ContextVar
164
+ auth_ctx = authorization_context or self._authorization_context
165
+ if auth_ctx is None:
166
+ try:
167
+ auth_ctx = get_authorization_context()
168
+ except LookupError:
169
+ auth_ctx = {}
170
+
171
+ extra_body = {
172
+ "authorization_context": auth_ctx,
173
+ }
174
+ return openai.chat.completions.create(
175
+ **completion_create_params,
176
+ model=model,
177
+ extra_body=extra_body,
178
+ )