datarobot-genai 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/__init__.py +19 -0
- datarobot_genai/core/__init__.py +0 -0
- datarobot_genai/core/agents/__init__.py +43 -0
- datarobot_genai/core/agents/base.py +195 -0
- datarobot_genai/core/chat/__init__.py +19 -0
- datarobot_genai/core/chat/auth.py +146 -0
- datarobot_genai/core/chat/client.py +178 -0
- datarobot_genai/core/chat/responses.py +297 -0
- datarobot_genai/core/cli/__init__.py +18 -0
- datarobot_genai/core/cli/agent_environment.py +47 -0
- datarobot_genai/core/cli/agent_kernel.py +211 -0
- datarobot_genai/core/custom_model.py +141 -0
- datarobot_genai/core/mcp/__init__.py +0 -0
- datarobot_genai/core/mcp/common.py +218 -0
- datarobot_genai/core/telemetry_agent.py +126 -0
- datarobot_genai/core/utils/__init__.py +3 -0
- datarobot_genai/core/utils/auth.py +234 -0
- datarobot_genai/core/utils/urls.py +64 -0
- datarobot_genai/crewai/__init__.py +24 -0
- datarobot_genai/crewai/agent.py +42 -0
- datarobot_genai/crewai/base.py +159 -0
- datarobot_genai/crewai/events.py +117 -0
- datarobot_genai/crewai/mcp.py +59 -0
- datarobot_genai/drmcp/__init__.py +78 -0
- datarobot_genai/drmcp/core/__init__.py +13 -0
- datarobot_genai/drmcp/core/auth.py +165 -0
- datarobot_genai/drmcp/core/clients.py +180 -0
- datarobot_genai/drmcp/core/config.py +250 -0
- datarobot_genai/drmcp/core/config_utils.py +174 -0
- datarobot_genai/drmcp/core/constants.py +18 -0
- datarobot_genai/drmcp/core/credentials.py +190 -0
- datarobot_genai/drmcp/core/dr_mcp_server.py +316 -0
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
- datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
- datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +128 -0
- datarobot_genai/drmcp/core/dynamic_prompts/register.py +206 -0
- datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
- datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
- datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
- datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
- datarobot_genai/drmcp/core/exceptions.py +25 -0
- datarobot_genai/drmcp/core/logging.py +98 -0
- datarobot_genai/drmcp/core/mcp_instance.py +542 -0
- datarobot_genai/drmcp/core/mcp_server_tools.py +129 -0
- datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
- datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
- datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
- datarobot_genai/drmcp/core/routes.py +436 -0
- datarobot_genai/drmcp/core/routes_utils.py +30 -0
- datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
- datarobot_genai/drmcp/core/telemetry.py +424 -0
- datarobot_genai/drmcp/core/tool_filter.py +108 -0
- datarobot_genai/drmcp/core/utils.py +131 -0
- datarobot_genai/drmcp/server.py +19 -0
- datarobot_genai/drmcp/test_utils/__init__.py +13 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +102 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +96 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +94 -0
- datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +234 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +151 -0
- datarobot_genai/drmcp/test_utils/utils.py +91 -0
- datarobot_genai/drmcp/tools/__init__.py +14 -0
- datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
- datarobot_genai/drmcp/tools/predictive/data.py +97 -0
- datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
- datarobot_genai/drmcp/tools/predictive/model.py +148 -0
- datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
- datarobot_genai/drmcp/tools/predictive/project.py +72 -0
- datarobot_genai/drmcp/tools/predictive/training.py +651 -0
- datarobot_genai/langgraph/__init__.py +0 -0
- datarobot_genai/langgraph/agent.py +341 -0
- datarobot_genai/langgraph/mcp.py +73 -0
- datarobot_genai/llama_index/__init__.py +16 -0
- datarobot_genai/llama_index/agent.py +50 -0
- datarobot_genai/llama_index/base.py +299 -0
- datarobot_genai/llama_index/mcp.py +79 -0
- datarobot_genai/nat/__init__.py +0 -0
- datarobot_genai/nat/agent.py +258 -0
- datarobot_genai/nat/datarobot_llm_clients.py +249 -0
- datarobot_genai/nat/datarobot_llm_providers.py +130 -0
- datarobot_genai/py.typed +0 -0
- datarobot_genai-0.2.0.dist-info/METADATA +139 -0
- datarobot_genai-0.2.0.dist-info/RECORD +101 -0
- datarobot_genai-0.2.0.dist-info/WHEEL +4 -0
- datarobot_genai-0.2.0.dist-info/entry_points.txt +3 -0
- datarobot_genai-0.2.0.dist-info/licenses/AUTHORS +2 -0
- datarobot_genai-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""OpenAI-compatible response helpers for chat interactions."""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import queue
|
|
19
|
+
import time
|
|
20
|
+
import traceback as tb
|
|
21
|
+
import uuid
|
|
22
|
+
from asyncio import AbstractEventLoop
|
|
23
|
+
from collections.abc import AsyncGenerator
|
|
24
|
+
from collections.abc import AsyncIterator
|
|
25
|
+
from collections.abc import Iterator
|
|
26
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
27
|
+
from typing import Any
|
|
28
|
+
from typing import TypeVar
|
|
29
|
+
|
|
30
|
+
from ag_ui.core import BaseEvent
|
|
31
|
+
from ag_ui.core import Event
|
|
32
|
+
from ag_ui.core import TextMessageChunkEvent
|
|
33
|
+
from ag_ui.core import TextMessageContentEvent
|
|
34
|
+
from openai.types import CompletionUsage
|
|
35
|
+
from openai.types.chat import ChatCompletion
|
|
36
|
+
from openai.types.chat import ChatCompletionChunk
|
|
37
|
+
from openai.types.chat import ChatCompletionMessage
|
|
38
|
+
from openai.types.chat.chat_completion import Choice
|
|
39
|
+
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
|
40
|
+
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
|
41
|
+
from ragas import MultiTurnSample
|
|
42
|
+
|
|
43
|
+
from datarobot_genai.core.agents import default_usage_metrics
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CustomModelChatResponse(ChatCompletion):
|
|
47
|
+
pipeline_interactions: str | None = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class CustomModelStreamingResponse(ChatCompletionChunk):
|
|
51
|
+
pipeline_interactions: str | None = None
|
|
52
|
+
event: Event | None = None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def to_custom_model_chat_response(
|
|
56
|
+
response_text: str,
|
|
57
|
+
pipeline_interactions: MultiTurnSample | None,
|
|
58
|
+
usage_metrics: dict[str, int],
|
|
59
|
+
model: str | object | None,
|
|
60
|
+
) -> CustomModelChatResponse:
|
|
61
|
+
"""Convert the OpenAI ChatCompletion response to CustomModelChatResponse."""
|
|
62
|
+
choice = Choice(
|
|
63
|
+
index=0,
|
|
64
|
+
message=ChatCompletionMessage(role="assistant", content=response_text),
|
|
65
|
+
finish_reason="stop",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if model is None:
|
|
69
|
+
model = "unspecified-model"
|
|
70
|
+
else:
|
|
71
|
+
model = str(model)
|
|
72
|
+
|
|
73
|
+
required_usage_metrics: dict[str, int] = {
|
|
74
|
+
"completion_tokens": 0,
|
|
75
|
+
"prompt_tokens": 0,
|
|
76
|
+
"total_tokens": 0,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
return CustomModelChatResponse(
|
|
80
|
+
id=str(uuid.uuid4()),
|
|
81
|
+
object="chat.completion",
|
|
82
|
+
choices=[choice],
|
|
83
|
+
created=int(time.time()),
|
|
84
|
+
model=model,
|
|
85
|
+
usage=CompletionUsage.model_validate(required_usage_metrics | usage_metrics),
|
|
86
|
+
pipeline_interactions=pipeline_interactions.model_dump_json()
|
|
87
|
+
if pipeline_interactions
|
|
88
|
+
else None,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def to_custom_model_streaming_response(
|
|
93
|
+
thread_pool_executor: ThreadPoolExecutor,
|
|
94
|
+
event_loop: AbstractEventLoop,
|
|
95
|
+
streaming_response_generator: AsyncGenerator[
|
|
96
|
+
tuple[str | Event, MultiTurnSample | None, dict[str, int]], None
|
|
97
|
+
],
|
|
98
|
+
model: str | object | None,
|
|
99
|
+
) -> Iterator[CustomModelStreamingResponse]:
|
|
100
|
+
"""Convert the OpenAI ChatCompletionChunk response to CustomModelStreamingResponse."""
|
|
101
|
+
completion_id = str(uuid.uuid4())
|
|
102
|
+
created = int(time.time())
|
|
103
|
+
|
|
104
|
+
last_pipeline_interactions = None
|
|
105
|
+
last_usage_metrics = None
|
|
106
|
+
|
|
107
|
+
if model is None:
|
|
108
|
+
model = "unspecified-model"
|
|
109
|
+
else:
|
|
110
|
+
model = str(model)
|
|
111
|
+
|
|
112
|
+
required_usage_metrics = default_usage_metrics()
|
|
113
|
+
try:
|
|
114
|
+
agent_response = aiter(streaming_response_generator)
|
|
115
|
+
while True:
|
|
116
|
+
try:
|
|
117
|
+
(
|
|
118
|
+
response_text_or_event,
|
|
119
|
+
pipeline_interactions,
|
|
120
|
+
usage_metrics,
|
|
121
|
+
) = thread_pool_executor.submit(
|
|
122
|
+
event_loop.run_until_complete, anext(agent_response)
|
|
123
|
+
).result()
|
|
124
|
+
last_pipeline_interactions = pipeline_interactions
|
|
125
|
+
last_usage_metrics = usage_metrics
|
|
126
|
+
|
|
127
|
+
if isinstance(response_text_or_event, str) and response_text_or_event:
|
|
128
|
+
choice = ChunkChoice(
|
|
129
|
+
index=0,
|
|
130
|
+
delta=ChoiceDelta(role="assistant", content=response_text_or_event),
|
|
131
|
+
finish_reason=None,
|
|
132
|
+
)
|
|
133
|
+
yield CustomModelStreamingResponse(
|
|
134
|
+
id=completion_id,
|
|
135
|
+
object="chat.completion.chunk",
|
|
136
|
+
created=created,
|
|
137
|
+
model=model,
|
|
138
|
+
choices=[choice],
|
|
139
|
+
usage=CompletionUsage.model_validate(required_usage_metrics | usage_metrics)
|
|
140
|
+
if usage_metrics
|
|
141
|
+
else None,
|
|
142
|
+
)
|
|
143
|
+
elif isinstance(response_text_or_event, BaseEvent):
|
|
144
|
+
content = ""
|
|
145
|
+
if isinstance(
|
|
146
|
+
response_text_or_event, (TextMessageContentEvent, TextMessageChunkEvent)
|
|
147
|
+
):
|
|
148
|
+
content = response_text_or_event.delta or content
|
|
149
|
+
choice = ChunkChoice(
|
|
150
|
+
index=0,
|
|
151
|
+
delta=ChoiceDelta(role="assistant", content=content),
|
|
152
|
+
finish_reason=None,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
yield CustomModelStreamingResponse(
|
|
156
|
+
id=completion_id,
|
|
157
|
+
object="chat.completion.chunk",
|
|
158
|
+
created=created,
|
|
159
|
+
model=model,
|
|
160
|
+
choices=[choice],
|
|
161
|
+
usage=CompletionUsage.model_validate(required_usage_metrics | usage_metrics)
|
|
162
|
+
if usage_metrics
|
|
163
|
+
else None,
|
|
164
|
+
event=response_text_or_event,
|
|
165
|
+
)
|
|
166
|
+
except StopAsyncIteration:
|
|
167
|
+
break
|
|
168
|
+
event_loop.run_until_complete(streaming_response_generator.aclose())
|
|
169
|
+
# Yield final chunk indicating end of stream
|
|
170
|
+
choice = ChunkChoice(
|
|
171
|
+
index=0,
|
|
172
|
+
delta=ChoiceDelta(role="assistant"),
|
|
173
|
+
finish_reason="stop",
|
|
174
|
+
)
|
|
175
|
+
yield CustomModelStreamingResponse(
|
|
176
|
+
id=completion_id,
|
|
177
|
+
object="chat.completion.chunk",
|
|
178
|
+
created=created,
|
|
179
|
+
model=model,
|
|
180
|
+
choices=[choice],
|
|
181
|
+
usage=CompletionUsage.model_validate(required_usage_metrics | last_usage_metrics)
|
|
182
|
+
if last_usage_metrics
|
|
183
|
+
else None,
|
|
184
|
+
pipeline_interactions=last_pipeline_interactions.model_dump_json()
|
|
185
|
+
if last_pipeline_interactions
|
|
186
|
+
else None,
|
|
187
|
+
)
|
|
188
|
+
except Exception as e:
|
|
189
|
+
tb.print_exc()
|
|
190
|
+
created = int(time.time())
|
|
191
|
+
choice = ChunkChoice(
|
|
192
|
+
index=0,
|
|
193
|
+
delta=ChoiceDelta(role="assistant", content=str(e), refusal="error"),
|
|
194
|
+
finish_reason="stop",
|
|
195
|
+
)
|
|
196
|
+
yield CustomModelStreamingResponse(
|
|
197
|
+
id=completion_id,
|
|
198
|
+
object="chat.completion.chunk",
|
|
199
|
+
created=created,
|
|
200
|
+
model=model,
|
|
201
|
+
choices=[choice],
|
|
202
|
+
usage=None,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def streaming_iterator_to_custom_model_streaming_response(
|
|
207
|
+
streaming_response_iterator: Iterator[tuple[str, MultiTurnSample | None, dict[str, int]]],
|
|
208
|
+
model: str | object | None,
|
|
209
|
+
) -> Iterator[CustomModelStreamingResponse]:
|
|
210
|
+
"""Convert the OpenAI ChatCompletionChunk response to CustomModelStreamingResponse."""
|
|
211
|
+
completion_id = str(uuid.uuid4())
|
|
212
|
+
created = int(time.time())
|
|
213
|
+
|
|
214
|
+
last_pipeline_interactions = None
|
|
215
|
+
last_usage_metrics = None
|
|
216
|
+
|
|
217
|
+
while True:
|
|
218
|
+
try:
|
|
219
|
+
(
|
|
220
|
+
response_text,
|
|
221
|
+
pipeline_interactions,
|
|
222
|
+
usage_metrics,
|
|
223
|
+
) = next(streaming_response_iterator)
|
|
224
|
+
last_pipeline_interactions = pipeline_interactions
|
|
225
|
+
last_usage_metrics = usage_metrics
|
|
226
|
+
|
|
227
|
+
if response_text:
|
|
228
|
+
choice = ChunkChoice(
|
|
229
|
+
index=0,
|
|
230
|
+
delta=ChoiceDelta(role="assistant", content=response_text),
|
|
231
|
+
finish_reason=None,
|
|
232
|
+
)
|
|
233
|
+
yield CustomModelStreamingResponse(
|
|
234
|
+
id=completion_id,
|
|
235
|
+
object="chat.completion.chunk",
|
|
236
|
+
created=created,
|
|
237
|
+
model=model,
|
|
238
|
+
choices=[choice],
|
|
239
|
+
usage=CompletionUsage(**usage_metrics) if usage_metrics else None,
|
|
240
|
+
)
|
|
241
|
+
except StopIteration:
|
|
242
|
+
break
|
|
243
|
+
# Yield final chunk indicating end of stream
|
|
244
|
+
choice = ChunkChoice(
|
|
245
|
+
index=0,
|
|
246
|
+
delta=ChoiceDelta(role="assistant"),
|
|
247
|
+
finish_reason="stop",
|
|
248
|
+
)
|
|
249
|
+
yield CustomModelStreamingResponse(
|
|
250
|
+
id=completion_id,
|
|
251
|
+
object="chat.completion.chunk",
|
|
252
|
+
created=created,
|
|
253
|
+
model=model,
|
|
254
|
+
choices=[choice],
|
|
255
|
+
usage=CompletionUsage(**last_usage_metrics) if last_usage_metrics else None,
|
|
256
|
+
pipeline_interactions=last_pipeline_interactions.model_dump_json()
|
|
257
|
+
if last_pipeline_interactions
|
|
258
|
+
else None,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
T = TypeVar("T")
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def async_gen_to_sync_thread(
|
|
266
|
+
async_iterator: AsyncIterator[T],
|
|
267
|
+
thread_pool_executor: ThreadPoolExecutor,
|
|
268
|
+
event_loop: asyncio.AbstractEventLoop,
|
|
269
|
+
) -> Iterator[T]:
|
|
270
|
+
"""Run an async iterator in a separate thread and provide a sync iterator."""
|
|
271
|
+
# A thread-safe queue for communication
|
|
272
|
+
sync_queue: queue.Queue[Any] = queue.Queue()
|
|
273
|
+
# A sentinel object to signal the end of the async generator
|
|
274
|
+
SENTINEL = object() # noqa: N806
|
|
275
|
+
|
|
276
|
+
async def run_async_to_queue() -> None:
|
|
277
|
+
"""Run in the separate thread's event loop."""
|
|
278
|
+
try:
|
|
279
|
+
async for item in async_iterator:
|
|
280
|
+
sync_queue.put(item)
|
|
281
|
+
except Exception as e:
|
|
282
|
+
# Put the exception on the queue to be re-raised in the main thread
|
|
283
|
+
sync_queue.put(e)
|
|
284
|
+
finally:
|
|
285
|
+
# Signal the end of iteration
|
|
286
|
+
sync_queue.put(SENTINEL)
|
|
287
|
+
|
|
288
|
+
thread_pool_executor.submit(event_loop.run_until_complete, run_async_to_queue()).result()
|
|
289
|
+
|
|
290
|
+
# The main thread consumes items synchronously
|
|
291
|
+
while True:
|
|
292
|
+
item = sync_queue.get()
|
|
293
|
+
if item is SENTINEL:
|
|
294
|
+
break
|
|
295
|
+
if isinstance(item, Exception):
|
|
296
|
+
raise item
|
|
297
|
+
yield item
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .agent_environment import AgentEnvironment
|
|
16
|
+
from .agent_kernel import AgentKernel
|
|
17
|
+
|
|
18
|
+
__all__ = ["AgentEnvironment", "AgentKernel"]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
from datarobot_genai.core.cli.agent_kernel import AgentKernel
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AgentEnvironment:
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
api_token: str | None = None,
|
|
24
|
+
base_url: str | None = None,
|
|
25
|
+
):
|
|
26
|
+
self.api_token = os.environ.get("DATAROBOT_API_TOKEN") or api_token
|
|
27
|
+
if not self.api_token:
|
|
28
|
+
raise ValueError(
|
|
29
|
+
"Missing DataRobot API token. Set the DATAROBOT_API_TOKEN "
|
|
30
|
+
"environment variable or provide it explicitly."
|
|
31
|
+
)
|
|
32
|
+
self.base_url = (
|
|
33
|
+
os.environ.get("DATAROBOT_ENDPOINT") or base_url or "https://app.datarobot.com"
|
|
34
|
+
)
|
|
35
|
+
if not self.base_url:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
"Missing DataRobot endpoint. Set the DATAROBOT_ENDPOINT environment "
|
|
38
|
+
"variable or provide it explicitly."
|
|
39
|
+
)
|
|
40
|
+
self.base_url = self.base_url.replace("/api/v2", "")
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def interface(self) -> AgentKernel:
|
|
44
|
+
return AgentKernel(
|
|
45
|
+
api_token=str(self.api_token),
|
|
46
|
+
base_url=str(self.base_url),
|
|
47
|
+
)
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import json
|
|
15
|
+
import os
|
|
16
|
+
import time
|
|
17
|
+
from typing import Any
|
|
18
|
+
from typing import cast
|
|
19
|
+
|
|
20
|
+
import requests
|
|
21
|
+
from openai import OpenAI
|
|
22
|
+
from openai import Stream
|
|
23
|
+
from openai.types.chat import ChatCompletion
|
|
24
|
+
from openai.types.chat import ChatCompletionChunk
|
|
25
|
+
from openai.types.chat import ChatCompletionSystemMessageParam
|
|
26
|
+
from openai.types.chat import ChatCompletionUserMessageParam
|
|
27
|
+
from openai.types.chat.completion_create_params import CompletionCreateParamsNonStreaming
|
|
28
|
+
from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AgentKernel:
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
api_token: str,
|
|
35
|
+
base_url: str,
|
|
36
|
+
):
|
|
37
|
+
self.base_url = base_url
|
|
38
|
+
self.api_token = api_token
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def headers(self) -> dict[str, str]:
|
|
42
|
+
return {
|
|
43
|
+
"Authorization": f"Token {self.api_token}",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
def load_completion_json(self, completion_json: str) -> CompletionCreateParamsNonStreaming:
|
|
47
|
+
"""Load the completion JSON from a file or return an empty prompt."""
|
|
48
|
+
if not os.path.exists(completion_json):
|
|
49
|
+
raise FileNotFoundError(f"Completion JSON file not found: {completion_json}")
|
|
50
|
+
|
|
51
|
+
with open(completion_json) as f:
|
|
52
|
+
completion_data = json.load(f)
|
|
53
|
+
|
|
54
|
+
completion_create_params = CompletionCreateParamsNonStreaming(
|
|
55
|
+
**completion_data, # type: ignore[typeddict-item]
|
|
56
|
+
)
|
|
57
|
+
return cast(CompletionCreateParamsNonStreaming, completion_create_params)
|
|
58
|
+
|
|
59
|
+
def construct_prompt(
|
|
60
|
+
self, user_prompt: str, verbose: bool, stream: bool = False
|
|
61
|
+
) -> CompletionCreateParamsNonStreaming | CompletionCreateParamsStreaming:
|
|
62
|
+
extra_body = {
|
|
63
|
+
"api_key": self.api_token,
|
|
64
|
+
"api_base": self.base_url,
|
|
65
|
+
"verbose": verbose,
|
|
66
|
+
}
|
|
67
|
+
if stream:
|
|
68
|
+
return CompletionCreateParamsStreaming(
|
|
69
|
+
model="datarobot-deployed-llm",
|
|
70
|
+
messages=[
|
|
71
|
+
ChatCompletionSystemMessageParam(
|
|
72
|
+
content="You are a helpful assistant",
|
|
73
|
+
role="system",
|
|
74
|
+
),
|
|
75
|
+
ChatCompletionUserMessageParam(
|
|
76
|
+
content=user_prompt,
|
|
77
|
+
role="user",
|
|
78
|
+
),
|
|
79
|
+
],
|
|
80
|
+
n=1,
|
|
81
|
+
temperature=1,
|
|
82
|
+
stream=True,
|
|
83
|
+
extra_body=extra_body, # type: ignore[typeddict-unknown-key]
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
return CompletionCreateParamsNonStreaming(
|
|
87
|
+
model="datarobot-deployed-llm",
|
|
88
|
+
messages=[
|
|
89
|
+
ChatCompletionSystemMessageParam(
|
|
90
|
+
content="You are a helpful assistant",
|
|
91
|
+
role="system",
|
|
92
|
+
),
|
|
93
|
+
ChatCompletionUserMessageParam(
|
|
94
|
+
content=user_prompt,
|
|
95
|
+
role="user",
|
|
96
|
+
),
|
|
97
|
+
],
|
|
98
|
+
n=1,
|
|
99
|
+
temperature=1,
|
|
100
|
+
extra_body=extra_body, # type: ignore[typeddict-unknown-key]
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def local(
|
|
104
|
+
self,
|
|
105
|
+
user_prompt: str,
|
|
106
|
+
completion_json: str = "",
|
|
107
|
+
stream: bool = False,
|
|
108
|
+
config: Any | None = None,
|
|
109
|
+
) -> ChatCompletion | Stream[ChatCompletionChunk]:
|
|
110
|
+
chat_api_url = config.agent_endpoint if config else self.base_url
|
|
111
|
+
print(chat_api_url)
|
|
112
|
+
|
|
113
|
+
return self._do_chat_completion(chat_api_url, user_prompt, completion_json, stream=stream)
|
|
114
|
+
|
|
115
|
+
def custom_model(self, custom_model_id: str, user_prompt: str, timeout: float = 300) -> str:
|
|
116
|
+
chat_api_url = (
|
|
117
|
+
f"{self.base_url}/api/v2/genai/agents/fromCustomModel/{custom_model_id}/chat/"
|
|
118
|
+
)
|
|
119
|
+
print(chat_api_url)
|
|
120
|
+
|
|
121
|
+
headers = {
|
|
122
|
+
"Authorization": f"Bearer {os.environ['DATAROBOT_API_TOKEN']}",
|
|
123
|
+
"Content-Type": "application/json",
|
|
124
|
+
}
|
|
125
|
+
data = {"messages": [{"role": "user", "content": user_prompt}]}
|
|
126
|
+
|
|
127
|
+
print(f'Querying custom model with prompt: "{data}"')
|
|
128
|
+
print(
|
|
129
|
+
"Please wait... This may take 1-2 minutes the first time "
|
|
130
|
+
"you run this as a codespace is provisioned "
|
|
131
|
+
"for the custom model to execute."
|
|
132
|
+
)
|
|
133
|
+
response = requests.post(
|
|
134
|
+
chat_api_url,
|
|
135
|
+
headers=headers,
|
|
136
|
+
json=data,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if not response.ok or not response.headers.get("Location"):
|
|
140
|
+
raise Exception(response.text)
|
|
141
|
+
# Wait for the agent to complete
|
|
142
|
+
status_location = response.headers["Location"]
|
|
143
|
+
while response.ok:
|
|
144
|
+
time.sleep(1)
|
|
145
|
+
response = requests.get(
|
|
146
|
+
status_location, headers=headers, allow_redirects=False, timeout=timeout
|
|
147
|
+
)
|
|
148
|
+
if response.status_code == 303:
|
|
149
|
+
agent_response = requests.get(response.headers["Location"], headers=headers).json()
|
|
150
|
+
# Show the agent response
|
|
151
|
+
break
|
|
152
|
+
else:
|
|
153
|
+
status_response = response.json()
|
|
154
|
+
if status_response["status"] in ["ERROR", "ABORTED"]:
|
|
155
|
+
raise Exception(status_response)
|
|
156
|
+
else:
|
|
157
|
+
raise Exception(response.content)
|
|
158
|
+
|
|
159
|
+
if "errorMessage" in agent_response and agent_response["errorMessage"]:
|
|
160
|
+
return (
|
|
161
|
+
f"Error: "
|
|
162
|
+
f"{agent_response.get('errorMessage', 'No error message available')}"
|
|
163
|
+
f"Error details:"
|
|
164
|
+
f"{agent_response.get('errorDetails', 'No details available')}"
|
|
165
|
+
)
|
|
166
|
+
elif "choices" in agent_response:
|
|
167
|
+
return str(agent_response["choices"][0]["message"]["content"])
|
|
168
|
+
else:
|
|
169
|
+
return str(agent_response)
|
|
170
|
+
|
|
171
|
+
def deployment(
|
|
172
|
+
self,
|
|
173
|
+
deployment_id: str,
|
|
174
|
+
user_prompt: str,
|
|
175
|
+
completion_json: str = "",
|
|
176
|
+
stream: bool = False,
|
|
177
|
+
) -> ChatCompletion | Stream[ChatCompletionChunk]:
|
|
178
|
+
chat_api_url = f"{self.base_url}/api/v2/deployments/{deployment_id}/"
|
|
179
|
+
print(chat_api_url)
|
|
180
|
+
|
|
181
|
+
return self._do_chat_completion(chat_api_url, user_prompt, completion_json, stream=stream)
|
|
182
|
+
|
|
183
|
+
def _do_chat_completion(
|
|
184
|
+
self,
|
|
185
|
+
url: str,
|
|
186
|
+
user_prompt: str,
|
|
187
|
+
completion_json: str = "",
|
|
188
|
+
stream: bool = False,
|
|
189
|
+
) -> ChatCompletion | Stream[ChatCompletionChunk]:
|
|
190
|
+
if len(user_prompt) > 0:
|
|
191
|
+
completion_create_params = self.construct_prompt(
|
|
192
|
+
user_prompt, stream=stream, verbose=True
|
|
193
|
+
)
|
|
194
|
+
else:
|
|
195
|
+
completion_create_params = self.load_completion_json(completion_json)
|
|
196
|
+
|
|
197
|
+
openai_client = OpenAI(
|
|
198
|
+
base_url=url,
|
|
199
|
+
api_key=self.api_token,
|
|
200
|
+
_strict_response_validation=False,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
print(f'Querying deployment with prompt: "{completion_create_params}"')
|
|
204
|
+
print(
|
|
205
|
+
"Please wait for the agent to complete the response. "
|
|
206
|
+
"This may take a few seconds to minutes "
|
|
207
|
+
"depending on the complexity of the agent workflow."
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
completion = openai_client.chat.completions.create(**completion_create_params)
|
|
211
|
+
return completion
|