kagent-adk 0.6.8__tar.gz → 0.6.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kagent-adk might be problematic. Click here for more details.
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/PKG-INFO +1 -1
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/pyproject.toml +6 -1
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/src/kagent/adk/__init__.py +1 -1
- kagent_adk-0.6.10/src/kagent/adk/models/__init__.py +3 -0
- kagent_adk-0.6.10/src/kagent/adk/models/_openai.py +410 -0
- kagent_adk-0.6.8/src/kagent/adk/models.py → kagent_adk-0.6.10/src/kagent/adk/types.py +5 -2
- kagent_adk-0.6.10/tests/__init__.py +0 -0
- kagent_adk-0.6.10/tests/unittests/__init__.py +0 -0
- kagent_adk-0.6.10/tests/unittests/models/__init__.py +0 -0
- kagent_adk-0.6.10/tests/unittests/models/test_openai.py +338 -0
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/.gitignore +0 -0
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/.python-version +0 -0
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/README.md +0 -0
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/src/kagent/adk/_a2a.py +0 -0
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/src/kagent/adk/_agent_executor.py +0 -0
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/src/kagent/adk/_session_service.py +0 -0
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/src/kagent/adk/_token.py +0 -0
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/src/kagent/adk/cli.py +0 -0
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/src/kagent/adk/converters/__init__.py +0 -0
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/src/kagent/adk/converters/event_converter.py +0 -0
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/src/kagent/adk/converters/part_converter.py +0 -0
- {kagent_adk-0.6.8 → kagent_adk-0.6.10}/src/kagent/adk/converters/request_converter.py +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "kagent-adk"
|
|
7
|
-
version = "0.6.
|
|
7
|
+
version = "0.6.10"
|
|
8
8
|
description = "kagent-adk is an sdk for integrating adk agents with kagent"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.12.11"
|
|
@@ -50,3 +50,8 @@ packages = ["src/kagent"]
|
|
|
50
50
|
|
|
51
51
|
[tool.ruff]
|
|
52
52
|
extend = "../../pyproject.toml"
|
|
53
|
+
|
|
54
|
+
[tool.pytest.ini_options]
|
|
55
|
+
testpaths = ["tests"]
|
|
56
|
+
asyncio_default_fixture_loop_scope = "function"
|
|
57
|
+
asyncio_mode = "auto"
|
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, Literal, Optional
|
|
8
|
+
|
|
9
|
+
from google.adk.models import BaseLlm
|
|
10
|
+
from google.adk.models.llm_response import LlmResponse
|
|
11
|
+
from google.genai import types
|
|
12
|
+
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
|
13
|
+
from openai.types.chat import (
|
|
14
|
+
ChatCompletion,
|
|
15
|
+
ChatCompletionAssistantMessageParam,
|
|
16
|
+
ChatCompletionContentPartImageParam,
|
|
17
|
+
ChatCompletionContentPartTextParam,
|
|
18
|
+
ChatCompletionMessageParam,
|
|
19
|
+
ChatCompletionSystemMessageParam,
|
|
20
|
+
ChatCompletionToolMessageParam,
|
|
21
|
+
ChatCompletionToolParam,
|
|
22
|
+
ChatCompletionUserMessageParam,
|
|
23
|
+
)
|
|
24
|
+
from openai.types.chat.chat_completion_message_tool_call_param import (
|
|
25
|
+
ChatCompletionMessageToolCallParam,
|
|
26
|
+
)
|
|
27
|
+
from openai.types.chat.chat_completion_message_tool_call_param import (
|
|
28
|
+
Function as ToolCallFunction,
|
|
29
|
+
)
|
|
30
|
+
from openai.types.shared_params import FunctionDefinition
|
|
31
|
+
from pydantic import Field
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from google.adk.models.llm_request import LlmRequest
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _convert_role_to_openai(role: Optional[str]) -> str:
|
|
38
|
+
"""Convert google.genai role to OpenAI role."""
|
|
39
|
+
if role in ["model", "assistant"]:
|
|
40
|
+
return "assistant"
|
|
41
|
+
elif role == "system":
|
|
42
|
+
return "system"
|
|
43
|
+
else:
|
|
44
|
+
return "user"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _convert_content_to_openai_messages(
|
|
48
|
+
contents: list[types.Content], system_instruction: Optional[str] = None
|
|
49
|
+
) -> list[ChatCompletionMessageParam]:
|
|
50
|
+
"""Convert google.genai Content list to OpenAI messages format."""
|
|
51
|
+
messages: list[ChatCompletionMessageParam] = []
|
|
52
|
+
|
|
53
|
+
# Add system message if provided
|
|
54
|
+
if system_instruction:
|
|
55
|
+
system_message: ChatCompletionSystemMessageParam = {"role": "system", "content": system_instruction}
|
|
56
|
+
messages.append(system_message)
|
|
57
|
+
|
|
58
|
+
# First pass: collect all function responses to match with tool calls
|
|
59
|
+
all_function_responses = {}
|
|
60
|
+
for content in contents:
|
|
61
|
+
for part in content.parts or []:
|
|
62
|
+
if part.function_response:
|
|
63
|
+
tool_call_id = part.function_response.id or "call_1"
|
|
64
|
+
all_function_responses[tool_call_id] = part.function_response
|
|
65
|
+
|
|
66
|
+
for content in contents:
|
|
67
|
+
role = _convert_role_to_openai(content.role)
|
|
68
|
+
|
|
69
|
+
# Separate different types of parts
|
|
70
|
+
text_parts = []
|
|
71
|
+
function_calls = []
|
|
72
|
+
function_responses = []
|
|
73
|
+
image_parts = []
|
|
74
|
+
|
|
75
|
+
for part in content.parts or []:
|
|
76
|
+
if part.text:
|
|
77
|
+
text_parts.append(part.text)
|
|
78
|
+
elif part.function_call:
|
|
79
|
+
function_calls.append(part)
|
|
80
|
+
elif part.function_response:
|
|
81
|
+
function_responses.append(part)
|
|
82
|
+
elif part.inline_data and part.inline_data.mime_type and part.inline_data.mime_type.startswith("image"):
|
|
83
|
+
if part.inline_data.data:
|
|
84
|
+
image_data = base64.b64encode(part.inline_data.data).decode()
|
|
85
|
+
image_part: ChatCompletionContentPartImageParam = {
|
|
86
|
+
"type": "image_url",
|
|
87
|
+
"image_url": {"url": f"data:{part.inline_data.mime_type};base64,{image_data}"},
|
|
88
|
+
}
|
|
89
|
+
image_parts.append(image_part)
|
|
90
|
+
|
|
91
|
+
# Function responses are now handled together with function calls
|
|
92
|
+
# This ensures proper pairing and prevents orphaned tool messages
|
|
93
|
+
|
|
94
|
+
# Handle function calls (assistant messages with tool_calls)
|
|
95
|
+
if function_calls:
|
|
96
|
+
tool_calls = []
|
|
97
|
+
tool_response_messages = []
|
|
98
|
+
|
|
99
|
+
for func_call in function_calls:
|
|
100
|
+
tool_call_function: ToolCallFunction = {
|
|
101
|
+
"name": func_call.function_call.name or "",
|
|
102
|
+
"arguments": str(func_call.function_call.args) if func_call.function_call.args else "{}",
|
|
103
|
+
}
|
|
104
|
+
tool_call_id = func_call.function_call.id or "call_1"
|
|
105
|
+
tool_call: ChatCompletionMessageToolCallParam = {
|
|
106
|
+
"id": tool_call_id,
|
|
107
|
+
"type": "function",
|
|
108
|
+
"function": tool_call_function,
|
|
109
|
+
}
|
|
110
|
+
tool_calls.append(tool_call)
|
|
111
|
+
|
|
112
|
+
# Check if we have a response for this tool call
|
|
113
|
+
if tool_call_id in all_function_responses:
|
|
114
|
+
func_response = all_function_responses[tool_call_id]
|
|
115
|
+
tool_message: ChatCompletionToolMessageParam = {
|
|
116
|
+
"role": "tool",
|
|
117
|
+
"tool_call_id": tool_call_id,
|
|
118
|
+
"content": str(func_response.response.get("result", "")) if func_response.response else "",
|
|
119
|
+
}
|
|
120
|
+
tool_response_messages.append(tool_message)
|
|
121
|
+
else:
|
|
122
|
+
# If no response is available, create a placeholder response
|
|
123
|
+
# This prevents the OpenAI API error
|
|
124
|
+
tool_message: ChatCompletionToolMessageParam = {
|
|
125
|
+
"role": "tool",
|
|
126
|
+
"tool_call_id": tool_call_id,
|
|
127
|
+
"content": "No response available for this function call.",
|
|
128
|
+
}
|
|
129
|
+
tool_response_messages.append(tool_message)
|
|
130
|
+
|
|
131
|
+
# Create assistant message with tool calls
|
|
132
|
+
text_content = "\n".join(text_parts) if text_parts else None
|
|
133
|
+
assistant_message: ChatCompletionAssistantMessageParam = {
|
|
134
|
+
"role": "assistant",
|
|
135
|
+
"content": text_content,
|
|
136
|
+
"tool_calls": tool_calls,
|
|
137
|
+
}
|
|
138
|
+
messages.append(assistant_message)
|
|
139
|
+
|
|
140
|
+
# Add all tool response messages immediately after the assistant message
|
|
141
|
+
messages.extend(tool_response_messages)
|
|
142
|
+
|
|
143
|
+
# Handle regular text/image messages (only if no function calls)
|
|
144
|
+
elif text_parts or image_parts:
|
|
145
|
+
if role == "user":
|
|
146
|
+
if image_parts and text_parts:
|
|
147
|
+
# Multi-modal content
|
|
148
|
+
text_part: ChatCompletionContentPartTextParam = {"type": "text", "text": "\n".join(text_parts)}
|
|
149
|
+
content_parts = [text_part] + image_parts
|
|
150
|
+
user_message: ChatCompletionUserMessageParam = {"role": "user", "content": content_parts}
|
|
151
|
+
elif image_parts:
|
|
152
|
+
# Image only
|
|
153
|
+
user_message: ChatCompletionUserMessageParam = {"role": "user", "content": image_parts}
|
|
154
|
+
else:
|
|
155
|
+
# Text only
|
|
156
|
+
user_message: ChatCompletionUserMessageParam = {"role": "user", "content": "\n".join(text_parts)}
|
|
157
|
+
messages.append(user_message)
|
|
158
|
+
elif role == "assistant":
|
|
159
|
+
# Assistant messages with text (no tool calls)
|
|
160
|
+
assistant_message: ChatCompletionAssistantMessageParam = {
|
|
161
|
+
"role": "assistant",
|
|
162
|
+
"content": "\n".join(text_parts),
|
|
163
|
+
}
|
|
164
|
+
messages.append(assistant_message)
|
|
165
|
+
|
|
166
|
+
return messages
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _update_type_string(value_dict: dict[str, Any]):
|
|
170
|
+
"""Updates 'type' field to expected JSON schema format."""
|
|
171
|
+
if "type" in value_dict:
|
|
172
|
+
value_dict["type"] = value_dict["type"].lower()
|
|
173
|
+
|
|
174
|
+
if "items" in value_dict:
|
|
175
|
+
# 'type' field could exist for items as well, this would be the case if
|
|
176
|
+
# items represent primitive types.
|
|
177
|
+
_update_type_string(value_dict["items"])
|
|
178
|
+
|
|
179
|
+
if "properties" in value_dict["items"]:
|
|
180
|
+
# There could be properties as well on the items, especially if the items
|
|
181
|
+
# are complex object themselves. We recursively traverse each individual
|
|
182
|
+
# property as well and fix the "type" value.
|
|
183
|
+
for _, value in value_dict["items"]["properties"].items():
|
|
184
|
+
_update_type_string(value)
|
|
185
|
+
|
|
186
|
+
if "properties" in value_dict:
|
|
187
|
+
# Handle nested properties
|
|
188
|
+
for _, value in value_dict["properties"].items():
|
|
189
|
+
_update_type_string(value)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _convert_tools_to_openai(tools: list[types.Tool]) -> list[ChatCompletionToolParam]:
|
|
193
|
+
"""Convert google.genai Tools to OpenAI tools format."""
|
|
194
|
+
openai_tools: list[ChatCompletionToolParam] = []
|
|
195
|
+
|
|
196
|
+
for tool in tools:
|
|
197
|
+
if tool.function_declarations:
|
|
198
|
+
for func_decl in tool.function_declarations:
|
|
199
|
+
# Build function definition
|
|
200
|
+
function_def: FunctionDefinition = {
|
|
201
|
+
"name": func_decl.name or "",
|
|
202
|
+
"description": func_decl.description or "",
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
# Always include parameters field, even if empty
|
|
206
|
+
properties = {}
|
|
207
|
+
required = []
|
|
208
|
+
|
|
209
|
+
if func_decl.parameters:
|
|
210
|
+
if func_decl.parameters.properties:
|
|
211
|
+
for prop_name, prop_schema in func_decl.parameters.properties.items():
|
|
212
|
+
value_dict = prop_schema.model_dump(exclude_none=True)
|
|
213
|
+
_update_type_string(value_dict)
|
|
214
|
+
properties[prop_name] = value_dict
|
|
215
|
+
|
|
216
|
+
if func_decl.parameters.required:
|
|
217
|
+
required = func_decl.parameters.required
|
|
218
|
+
|
|
219
|
+
function_def["parameters"] = {"type": "object", "properties": properties, "required": required}
|
|
220
|
+
|
|
221
|
+
# Create the tool param
|
|
222
|
+
openai_tool: ChatCompletionToolParam = {"type": "function", "function": function_def}
|
|
223
|
+
openai_tools.append(openai_tool)
|
|
224
|
+
|
|
225
|
+
return openai_tools
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def _convert_openai_response_to_llm_response(response: ChatCompletion) -> LlmResponse:
|
|
229
|
+
"""Convert OpenAI response to LlmResponse."""
|
|
230
|
+
choice = response.choices[0]
|
|
231
|
+
message = choice.message
|
|
232
|
+
|
|
233
|
+
parts = []
|
|
234
|
+
|
|
235
|
+
# Handle text content
|
|
236
|
+
if message.content:
|
|
237
|
+
parts.append(types.Part.from_text(text=message.content))
|
|
238
|
+
|
|
239
|
+
# Handle function calls
|
|
240
|
+
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
241
|
+
for tool_call in message.tool_calls:
|
|
242
|
+
if tool_call.type == "function":
|
|
243
|
+
try:
|
|
244
|
+
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
|
245
|
+
except json.JSONDecodeError:
|
|
246
|
+
args = {}
|
|
247
|
+
|
|
248
|
+
part = types.Part.from_function_call(name=tool_call.function.name, args=args)
|
|
249
|
+
if part.function_call:
|
|
250
|
+
part.function_call.id = tool_call.id
|
|
251
|
+
parts.append(part)
|
|
252
|
+
|
|
253
|
+
content = types.Content(role="model", parts=parts)
|
|
254
|
+
|
|
255
|
+
# Handle usage metadata
|
|
256
|
+
usage_metadata = None
|
|
257
|
+
if hasattr(response, "usage") and response.usage:
|
|
258
|
+
usage_metadata = types.GenerateContentResponseUsageMetadata(
|
|
259
|
+
prompt_token_count=response.usage.prompt_tokens,
|
|
260
|
+
candidates_token_count=response.usage.completion_tokens,
|
|
261
|
+
total_token_count=response.usage.total_tokens,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Handle finish reason
|
|
265
|
+
finish_reason = types.FinishReason.STOP
|
|
266
|
+
if choice.finish_reason == "length":
|
|
267
|
+
finish_reason = types.FinishReason.MAX_TOKENS
|
|
268
|
+
elif choice.finish_reason == "content_filter":
|
|
269
|
+
finish_reason = types.FinishReason.SAFETY
|
|
270
|
+
|
|
271
|
+
return LlmResponse(content=content, usage_metadata=usage_metadata, finish_reason=finish_reason)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class BaseOpenAI(BaseLlm):
|
|
275
|
+
"""Base class for OpenAI-compatible models."""
|
|
276
|
+
|
|
277
|
+
model: str
|
|
278
|
+
base_url: Optional[str] = None
|
|
279
|
+
api_key: Optional[str] = Field(default=None, exclude=True)
|
|
280
|
+
max_tokens: Optional[int] = None
|
|
281
|
+
temperature: Optional[float] = None
|
|
282
|
+
|
|
283
|
+
@classmethod
|
|
284
|
+
def supported_models(cls) -> list[str]:
|
|
285
|
+
"""Returns a list of supported models in regex for LlmRegistry."""
|
|
286
|
+
return [r"gpt-.*", r"o1-.*"]
|
|
287
|
+
|
|
288
|
+
@cached_property
|
|
289
|
+
def _client(self) -> AsyncOpenAI:
|
|
290
|
+
"""Get the OpenAI client."""
|
|
291
|
+
kwargs = {}
|
|
292
|
+
if self.base_url:
|
|
293
|
+
kwargs["base_url"] = self.base_url
|
|
294
|
+
if self.api_key:
|
|
295
|
+
kwargs["api_key"] = self.api_key
|
|
296
|
+
|
|
297
|
+
return AsyncOpenAI(**kwargs)
|
|
298
|
+
|
|
299
|
+
async def generate_content_async(
|
|
300
|
+
self, llm_request: LlmRequest, stream: bool = False
|
|
301
|
+
) -> AsyncGenerator[LlmResponse, None]:
|
|
302
|
+
"""Generate content using OpenAI API."""
|
|
303
|
+
|
|
304
|
+
# Convert messages
|
|
305
|
+
system_instruction = None
|
|
306
|
+
if llm_request.config and llm_request.config.system_instruction:
|
|
307
|
+
if isinstance(llm_request.config.system_instruction, str):
|
|
308
|
+
system_instruction = llm_request.config.system_instruction
|
|
309
|
+
elif hasattr(llm_request.config.system_instruction, "parts"):
|
|
310
|
+
# Handle Content type system instruction
|
|
311
|
+
text_parts = []
|
|
312
|
+
parts = getattr(llm_request.config.system_instruction, "parts", [])
|
|
313
|
+
if parts:
|
|
314
|
+
for part in parts:
|
|
315
|
+
if hasattr(part, "text") and part.text:
|
|
316
|
+
text_parts.append(part.text)
|
|
317
|
+
system_instruction = "\n".join(text_parts)
|
|
318
|
+
|
|
319
|
+
messages = _convert_content_to_openai_messages(llm_request.contents, system_instruction)
|
|
320
|
+
|
|
321
|
+
# Prepare request parameters
|
|
322
|
+
kwargs = {
|
|
323
|
+
"model": llm_request.model or self.model,
|
|
324
|
+
"messages": messages,
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
if self.max_tokens:
|
|
328
|
+
kwargs["max_tokens"] = self.max_tokens
|
|
329
|
+
if self.temperature is not None:
|
|
330
|
+
kwargs["temperature"] = self.temperature
|
|
331
|
+
|
|
332
|
+
# Handle tools
|
|
333
|
+
if llm_request.config and llm_request.config.tools:
|
|
334
|
+
# Filter to only google.genai.types.Tool objects
|
|
335
|
+
genai_tools = []
|
|
336
|
+
for tool in llm_request.config.tools:
|
|
337
|
+
if hasattr(tool, "function_declarations"):
|
|
338
|
+
genai_tools.append(tool)
|
|
339
|
+
|
|
340
|
+
if genai_tools:
|
|
341
|
+
openai_tools = _convert_tools_to_openai(genai_tools)
|
|
342
|
+
if openai_tools:
|
|
343
|
+
kwargs["tools"] = openai_tools
|
|
344
|
+
kwargs["tool_choice"] = "auto"
|
|
345
|
+
|
|
346
|
+
try:
|
|
347
|
+
if stream:
|
|
348
|
+
# Handle streaming
|
|
349
|
+
async for chunk in await self._client.chat.completions.create(stream=True, **kwargs):
|
|
350
|
+
if chunk.choices and chunk.choices[0].delta:
|
|
351
|
+
delta = chunk.choices[0].delta
|
|
352
|
+
if delta.content:
|
|
353
|
+
content = types.Content(role="model", parts=[types.Part.from_text(text=delta.content)])
|
|
354
|
+
yield LlmResponse(
|
|
355
|
+
content=content, partial=True, turn_complete=chunk.choices[0].finish_reason is not None
|
|
356
|
+
)
|
|
357
|
+
else:
|
|
358
|
+
# Handle non-streaming
|
|
359
|
+
response = await self._client.chat.completions.create(stream=False, **kwargs)
|
|
360
|
+
yield _convert_openai_response_to_llm_response(response)
|
|
361
|
+
|
|
362
|
+
except Exception as e:
|
|
363
|
+
yield LlmResponse(error_code="API_ERROR", error_message=str(e))
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
class OpenAI(BaseOpenAI):
|
|
367
|
+
"""OpenAI model implementation."""
|
|
368
|
+
|
|
369
|
+
type: Literal["openai"]
|
|
370
|
+
|
|
371
|
+
@cached_property
|
|
372
|
+
def _client(self) -> AsyncOpenAI:
|
|
373
|
+
"""Get the OpenAI client."""
|
|
374
|
+
kwargs = {}
|
|
375
|
+
if self.base_url:
|
|
376
|
+
kwargs["base_url"] = self.base_url
|
|
377
|
+
if self.api_key:
|
|
378
|
+
kwargs["api_key"] = self.api_key
|
|
379
|
+
elif "OPENAI_API_KEY" in os.environ:
|
|
380
|
+
kwargs["api_key"] = os.environ["OPENAI_API_KEY"]
|
|
381
|
+
|
|
382
|
+
return AsyncOpenAI(**kwargs)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
class AzureOpenAI(BaseOpenAI):
|
|
386
|
+
"""Azure OpenAI model implementation."""
|
|
387
|
+
|
|
388
|
+
type: Literal["azure_openai"]
|
|
389
|
+
api_version: Optional[str] = None
|
|
390
|
+
azure_endpoint: Optional[str] = None
|
|
391
|
+
azure_deployment: Optional[str] = None
|
|
392
|
+
|
|
393
|
+
@cached_property
|
|
394
|
+
def _client(self) -> AsyncAzureOpenAI:
|
|
395
|
+
"""Get the Azure OpenAI client."""
|
|
396
|
+
api_version = self.api_version or os.environ.get("AZURE_OPENAI_API_VERSION", "2024-02-15-preview")
|
|
397
|
+
azure_endpoint = self.azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
|
398
|
+
api_key = self.api_key or os.environ.get("AZURE_OPENAI_API_KEY")
|
|
399
|
+
|
|
400
|
+
if not azure_endpoint:
|
|
401
|
+
raise ValueError(
|
|
402
|
+
"Azure endpoint must be provided either via azure_endpoint parameter or AZURE_OPENAI_ENDPOINT environment variable"
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
if not api_key:
|
|
406
|
+
raise ValueError(
|
|
407
|
+
"API key must be provided either via api_key parameter or AZURE_OPENAI_API_KEY environment variable"
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
return AsyncAzureOpenAI(api_version=api_version, azure_endpoint=azure_endpoint, api_key=api_key)
|
|
@@ -12,6 +12,9 @@ from google.adk.tools.agent_tool import AgentTool
|
|
|
12
12
|
from google.adk.tools.mcp_tool import MCPToolset, SseConnectionParams, StreamableHTTPConnectionParams
|
|
13
13
|
from pydantic import BaseModel, Field
|
|
14
14
|
|
|
15
|
+
from .models import AzureOpenAI as OpenAIAzure
|
|
16
|
+
from .models import OpenAI as OpenAINative
|
|
17
|
+
|
|
15
18
|
logger = logging.getLogger(__name__)
|
|
16
19
|
|
|
17
20
|
|
|
@@ -97,7 +100,7 @@ class AgentConfig(BaseModel):
|
|
|
97
100
|
mcp_toolsets.append(AgentTool(agent=remote_agent, skip_summarization=True))
|
|
98
101
|
|
|
99
102
|
if self.model.type == "openai":
|
|
100
|
-
model =
|
|
103
|
+
model = OpenAINative(model=self.model.model, base_url=self.model.base_url, type="openai")
|
|
101
104
|
elif self.model.type == "anthropic":
|
|
102
105
|
model = LiteLlm(model=f"anthropic/{self.model.model}", base_url=self.model.base_url)
|
|
103
106
|
elif self.model.type == "gemini_vertex_ai":
|
|
@@ -107,7 +110,7 @@ class AgentConfig(BaseModel):
|
|
|
107
110
|
elif self.model.type == "ollama":
|
|
108
111
|
model = LiteLlm(model=f"ollama_chat/{self.model.model}")
|
|
109
112
|
elif self.model.type == "azure_openai":
|
|
110
|
-
model =
|
|
113
|
+
model = OpenAIAzure(model=self.model.model, type="azure_openai")
|
|
111
114
|
elif self.model.type == "gemini":
|
|
112
115
|
model = self.model.model
|
|
113
116
|
else:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest import mock
|
|
16
|
+
|
|
17
|
+
import pytest
|
|
18
|
+
from google.adk.models.llm_request import LlmRequest
|
|
19
|
+
from google.adk.models.llm_response import LlmResponse
|
|
20
|
+
from google.genai import types
|
|
21
|
+
from google.genai.types import Content, Part
|
|
22
|
+
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
|
23
|
+
|
|
24
|
+
from kagent.adk.models import OpenAI
|
|
25
|
+
from kagent.adk.models._openai import _convert_tools_to_openai
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@pytest.fixture
|
|
29
|
+
def generate_content_response():
|
|
30
|
+
# Create a mock response object
|
|
31
|
+
class MockUsage:
|
|
32
|
+
def __init__(self):
|
|
33
|
+
self.completion_tokens = 12
|
|
34
|
+
self.prompt_tokens = 13
|
|
35
|
+
self.total_tokens = 25
|
|
36
|
+
|
|
37
|
+
class MockMessage:
|
|
38
|
+
def __init__(self):
|
|
39
|
+
self.content = "Hi! How can I help you today?"
|
|
40
|
+
self.role = "assistant"
|
|
41
|
+
|
|
42
|
+
class MockChoice:
|
|
43
|
+
def __init__(self):
|
|
44
|
+
self.finish_reason = "stop"
|
|
45
|
+
self.index = 0
|
|
46
|
+
self.message = MockMessage()
|
|
47
|
+
|
|
48
|
+
class MockResponse:
|
|
49
|
+
def __init__(self):
|
|
50
|
+
self.id = "chatcmpl-testid"
|
|
51
|
+
self.choices = [MockChoice()]
|
|
52
|
+
self.created = 1234567890
|
|
53
|
+
self.model = "gpt-3.5-turbo"
|
|
54
|
+
self.object = "chat.completion"
|
|
55
|
+
self.usage = MockUsage()
|
|
56
|
+
|
|
57
|
+
return MockResponse()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.fixture
|
|
61
|
+
def generate_llm_response():
|
|
62
|
+
return LlmResponse.create(
|
|
63
|
+
types.GenerateContentResponse(
|
|
64
|
+
candidates=[
|
|
65
|
+
types.Candidate(
|
|
66
|
+
content=Content(
|
|
67
|
+
role="model",
|
|
68
|
+
parts=[Part.from_text(text="Hello, how can I help you?")],
|
|
69
|
+
),
|
|
70
|
+
finish_reason=types.FinishReason.STOP,
|
|
71
|
+
)
|
|
72
|
+
]
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@pytest.fixture
|
|
78
|
+
def openai_llm():
|
|
79
|
+
return OpenAI(model="gpt-3.5-turbo", type="openai")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.fixture
|
|
83
|
+
def llm_request():
|
|
84
|
+
return LlmRequest(
|
|
85
|
+
model="gpt-3.5-turbo",
|
|
86
|
+
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
|
|
87
|
+
config=types.GenerateContentConfig(
|
|
88
|
+
temperature=0.1,
|
|
89
|
+
response_modalities=[types.Modality.TEXT],
|
|
90
|
+
system_instruction="You are a helpful assistant",
|
|
91
|
+
),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def test_supported_models():
|
|
96
|
+
models = OpenAI.supported_models()
|
|
97
|
+
assert len(models) == 2
|
|
98
|
+
assert models[0] == r"gpt-.*"
|
|
99
|
+
assert models[1] == r"o1-.*"
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
function_declaration_test_cases = [
|
|
103
|
+
(
|
|
104
|
+
"function_with_no_parameters",
|
|
105
|
+
types.FunctionDeclaration(
|
|
106
|
+
name="get_current_time",
|
|
107
|
+
description="Gets the current time.",
|
|
108
|
+
),
|
|
109
|
+
ChatCompletionToolParam(
|
|
110
|
+
type="function",
|
|
111
|
+
function={
|
|
112
|
+
"name": "get_current_time",
|
|
113
|
+
"description": "Gets the current time.",
|
|
114
|
+
"parameters": {"type": "object", "properties": {}, "required": []},
|
|
115
|
+
},
|
|
116
|
+
),
|
|
117
|
+
),
|
|
118
|
+
(
|
|
119
|
+
"function_with_one_optional_parameter",
|
|
120
|
+
types.FunctionDeclaration(
|
|
121
|
+
name="get_weather",
|
|
122
|
+
description="Gets weather information for a given location.",
|
|
123
|
+
parameters=types.Schema(
|
|
124
|
+
type=types.Type.OBJECT,
|
|
125
|
+
properties={
|
|
126
|
+
"location": types.Schema(
|
|
127
|
+
type=types.Type.STRING,
|
|
128
|
+
description="City and state, e.g., San Francisco, CA",
|
|
129
|
+
)
|
|
130
|
+
},
|
|
131
|
+
),
|
|
132
|
+
),
|
|
133
|
+
ChatCompletionToolParam(
|
|
134
|
+
type="function",
|
|
135
|
+
function={
|
|
136
|
+
"name": "get_weather",
|
|
137
|
+
"description": "Gets weather information for a given location.",
|
|
138
|
+
"parameters": {
|
|
139
|
+
"type": "object",
|
|
140
|
+
"properties": {
|
|
141
|
+
"location": {
|
|
142
|
+
"type": "string",
|
|
143
|
+
"description": "City and state, e.g., San Francisco, CA",
|
|
144
|
+
}
|
|
145
|
+
},
|
|
146
|
+
"required": [],
|
|
147
|
+
},
|
|
148
|
+
},
|
|
149
|
+
),
|
|
150
|
+
),
|
|
151
|
+
(
|
|
152
|
+
"function_with_one_required_parameter",
|
|
153
|
+
types.FunctionDeclaration(
|
|
154
|
+
name="get_stock_price",
|
|
155
|
+
description="Gets the current price for a stock ticker.",
|
|
156
|
+
parameters=types.Schema(
|
|
157
|
+
type=types.Type.OBJECT,
|
|
158
|
+
properties={
|
|
159
|
+
"ticker": types.Schema(
|
|
160
|
+
type=types.Type.STRING,
|
|
161
|
+
description="The stock ticker, e.g., AAPL",
|
|
162
|
+
)
|
|
163
|
+
},
|
|
164
|
+
required=["ticker"],
|
|
165
|
+
),
|
|
166
|
+
),
|
|
167
|
+
ChatCompletionToolParam(
|
|
168
|
+
type="function",
|
|
169
|
+
function={
|
|
170
|
+
"name": "get_stock_price",
|
|
171
|
+
"description": "Gets the current price for a stock ticker.",
|
|
172
|
+
"parameters": {
|
|
173
|
+
"type": "object",
|
|
174
|
+
"properties": {
|
|
175
|
+
"ticker": {
|
|
176
|
+
"type": "string",
|
|
177
|
+
"description": "The stock ticker, e.g., AAPL",
|
|
178
|
+
}
|
|
179
|
+
},
|
|
180
|
+
"required": ["ticker"],
|
|
181
|
+
},
|
|
182
|
+
},
|
|
183
|
+
),
|
|
184
|
+
),
|
|
185
|
+
(
|
|
186
|
+
"function_with_multiple_mixed_parameters",
|
|
187
|
+
types.FunctionDeclaration(
|
|
188
|
+
name="submit_order",
|
|
189
|
+
description="Submits a product order.",
|
|
190
|
+
parameters=types.Schema(
|
|
191
|
+
type=types.Type.OBJECT,
|
|
192
|
+
properties={
|
|
193
|
+
"product_id": types.Schema(type=types.Type.STRING, description="The product ID"),
|
|
194
|
+
"quantity": types.Schema(
|
|
195
|
+
type=types.Type.INTEGER,
|
|
196
|
+
description="The order quantity",
|
|
197
|
+
),
|
|
198
|
+
"notes": types.Schema(
|
|
199
|
+
type=types.Type.STRING,
|
|
200
|
+
description="Optional order notes",
|
|
201
|
+
),
|
|
202
|
+
},
|
|
203
|
+
required=["product_id", "quantity"],
|
|
204
|
+
),
|
|
205
|
+
),
|
|
206
|
+
ChatCompletionToolParam(
|
|
207
|
+
type="function",
|
|
208
|
+
function={
|
|
209
|
+
"name": "submit_order",
|
|
210
|
+
"description": "Submits a product order.",
|
|
211
|
+
"parameters": {
|
|
212
|
+
"type": "object",
|
|
213
|
+
"properties": {
|
|
214
|
+
"product_id": {
|
|
215
|
+
"type": "string",
|
|
216
|
+
"description": "The product ID",
|
|
217
|
+
},
|
|
218
|
+
"quantity": {
|
|
219
|
+
"type": "integer",
|
|
220
|
+
"description": "The order quantity",
|
|
221
|
+
},
|
|
222
|
+
"notes": {
|
|
223
|
+
"type": "string",
|
|
224
|
+
"description": "Optional order notes",
|
|
225
|
+
},
|
|
226
|
+
},
|
|
227
|
+
"required": ["product_id", "quantity"],
|
|
228
|
+
},
|
|
229
|
+
},
|
|
230
|
+
),
|
|
231
|
+
),
|
|
232
|
+
(
|
|
233
|
+
"function_with_complex_nested_parameter",
|
|
234
|
+
types.FunctionDeclaration(
|
|
235
|
+
name="create_playlist",
|
|
236
|
+
description="Creates a playlist from a list of songs.",
|
|
237
|
+
parameters=types.Schema(
|
|
238
|
+
type=types.Type.OBJECT,
|
|
239
|
+
properties={
|
|
240
|
+
"playlist_name": types.Schema(
|
|
241
|
+
type=types.Type.STRING,
|
|
242
|
+
description="The name for the new playlist",
|
|
243
|
+
),
|
|
244
|
+
"songs": types.Schema(
|
|
245
|
+
type=types.Type.ARRAY,
|
|
246
|
+
description="A list of songs to add to the playlist",
|
|
247
|
+
items=types.Schema(
|
|
248
|
+
type=types.Type.OBJECT,
|
|
249
|
+
properties={
|
|
250
|
+
"title": types.Schema(type=types.Type.STRING),
|
|
251
|
+
"artist": types.Schema(type=types.Type.STRING),
|
|
252
|
+
},
|
|
253
|
+
required=["title", "artist"],
|
|
254
|
+
),
|
|
255
|
+
),
|
|
256
|
+
},
|
|
257
|
+
required=["playlist_name", "songs"],
|
|
258
|
+
),
|
|
259
|
+
),
|
|
260
|
+
ChatCompletionToolParam(
|
|
261
|
+
type="function",
|
|
262
|
+
function={
|
|
263
|
+
"name": "create_playlist",
|
|
264
|
+
"description": "Creates a playlist from a list of songs.",
|
|
265
|
+
"parameters": {
|
|
266
|
+
"type": "object",
|
|
267
|
+
"properties": {
|
|
268
|
+
"playlist_name": {
|
|
269
|
+
"type": "string",
|
|
270
|
+
"description": "The name for the new playlist",
|
|
271
|
+
},
|
|
272
|
+
"songs": {
|
|
273
|
+
"type": "array",
|
|
274
|
+
"description": "A list of songs to add to the playlist",
|
|
275
|
+
"items": {
|
|
276
|
+
"type": "object",
|
|
277
|
+
"properties": {
|
|
278
|
+
"title": {"type": "string"},
|
|
279
|
+
"artist": {"type": "string"},
|
|
280
|
+
},
|
|
281
|
+
"required": ["title", "artist"],
|
|
282
|
+
},
|
|
283
|
+
},
|
|
284
|
+
},
|
|
285
|
+
"required": ["playlist_name", "songs"],
|
|
286
|
+
},
|
|
287
|
+
},
|
|
288
|
+
),
|
|
289
|
+
),
|
|
290
|
+
]
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@pytest.mark.parametrize(
|
|
294
|
+
"_, function_declaration, expected_tool_param",
|
|
295
|
+
function_declaration_test_cases,
|
|
296
|
+
ids=[case[0] for case in function_declaration_test_cases],
|
|
297
|
+
)
|
|
298
|
+
async def test_function_declaration_to_tool_param(_, function_declaration, expected_tool_param):
|
|
299
|
+
"""Test _convert_tools_to_openai function."""
|
|
300
|
+
tool = types.Tool(function_declarations=[function_declaration])
|
|
301
|
+
result = _convert_tools_to_openai([tool])
|
|
302
|
+
assert len(result) == 1
|
|
303
|
+
assert result[0] == expected_tool_param
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@pytest.mark.asyncio
|
|
307
|
+
async def test_generate_content_async(openai_llm, llm_request, generate_content_response, generate_llm_response):
|
|
308
|
+
with mock.patch.object(openai_llm, "_client") as mock_client:
|
|
309
|
+
# Create a mock coroutine that returns the generate_content_response.
|
|
310
|
+
async def mock_coro(*args, **kwargs):
|
|
311
|
+
return generate_content_response
|
|
312
|
+
|
|
313
|
+
# Assign the coroutine to the mocked method
|
|
314
|
+
mock_client.chat.completions.create.return_value = mock_coro()
|
|
315
|
+
|
|
316
|
+
responses = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=False)]
|
|
317
|
+
assert len(responses) == 1
|
|
318
|
+
assert isinstance(responses[0], LlmResponse)
|
|
319
|
+
assert responses[0].content is not None
|
|
320
|
+
assert len(responses[0].content.parts) > 0
|
|
321
|
+
assert responses[0].content.parts[0].text == "Hi! How can I help you today?"
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@pytest.mark.asyncio
|
|
325
|
+
async def test_generate_content_async_with_max_tokens(llm_request, generate_content_response, generate_llm_response):
|
|
326
|
+
openai_llm = OpenAI(model="gpt-3.5-turbo", max_tokens=4096, type="openai")
|
|
327
|
+
with mock.patch.object(openai_llm, "_client") as mock_client:
|
|
328
|
+
# Create a mock coroutine that returns the generate_content_response.
|
|
329
|
+
async def mock_coro(*args, **kwargs):
|
|
330
|
+
return generate_content_response
|
|
331
|
+
|
|
332
|
+
# Assign the coroutine to the mocked method
|
|
333
|
+
mock_client.chat.completions.create.return_value = mock_coro()
|
|
334
|
+
|
|
335
|
+
_ = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=False)]
|
|
336
|
+
mock_client.chat.completions.create.assert_called_once()
|
|
337
|
+
_, kwargs = mock_client.chat.completions.create.call_args
|
|
338
|
+
assert kwargs["max_tokens"] == 4096
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|