kagent-adk 0.6.8__tar.gz → 0.6.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kagent-adk might be problematic. Click here for more details.

Files changed (22) hide show
  1. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/PKG-INFO +1 -1
  2. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/pyproject.toml +6 -1
  3. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/src/kagent/adk/__init__.py +1 -1
  4. kagent_adk-0.6.9/src/kagent/adk/models/__init__.py +3 -0
  5. kagent_adk-0.6.9/src/kagent/adk/models/_openai.py +393 -0
  6. kagent_adk-0.6.8/src/kagent/adk/models.py → kagent_adk-0.6.9/src/kagent/adk/types.py +5 -2
  7. kagent_adk-0.6.9/tests/__init__.py +0 -0
  8. kagent_adk-0.6.9/tests/unittests/__init__.py +0 -0
  9. kagent_adk-0.6.9/tests/unittests/models/__init__.py +0 -0
  10. kagent_adk-0.6.9/tests/unittests/models/test_openai.py +338 -0
  11. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/.gitignore +0 -0
  12. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/.python-version +0 -0
  13. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/README.md +0 -0
  14. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/src/kagent/adk/_a2a.py +0 -0
  15. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/src/kagent/adk/_agent_executor.py +0 -0
  16. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/src/kagent/adk/_session_service.py +0 -0
  17. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/src/kagent/adk/_token.py +0 -0
  18. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/src/kagent/adk/cli.py +0 -0
  19. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/src/kagent/adk/converters/__init__.py +0 -0
  20. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/src/kagent/adk/converters/event_converter.py +0 -0
  21. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/src/kagent/adk/converters/part_converter.py +0 -0
  22. {kagent_adk-0.6.8 → kagent_adk-0.6.9}/src/kagent/adk/converters/request_converter.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kagent-adk
3
- Version: 0.6.8
3
+ Version: 0.6.9
4
4
  Summary: kagent-adk is an sdk for integrating adk agents with kagent
5
5
  Requires-Python: >=3.12.11
6
6
  Requires-Dist: a2a-sdk>=0.3.1
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "kagent-adk"
7
- version = "0.6.8"
7
+ version = "0.6.9"
8
8
  description = "kagent-adk is an sdk for integrating adk agents with kagent"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.12.11"
@@ -50,3 +50,8 @@ packages = ["src/kagent"]
50
50
 
51
51
  [tool.ruff]
52
52
  extend = "../../pyproject.toml"
53
+
54
+ [tool.pytest.ini_options]
55
+ testpaths = ["tests"]
56
+ asyncio_default_fixture_loop_scope = "function"
57
+ asyncio_mode = "auto"
@@ -1,7 +1,7 @@
1
1
  import importlib.metadata
2
2
 
3
3
  from ._a2a import KAgentApp
4
- from .models import AgentConfig
4
+ from .types import AgentConfig
5
5
 
6
6
  __version__ = importlib.metadata.version("kagent_adk")
7
7
 
@@ -0,0 +1,3 @@
1
+ from ._openai import AzureOpenAI, OpenAI
2
+
3
+ __all__ = ["OpenAI", "AzureOpenAI"]
@@ -0,0 +1,393 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import json
5
+ import os
6
+ from functools import cached_property
7
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, Literal, Optional
8
+
9
+ from google.adk.models import BaseLlm
10
+ from google.adk.models.llm_response import LlmResponse
11
+ from google.genai import types
12
+ from openai import AsyncAzureOpenAI, AsyncOpenAI
13
+ from openai.types.chat import (
14
+ ChatCompletion,
15
+ ChatCompletionAssistantMessageParam,
16
+ ChatCompletionContentPartImageParam,
17
+ ChatCompletionContentPartTextParam,
18
+ ChatCompletionMessageParam,
19
+ ChatCompletionSystemMessageParam,
20
+ ChatCompletionToolMessageParam,
21
+ ChatCompletionToolParam,
22
+ ChatCompletionUserMessageParam,
23
+ )
24
+ from openai.types.chat.chat_completion_message_tool_call_param import (
25
+ ChatCompletionMessageToolCallParam,
26
+ )
27
+ from openai.types.chat.chat_completion_message_tool_call_param import (
28
+ Function as ToolCallFunction,
29
+ )
30
+ from openai.types.shared_params import FunctionDefinition
31
+ from pydantic import Field
32
+
33
+ if TYPE_CHECKING:
34
+ from google.adk.models.llm_request import LlmRequest
35
+
36
+
37
+ def _convert_role_to_openai(role: Optional[str]) -> str:
38
+ """Convert google.genai role to OpenAI role."""
39
+ if role in ["model", "assistant"]:
40
+ return "assistant"
41
+ elif role == "system":
42
+ return "system"
43
+ else:
44
+ return "user"
45
+
46
+
47
+ def _convert_content_to_openai_messages(
48
+ contents: list[types.Content], system_instruction: Optional[str] = None
49
+ ) -> list[ChatCompletionMessageParam]:
50
+ """Convert google.genai Content list to OpenAI messages format."""
51
+ messages: list[ChatCompletionMessageParam] = []
52
+
53
+ # Add system message if provided
54
+ if system_instruction:
55
+ system_message: ChatCompletionSystemMessageParam = {"role": "system", "content": system_instruction}
56
+ messages.append(system_message)
57
+
58
+ # Track tool calls to ensure proper flow
59
+ pending_tool_calls = set()
60
+
61
+ for content in contents:
62
+ role = _convert_role_to_openai(content.role)
63
+
64
+ # Separate different types of parts
65
+ text_parts = []
66
+ function_calls = []
67
+ function_responses = []
68
+ image_parts = []
69
+
70
+ for part in content.parts or []:
71
+ if part.text:
72
+ text_parts.append(part.text)
73
+ elif part.function_call:
74
+ function_calls.append(part)
75
+ elif part.function_response:
76
+ function_responses.append(part)
77
+ elif part.inline_data and part.inline_data.mime_type and part.inline_data.mime_type.startswith("image"):
78
+ if part.inline_data.data:
79
+ image_data = base64.b64encode(part.inline_data.data).decode()
80
+ image_part: ChatCompletionContentPartImageParam = {
81
+ "type": "image_url",
82
+ "image_url": {"url": f"data:{part.inline_data.mime_type};base64,{image_data}"},
83
+ }
84
+ image_parts.append(image_part)
85
+
86
+ # Handle function responses first (they should be tool messages)
87
+ for func_response in function_responses:
88
+ tool_call_id = func_response.function_response.id or "call_1"
89
+ if tool_call_id in pending_tool_calls:
90
+ tool_message: ChatCompletionToolMessageParam = {
91
+ "role": "tool",
92
+ "tool_call_id": tool_call_id,
93
+ "content": str(func_response.function_response.response.get("result", ""))
94
+ if func_response.function_response.response
95
+ else "",
96
+ }
97
+ messages.append(tool_message)
98
+ pending_tool_calls.discard(tool_call_id)
99
+
100
+ # Handle function calls (assistant messages with tool_calls)
101
+ if function_calls:
102
+ tool_calls = []
103
+ for func_call in function_calls:
104
+ tool_call_function: ToolCallFunction = {
105
+ "name": func_call.function_call.name or "",
106
+ "arguments": str(func_call.function_call.args) if func_call.function_call.args else "{}",
107
+ }
108
+ tool_call_id = func_call.function_call.id or "call_1"
109
+ tool_call: ChatCompletionMessageToolCallParam = {
110
+ "id": tool_call_id,
111
+ "type": "function",
112
+ "function": tool_call_function,
113
+ }
114
+ tool_calls.append(tool_call)
115
+ pending_tool_calls.add(tool_call_id)
116
+
117
+ # Create assistant message with tool calls
118
+ text_content = "\n".join(text_parts) if text_parts else None
119
+ assistant_message: ChatCompletionAssistantMessageParam = {
120
+ "role": "assistant",
121
+ "content": text_content,
122
+ "tool_calls": tool_calls,
123
+ }
124
+ messages.append(assistant_message)
125
+
126
+ # Handle regular text/image messages (only if no function calls)
127
+ elif text_parts or image_parts:
128
+ if role == "user":
129
+ if image_parts and text_parts:
130
+ # Multi-modal content
131
+ text_part: ChatCompletionContentPartTextParam = {"type": "text", "text": "\n".join(text_parts)}
132
+ content_parts = [text_part] + image_parts
133
+ user_message: ChatCompletionUserMessageParam = {"role": "user", "content": content_parts}
134
+ elif image_parts:
135
+ # Image only
136
+ user_message: ChatCompletionUserMessageParam = {"role": "user", "content": image_parts}
137
+ else:
138
+ # Text only
139
+ user_message: ChatCompletionUserMessageParam = {"role": "user", "content": "\n".join(text_parts)}
140
+ messages.append(user_message)
141
+ elif role == "assistant":
142
+ # Assistant messages with text (no tool calls)
143
+ assistant_message: ChatCompletionAssistantMessageParam = {
144
+ "role": "assistant",
145
+ "content": "\n".join(text_parts),
146
+ }
147
+ messages.append(assistant_message)
148
+
149
+ return messages
150
+
151
+
152
+ def _update_type_string(value_dict: dict[str, Any]):
153
+ """Updates 'type' field to expected JSON schema format."""
154
+ if "type" in value_dict:
155
+ value_dict["type"] = value_dict["type"].lower()
156
+
157
+ if "items" in value_dict:
158
+ # 'type' field could exist for items as well, this would be the case if
159
+ # items represent primitive types.
160
+ _update_type_string(value_dict["items"])
161
+
162
+ if "properties" in value_dict["items"]:
163
+ # There could be properties as well on the items, especially if the items
164
+ # are complex object themselves. We recursively traverse each individual
165
+ # property as well and fix the "type" value.
166
+ for _, value in value_dict["items"]["properties"].items():
167
+ _update_type_string(value)
168
+
169
+ if "properties" in value_dict:
170
+ # Handle nested properties
171
+ for _, value in value_dict["properties"].items():
172
+ _update_type_string(value)
173
+
174
+
175
+ def _convert_tools_to_openai(tools: list[types.Tool]) -> list[ChatCompletionToolParam]:
176
+ """Convert google.genai Tools to OpenAI tools format."""
177
+ openai_tools: list[ChatCompletionToolParam] = []
178
+
179
+ for tool in tools:
180
+ if tool.function_declarations:
181
+ for func_decl in tool.function_declarations:
182
+ # Build function definition
183
+ function_def: FunctionDefinition = {
184
+ "name": func_decl.name or "",
185
+ "description": func_decl.description or "",
186
+ }
187
+
188
+ # Always include parameters field, even if empty
189
+ properties = {}
190
+ required = []
191
+
192
+ if func_decl.parameters:
193
+ if func_decl.parameters.properties:
194
+ for prop_name, prop_schema in func_decl.parameters.properties.items():
195
+ value_dict = prop_schema.model_dump(exclude_none=True)
196
+ _update_type_string(value_dict)
197
+ properties[prop_name] = value_dict
198
+
199
+ if func_decl.parameters.required:
200
+ required = func_decl.parameters.required
201
+
202
+ function_def["parameters"] = {"type": "object", "properties": properties, "required": required}
203
+
204
+ # Create the tool param
205
+ openai_tool: ChatCompletionToolParam = {"type": "function", "function": function_def}
206
+ openai_tools.append(openai_tool)
207
+
208
+ return openai_tools
209
+
210
+
211
+ def _convert_openai_response_to_llm_response(response: ChatCompletion) -> LlmResponse:
212
+ """Convert OpenAI response to LlmResponse."""
213
+ choice = response.choices[0]
214
+ message = choice.message
215
+
216
+ parts = []
217
+
218
+ # Handle text content
219
+ if message.content:
220
+ parts.append(types.Part.from_text(text=message.content))
221
+
222
+ # Handle function calls
223
+ if hasattr(message, "tool_calls") and message.tool_calls:
224
+ for tool_call in message.tool_calls:
225
+ if tool_call.type == "function":
226
+ try:
227
+ args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
228
+ except json.JSONDecodeError:
229
+ args = {}
230
+
231
+ part = types.Part.from_function_call(name=tool_call.function.name, args=args)
232
+ if part.function_call:
233
+ part.function_call.id = tool_call.id
234
+ parts.append(part)
235
+
236
+ content = types.Content(role="model", parts=parts)
237
+
238
+ # Handle usage metadata
239
+ usage_metadata = None
240
+ if hasattr(response, "usage") and response.usage:
241
+ usage_metadata = types.GenerateContentResponseUsageMetadata(
242
+ prompt_token_count=response.usage.prompt_tokens,
243
+ candidates_token_count=response.usage.completion_tokens,
244
+ total_token_count=response.usage.total_tokens,
245
+ )
246
+
247
+ # Handle finish reason
248
+ finish_reason = types.FinishReason.STOP
249
+ if choice.finish_reason == "length":
250
+ finish_reason = types.FinishReason.MAX_TOKENS
251
+ elif choice.finish_reason == "content_filter":
252
+ finish_reason = types.FinishReason.SAFETY
253
+
254
+ return LlmResponse(content=content, usage_metadata=usage_metadata, finish_reason=finish_reason)
255
+
256
+
257
+ class BaseOpenAI(BaseLlm):
258
+ """Base class for OpenAI-compatible models."""
259
+
260
+ model: str
261
+ base_url: Optional[str] = None
262
+ api_key: Optional[str] = Field(default=None, exclude=True)
263
+ max_tokens: Optional[int] = None
264
+ temperature: Optional[float] = None
265
+
266
+ @classmethod
267
+ def supported_models(cls) -> list[str]:
268
+ """Returns a list of supported models in regex for LlmRegistry."""
269
+ return [r"gpt-.*", r"o1-.*"]
270
+
271
+ @cached_property
272
+ def _client(self) -> AsyncOpenAI:
273
+ """Get the OpenAI client."""
274
+ kwargs = {}
275
+ if self.base_url:
276
+ kwargs["base_url"] = self.base_url
277
+ if self.api_key:
278
+ kwargs["api_key"] = self.api_key
279
+
280
+ return AsyncOpenAI(**kwargs)
281
+
282
+ async def generate_content_async(
283
+ self, llm_request: LlmRequest, stream: bool = False
284
+ ) -> AsyncGenerator[LlmResponse, None]:
285
+ """Generate content using OpenAI API."""
286
+
287
+ # Convert messages
288
+ system_instruction = None
289
+ if llm_request.config and llm_request.config.system_instruction:
290
+ if isinstance(llm_request.config.system_instruction, str):
291
+ system_instruction = llm_request.config.system_instruction
292
+ elif hasattr(llm_request.config.system_instruction, "parts"):
293
+ # Handle Content type system instruction
294
+ text_parts = []
295
+ parts = getattr(llm_request.config.system_instruction, "parts", [])
296
+ if parts:
297
+ for part in parts:
298
+ if hasattr(part, "text") and part.text:
299
+ text_parts.append(part.text)
300
+ system_instruction = "\n".join(text_parts)
301
+
302
+ messages = _convert_content_to_openai_messages(llm_request.contents, system_instruction)
303
+
304
+ # Prepare request parameters
305
+ kwargs = {
306
+ "model": llm_request.model or self.model,
307
+ "messages": messages,
308
+ }
309
+
310
+ if self.max_tokens:
311
+ kwargs["max_tokens"] = self.max_tokens
312
+ if self.temperature is not None:
313
+ kwargs["temperature"] = self.temperature
314
+
315
+ # Handle tools
316
+ if llm_request.config and llm_request.config.tools:
317
+ # Filter to only google.genai.types.Tool objects
318
+ genai_tools = []
319
+ for tool in llm_request.config.tools:
320
+ if hasattr(tool, "function_declarations"):
321
+ genai_tools.append(tool)
322
+
323
+ if genai_tools:
324
+ openai_tools = _convert_tools_to_openai(genai_tools)
325
+ if openai_tools:
326
+ kwargs["tools"] = openai_tools
327
+ kwargs["tool_choice"] = "auto"
328
+
329
+ try:
330
+ if stream:
331
+ # Handle streaming
332
+ async for chunk in await self._client.chat.completions.create(stream=True, **kwargs):
333
+ if chunk.choices and chunk.choices[0].delta:
334
+ delta = chunk.choices[0].delta
335
+ if delta.content:
336
+ content = types.Content(role="model", parts=[types.Part.from_text(text=delta.content)])
337
+ yield LlmResponse(
338
+ content=content, partial=True, turn_complete=chunk.choices[0].finish_reason is not None
339
+ )
340
+ else:
341
+ # Handle non-streaming
342
+ response = await self._client.chat.completions.create(stream=False, **kwargs)
343
+ yield _convert_openai_response_to_llm_response(response)
344
+
345
+ except Exception as e:
346
+ yield LlmResponse(error_code="API_ERROR", error_message=str(e))
347
+
348
+
349
+ class OpenAI(BaseOpenAI):
350
+ """OpenAI model implementation."""
351
+
352
+ type: Literal["openai"]
353
+
354
+ @cached_property
355
+ def _client(self) -> AsyncOpenAI:
356
+ """Get the OpenAI client."""
357
+ kwargs = {}
358
+ if self.base_url:
359
+ kwargs["base_url"] = self.base_url
360
+ if self.api_key:
361
+ kwargs["api_key"] = self.api_key
362
+ elif "OPENAI_API_KEY" in os.environ:
363
+ kwargs["api_key"] = os.environ["OPENAI_API_KEY"]
364
+
365
+ return AsyncOpenAI(**kwargs)
366
+
367
+
368
+ class AzureOpenAI(BaseOpenAI):
369
+ """Azure OpenAI model implementation."""
370
+
371
+ type: Literal["azure_openai"]
372
+ api_version: Optional[str] = None
373
+ azure_endpoint: Optional[str] = None
374
+ azure_deployment: Optional[str] = None
375
+
376
+ @cached_property
377
+ def _client(self) -> AsyncAzureOpenAI:
378
+ """Get the Azure OpenAI client."""
379
+ api_version = self.api_version or os.environ.get("AZURE_OPENAI_API_VERSION", "2024-02-15-preview")
380
+ azure_endpoint = self.azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
381
+ api_key = self.api_key or os.environ.get("AZURE_OPENAI_API_KEY")
382
+
383
+ if not azure_endpoint:
384
+ raise ValueError(
385
+ "Azure endpoint must be provided either via azure_endpoint parameter or AZURE_OPENAI_ENDPOINT environment variable"
386
+ )
387
+
388
+ if not api_key:
389
+ raise ValueError(
390
+ "API key must be provided either via api_key parameter or AZURE_OPENAI_API_KEY environment variable"
391
+ )
392
+
393
+ return AsyncAzureOpenAI(api_version=api_version, azure_endpoint=azure_endpoint, api_key=api_key)
@@ -12,6 +12,9 @@ from google.adk.tools.agent_tool import AgentTool
12
12
  from google.adk.tools.mcp_tool import MCPToolset, SseConnectionParams, StreamableHTTPConnectionParams
13
13
  from pydantic import BaseModel, Field
14
14
 
15
+ from .models import AzureOpenAI as OpenAIAzure
16
+ from .models import OpenAI as OpenAINative
17
+
15
18
  logger = logging.getLogger(__name__)
16
19
 
17
20
 
@@ -97,7 +100,7 @@ class AgentConfig(BaseModel):
97
100
  mcp_toolsets.append(AgentTool(agent=remote_agent, skip_summarization=True))
98
101
 
99
102
  if self.model.type == "openai":
100
- model = LiteLlm(model=f"openai/{self.model.model}", base_url=self.model.base_url)
103
+ model = OpenAINative(model=self.model.model, base_url=self.model.base_url, type="openai")
101
104
  elif self.model.type == "anthropic":
102
105
  model = LiteLlm(model=f"anthropic/{self.model.model}", base_url=self.model.base_url)
103
106
  elif self.model.type == "gemini_vertex_ai":
@@ -107,7 +110,7 @@ class AgentConfig(BaseModel):
107
110
  elif self.model.type == "ollama":
108
111
  model = LiteLlm(model=f"ollama_chat/{self.model.model}")
109
112
  elif self.model.type == "azure_openai":
110
- model = LiteLlm(model=f"azure/{self.model.model}")
113
+ model = OpenAIAzure(model=self.model.model, type="azure_openai")
111
114
  elif self.model.type == "gemini":
112
115
  model = self.model.model
113
116
  else:
File without changes
File without changes
File without changes
@@ -0,0 +1,338 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest import mock
16
+
17
+ import pytest
18
+ from google.adk.models.llm_request import LlmRequest
19
+ from google.adk.models.llm_response import LlmResponse
20
+ from google.genai import types
21
+ from google.genai.types import Content, Part
22
+ from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
23
+
24
+ from kagent.adk.models import OpenAI
25
+ from kagent.adk.models._openai import _convert_tools_to_openai
26
+
27
+
28
+ @pytest.fixture
29
+ def generate_content_response():
30
+ # Create a mock response object
31
+ class MockUsage:
32
+ def __init__(self):
33
+ self.completion_tokens = 12
34
+ self.prompt_tokens = 13
35
+ self.total_tokens = 25
36
+
37
+ class MockMessage:
38
+ def __init__(self):
39
+ self.content = "Hi! How can I help you today?"
40
+ self.role = "assistant"
41
+
42
+ class MockChoice:
43
+ def __init__(self):
44
+ self.finish_reason = "stop"
45
+ self.index = 0
46
+ self.message = MockMessage()
47
+
48
+ class MockResponse:
49
+ def __init__(self):
50
+ self.id = "chatcmpl-testid"
51
+ self.choices = [MockChoice()]
52
+ self.created = 1234567890
53
+ self.model = "gpt-3.5-turbo"
54
+ self.object = "chat.completion"
55
+ self.usage = MockUsage()
56
+
57
+ return MockResponse()
58
+
59
+
60
+ @pytest.fixture
61
+ def generate_llm_response():
62
+ return LlmResponse.create(
63
+ types.GenerateContentResponse(
64
+ candidates=[
65
+ types.Candidate(
66
+ content=Content(
67
+ role="model",
68
+ parts=[Part.from_text(text="Hello, how can I help you?")],
69
+ ),
70
+ finish_reason=types.FinishReason.STOP,
71
+ )
72
+ ]
73
+ )
74
+ )
75
+
76
+
77
+ @pytest.fixture
78
+ def openai_llm():
79
+ return OpenAI(model="gpt-3.5-turbo", type="openai")
80
+
81
+
82
+ @pytest.fixture
83
+ def llm_request():
84
+ return LlmRequest(
85
+ model="gpt-3.5-turbo",
86
+ contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
87
+ config=types.GenerateContentConfig(
88
+ temperature=0.1,
89
+ response_modalities=[types.Modality.TEXT],
90
+ system_instruction="You are a helpful assistant",
91
+ ),
92
+ )
93
+
94
+
95
+ def test_supported_models():
96
+ models = OpenAI.supported_models()
97
+ assert len(models) == 2
98
+ assert models[0] == r"gpt-.*"
99
+ assert models[1] == r"o1-.*"
100
+
101
+
102
+ function_declaration_test_cases = [
103
+ (
104
+ "function_with_no_parameters",
105
+ types.FunctionDeclaration(
106
+ name="get_current_time",
107
+ description="Gets the current time.",
108
+ ),
109
+ ChatCompletionToolParam(
110
+ type="function",
111
+ function={
112
+ "name": "get_current_time",
113
+ "description": "Gets the current time.",
114
+ "parameters": {"type": "object", "properties": {}, "required": []},
115
+ },
116
+ ),
117
+ ),
118
+ (
119
+ "function_with_one_optional_parameter",
120
+ types.FunctionDeclaration(
121
+ name="get_weather",
122
+ description="Gets weather information for a given location.",
123
+ parameters=types.Schema(
124
+ type=types.Type.OBJECT,
125
+ properties={
126
+ "location": types.Schema(
127
+ type=types.Type.STRING,
128
+ description="City and state, e.g., San Francisco, CA",
129
+ )
130
+ },
131
+ ),
132
+ ),
133
+ ChatCompletionToolParam(
134
+ type="function",
135
+ function={
136
+ "name": "get_weather",
137
+ "description": "Gets weather information for a given location.",
138
+ "parameters": {
139
+ "type": "object",
140
+ "properties": {
141
+ "location": {
142
+ "type": "string",
143
+ "description": "City and state, e.g., San Francisco, CA",
144
+ }
145
+ },
146
+ "required": [],
147
+ },
148
+ },
149
+ ),
150
+ ),
151
+ (
152
+ "function_with_one_required_parameter",
153
+ types.FunctionDeclaration(
154
+ name="get_stock_price",
155
+ description="Gets the current price for a stock ticker.",
156
+ parameters=types.Schema(
157
+ type=types.Type.OBJECT,
158
+ properties={
159
+ "ticker": types.Schema(
160
+ type=types.Type.STRING,
161
+ description="The stock ticker, e.g., AAPL",
162
+ )
163
+ },
164
+ required=["ticker"],
165
+ ),
166
+ ),
167
+ ChatCompletionToolParam(
168
+ type="function",
169
+ function={
170
+ "name": "get_stock_price",
171
+ "description": "Gets the current price for a stock ticker.",
172
+ "parameters": {
173
+ "type": "object",
174
+ "properties": {
175
+ "ticker": {
176
+ "type": "string",
177
+ "description": "The stock ticker, e.g., AAPL",
178
+ }
179
+ },
180
+ "required": ["ticker"],
181
+ },
182
+ },
183
+ ),
184
+ ),
185
+ (
186
+ "function_with_multiple_mixed_parameters",
187
+ types.FunctionDeclaration(
188
+ name="submit_order",
189
+ description="Submits a product order.",
190
+ parameters=types.Schema(
191
+ type=types.Type.OBJECT,
192
+ properties={
193
+ "product_id": types.Schema(type=types.Type.STRING, description="The product ID"),
194
+ "quantity": types.Schema(
195
+ type=types.Type.INTEGER,
196
+ description="The order quantity",
197
+ ),
198
+ "notes": types.Schema(
199
+ type=types.Type.STRING,
200
+ description="Optional order notes",
201
+ ),
202
+ },
203
+ required=["product_id", "quantity"],
204
+ ),
205
+ ),
206
+ ChatCompletionToolParam(
207
+ type="function",
208
+ function={
209
+ "name": "submit_order",
210
+ "description": "Submits a product order.",
211
+ "parameters": {
212
+ "type": "object",
213
+ "properties": {
214
+ "product_id": {
215
+ "type": "string",
216
+ "description": "The product ID",
217
+ },
218
+ "quantity": {
219
+ "type": "integer",
220
+ "description": "The order quantity",
221
+ },
222
+ "notes": {
223
+ "type": "string",
224
+ "description": "Optional order notes",
225
+ },
226
+ },
227
+ "required": ["product_id", "quantity"],
228
+ },
229
+ },
230
+ ),
231
+ ),
232
+ (
233
+ "function_with_complex_nested_parameter",
234
+ types.FunctionDeclaration(
235
+ name="create_playlist",
236
+ description="Creates a playlist from a list of songs.",
237
+ parameters=types.Schema(
238
+ type=types.Type.OBJECT,
239
+ properties={
240
+ "playlist_name": types.Schema(
241
+ type=types.Type.STRING,
242
+ description="The name for the new playlist",
243
+ ),
244
+ "songs": types.Schema(
245
+ type=types.Type.ARRAY,
246
+ description="A list of songs to add to the playlist",
247
+ items=types.Schema(
248
+ type=types.Type.OBJECT,
249
+ properties={
250
+ "title": types.Schema(type=types.Type.STRING),
251
+ "artist": types.Schema(type=types.Type.STRING),
252
+ },
253
+ required=["title", "artist"],
254
+ ),
255
+ ),
256
+ },
257
+ required=["playlist_name", "songs"],
258
+ ),
259
+ ),
260
+ ChatCompletionToolParam(
261
+ type="function",
262
+ function={
263
+ "name": "create_playlist",
264
+ "description": "Creates a playlist from a list of songs.",
265
+ "parameters": {
266
+ "type": "object",
267
+ "properties": {
268
+ "playlist_name": {
269
+ "type": "string",
270
+ "description": "The name for the new playlist",
271
+ },
272
+ "songs": {
273
+ "type": "array",
274
+ "description": "A list of songs to add to the playlist",
275
+ "items": {
276
+ "type": "object",
277
+ "properties": {
278
+ "title": {"type": "string"},
279
+ "artist": {"type": "string"},
280
+ },
281
+ "required": ["title", "artist"],
282
+ },
283
+ },
284
+ },
285
+ "required": ["playlist_name", "songs"],
286
+ },
287
+ },
288
+ ),
289
+ ),
290
+ ]
291
+
292
+
293
+ @pytest.mark.parametrize(
294
+ "_, function_declaration, expected_tool_param",
295
+ function_declaration_test_cases,
296
+ ids=[case[0] for case in function_declaration_test_cases],
297
+ )
298
+ async def test_function_declaration_to_tool_param(_, function_declaration, expected_tool_param):
299
+ """Test _convert_tools_to_openai function."""
300
+ tool = types.Tool(function_declarations=[function_declaration])
301
+ result = _convert_tools_to_openai([tool])
302
+ assert len(result) == 1
303
+ assert result[0] == expected_tool_param
304
+
305
+
306
+ @pytest.mark.asyncio
307
+ async def test_generate_content_async(openai_llm, llm_request, generate_content_response, generate_llm_response):
308
+ with mock.patch.object(openai_llm, "_client") as mock_client:
309
+ # Create a mock coroutine that returns the generate_content_response.
310
+ async def mock_coro(*args, **kwargs):
311
+ return generate_content_response
312
+
313
+ # Assign the coroutine to the mocked method
314
+ mock_client.chat.completions.create.return_value = mock_coro()
315
+
316
+ responses = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=False)]
317
+ assert len(responses) == 1
318
+ assert isinstance(responses[0], LlmResponse)
319
+ assert responses[0].content is not None
320
+ assert len(responses[0].content.parts) > 0
321
+ assert responses[0].content.parts[0].text == "Hi! How can I help you today?"
322
+
323
+
324
+ @pytest.mark.asyncio
325
+ async def test_generate_content_async_with_max_tokens(llm_request, generate_content_response, generate_llm_response):
326
+ openai_llm = OpenAI(model="gpt-3.5-turbo", max_tokens=4096, type="openai")
327
+ with mock.patch.object(openai_llm, "_client") as mock_client:
328
+ # Create a mock coroutine that returns the generate_content_response.
329
+ async def mock_coro(*args, **kwargs):
330
+ return generate_content_response
331
+
332
+ # Assign the coroutine to the mocked method
333
+ mock_client.chat.completions.create.return_value = mock_coro()
334
+
335
+ _ = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=False)]
336
+ mock_client.chat.completions.create.assert_called_once()
337
+ _, kwargs = mock_client.chat.completions.create.call_args
338
+ assert kwargs["max_tokens"] == 4096
File without changes
File without changes
File without changes