kagent-adk 0.7.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,564 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import json
5
+ import os
6
+ from functools import cached_property
7
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, Literal, Optional
8
+
9
+ import httpx
10
+ from google.adk.models import BaseLlm
11
+ from google.adk.models.llm_response import LlmResponse
12
+ from google.genai import types
13
+ from google.genai.types import FunctionCall, FunctionResponse
14
+ from openai import AsyncAzureOpenAI, AsyncOpenAI, DefaultAsyncHttpxClient
15
+ from openai.types.chat import (
16
+ ChatCompletion,
17
+ ChatCompletionAssistantMessageParam,
18
+ ChatCompletionContentPartImageParam,
19
+ ChatCompletionContentPartTextParam,
20
+ ChatCompletionMessageParam,
21
+ ChatCompletionSystemMessageParam,
22
+ ChatCompletionToolMessageParam,
23
+ ChatCompletionToolParam,
24
+ ChatCompletionUserMessageParam,
25
+ )
26
+ from openai.types.chat.chat_completion_message_tool_call_param import (
27
+ ChatCompletionMessageToolCallParam,
28
+ )
29
+ from openai.types.chat.chat_completion_message_tool_call_param import (
30
+ Function as ToolCallFunction,
31
+ )
32
+ from openai.types.shared_params import FunctionDefinition, FunctionParameters
33
+ from pydantic import Field
34
+
35
+ from ._ssl import create_ssl_context
36
+
37
+ if TYPE_CHECKING:
38
+ from google.adk.models.llm_request import LlmRequest
39
+
40
+
41
+ def _convert_role_to_openai(role: Optional[str]) -> str:
42
+ """Convert google.genai role to OpenAI role."""
43
+ if role in ["model", "assistant"]:
44
+ return "assistant"
45
+ elif role == "system":
46
+ return "system"
47
+ else:
48
+ return "user"
49
+
50
+
51
+ def _convert_content_to_openai_messages(
52
+ contents: list[types.Content], system_instruction: Optional[str] = None
53
+ ) -> list[ChatCompletionMessageParam]:
54
+ """Convert google.genai Content list to OpenAI messages format."""
55
+ messages: list[ChatCompletionMessageParam] = []
56
+
57
+ # Add system message if provided
58
+ if system_instruction:
59
+ system_message: ChatCompletionSystemMessageParam = {"role": "system", "content": system_instruction}
60
+ messages.append(system_message)
61
+
62
+ # First pass: collect all function responses to match with tool calls
63
+ all_function_responses: dict[str, FunctionResponse] = {}
64
+ for content in contents:
65
+ for part in content.parts or []:
66
+ if part.function_response:
67
+ tool_call_id = part.function_response.id or "call_1"
68
+ all_function_responses[tool_call_id] = part.function_response
69
+
70
+ for content in contents:
71
+ role = _convert_role_to_openai(content.role)
72
+
73
+ # Separate different types of parts
74
+ text_parts: list[str] = []
75
+ function_calls: list[FunctionCall] = []
76
+ function_responses: list[FunctionResponse] = []
77
+ image_parts = []
78
+
79
+ for part in content.parts or []:
80
+ if part.text:
81
+ text_parts.append(part.text)
82
+ elif part.function_call:
83
+ function_calls.append(part.function_call)
84
+ elif part.function_response:
85
+ function_responses.append(part.function_response)
86
+ elif part.inline_data and part.inline_data.mime_type and part.inline_data.mime_type.startswith("image"):
87
+ if part.inline_data.data:
88
+ image_data = base64.b64encode(part.inline_data.data).decode()
89
+ image_part: ChatCompletionContentPartImageParam = {
90
+ "type": "image_url",
91
+ "image_url": {"url": f"data:{part.inline_data.mime_type};base64,{image_data}"},
92
+ }
93
+ image_parts.append(image_part)
94
+
95
+ # Function responses are now handled together with function calls
96
+ # This ensures proper pairing and prevents orphaned tool messages
97
+
98
+ # Handle function calls (assistant messages with tool_calls)
99
+ if function_calls:
100
+ tool_calls = []
101
+ tool_response_messages = []
102
+
103
+ for func_call in function_calls:
104
+ tool_call_function: ToolCallFunction = {
105
+ "name": func_call.name or "",
106
+ "arguments": json.dumps(func_call.args) if func_call.args else "{}",
107
+ }
108
+ tool_call_id = func_call.id or "call_1"
109
+ tool_call = ChatCompletionMessageToolCallParam(
110
+ id=tool_call_id,
111
+ type="function",
112
+ function=tool_call_function,
113
+ )
114
+ tool_calls.append(tool_call)
115
+
116
+ # Check if we have a response for this tool call
117
+ if tool_call_id in all_function_responses:
118
+ func_response = all_function_responses[tool_call_id]
119
+ content = ""
120
+ if isinstance(func_response.response, str):
121
+ content = func_response.response
122
+ elif func_response.response and "content" in func_response.response:
123
+ content_list = func_response.response["content"]
124
+ if len(content_list) > 0:
125
+ content = content_list[0]["text"]
126
+ elif func_response.response and "result" in func_response.response:
127
+ content = func_response.response["result"]
128
+
129
+ tool_message = ChatCompletionToolMessageParam(
130
+ role="tool",
131
+ tool_call_id=tool_call_id,
132
+ content=content,
133
+ )
134
+ tool_response_messages.append(tool_message)
135
+ else:
136
+ # If no response is available, create a placeholder response
137
+ # This prevents the OpenAI API error
138
+ tool_message = ChatCompletionToolMessageParam(
139
+ role="tool",
140
+ tool_call_id=tool_call_id,
141
+ content="No response available for this function call.",
142
+ )
143
+ tool_response_messages.append(tool_message)
144
+
145
+ # Create assistant message with tool calls
146
+ text_content = "\n".join(text_parts) if text_parts else None
147
+ assistant_message = ChatCompletionAssistantMessageParam(
148
+ role="assistant",
149
+ content=text_content,
150
+ tool_calls=tool_calls,
151
+ )
152
+ messages.append(assistant_message)
153
+
154
+ # Add all tool response messages immediately after the assistant message
155
+ messages.extend(tool_response_messages)
156
+
157
+ # Handle regular text/image messages (only if no function calls)
158
+ elif text_parts or image_parts:
159
+ if role == "user":
160
+ if image_parts and text_parts:
161
+ # Multi-modal content
162
+ text_part = ChatCompletionContentPartTextParam(type="text", text="\n".join(text_parts))
163
+ content_parts = [text_part] + image_parts
164
+ user_message = ChatCompletionUserMessageParam(role="user", content=content_parts)
165
+ elif image_parts:
166
+ # Image only
167
+ user_message = ChatCompletionUserMessageParam(role="user", content=image_parts)
168
+ else:
169
+ # Text only
170
+ user_message = ChatCompletionUserMessageParam(role="user", content="\n".join(text_parts))
171
+ messages.append(user_message)
172
+ elif role == "assistant":
173
+ # Assistant messages with text (no tool calls)
174
+ assistant_message = ChatCompletionAssistantMessageParam(
175
+ role="assistant",
176
+ content="\n".join(text_parts),
177
+ )
178
+ messages.append(assistant_message)
179
+
180
+ return messages
181
+
182
+
183
+ def _update_type_string(value_dict: dict[str, Any]):
184
+ """Updates 'type' field to expected JSON schema format."""
185
+ if "type" in value_dict:
186
+ value_dict["type"] = value_dict["type"].lower()
187
+
188
+ if "items" in value_dict:
189
+ # 'type' field could exist for items as well, this would be the case if
190
+ # items represent primitive types.
191
+ _update_type_string(value_dict["items"])
192
+
193
+ if "properties" in value_dict["items"]:
194
+ # There could be properties as well on the items, especially if the items
195
+ # are complex object themselves. We recursively traverse each individual
196
+ # property as well and fix the "type" value.
197
+ for _, value in value_dict["items"]["properties"].items():
198
+ _update_type_string(value)
199
+
200
+ if "properties" in value_dict:
201
+ # Handle nested properties
202
+ for _, value in value_dict["properties"].items():
203
+ _update_type_string(value)
204
+
205
+
206
+ def _convert_tools_to_openai(tools: list[types.Tool]) -> list[ChatCompletionToolParam]:
207
+ """Convert google.genai Tools to OpenAI tools format."""
208
+ openai_tools: list[ChatCompletionToolParam] = []
209
+
210
+ for tool in tools:
211
+ if tool.function_declarations:
212
+ for func_decl in tool.function_declarations:
213
+ # Build function definition
214
+ function_def = FunctionDefinition(
215
+ name=func_decl.name or "",
216
+ description=func_decl.description or "",
217
+ )
218
+
219
+ # Always include parameters field, even if empty
220
+ properties = {}
221
+ required = []
222
+
223
+ if func_decl.parameters:
224
+ if func_decl.parameters.properties:
225
+ for prop_name, prop_schema in func_decl.parameters.properties.items():
226
+ value_dict = prop_schema.model_dump(exclude_none=True)
227
+ _update_type_string(value_dict)
228
+ properties[prop_name] = value_dict
229
+
230
+ if func_decl.parameters.required:
231
+ required = func_decl.parameters.required
232
+
233
+ function_def["parameters"] = {"type": "object", "properties": properties, "required": required}
234
+
235
+ # Create the tool param
236
+ openai_tool = ChatCompletionToolParam(type="function", function=function_def)
237
+ openai_tools.append(openai_tool)
238
+
239
+ return openai_tools
240
+
241
+
242
+ def _convert_openai_response_to_llm_response(response: ChatCompletion) -> LlmResponse:
243
+ """Convert OpenAI response to LlmResponse."""
244
+ choice = response.choices[0]
245
+ message = choice.message
246
+
247
+ parts = []
248
+
249
+ # Handle text content
250
+ if message.content:
251
+ parts.append(types.Part.from_text(text=message.content))
252
+
253
+ # Handle function calls
254
+ if hasattr(message, "tool_calls") and message.tool_calls:
255
+ for tool_call in message.tool_calls:
256
+ if tool_call.type == "function":
257
+ try:
258
+ args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
259
+ except json.JSONDecodeError:
260
+ args = {}
261
+
262
+ part = types.Part.from_function_call(name=tool_call.function.name, args=args)
263
+ if part.function_call:
264
+ part.function_call.id = tool_call.id
265
+ parts.append(part)
266
+
267
+ content = types.Content(role="model", parts=parts)
268
+
269
+ # Handle usage metadata
270
+ usage_metadata = None
271
+ if hasattr(response, "usage") and response.usage:
272
+ usage_metadata = types.GenerateContentResponseUsageMetadata(
273
+ prompt_token_count=response.usage.prompt_tokens,
274
+ candidates_token_count=response.usage.completion_tokens,
275
+ total_token_count=response.usage.total_tokens,
276
+ )
277
+
278
+ # Handle finish reason
279
+ finish_reason = types.FinishReason.STOP
280
+ if choice.finish_reason == "length":
281
+ finish_reason = types.FinishReason.MAX_TOKENS
282
+ elif choice.finish_reason == "content_filter":
283
+ finish_reason = types.FinishReason.SAFETY
284
+
285
+ return LlmResponse(content=content, usage_metadata=usage_metadata, finish_reason=finish_reason)
286
+
287
+
288
+ class BaseOpenAI(BaseLlm):
289
+ """Base class for OpenAI-compatible models."""
290
+
291
+ model: str
292
+ api_key: Optional[str] = Field(default=None, exclude=True)
293
+ base_url: Optional[str] = None
294
+ frequency_penalty: Optional[float] = None
295
+ default_headers: Optional[dict[str, str]] = None
296
+ max_tokens: Optional[int] = None
297
+ n: Optional[int] = None
298
+ presence_penalty: Optional[float] = None
299
+ reasoning_effort: Optional[str] = None
300
+ seed: Optional[int] = None
301
+ temperature: Optional[float] = None
302
+ timeout: Optional[int] = None
303
+ top_p: Optional[float] = None
304
+
305
+ # TLS/SSL configuration fields
306
+ tls_disable_verify: Optional[bool] = None
307
+ tls_ca_cert_path: Optional[str] = None
308
+ tls_disable_system_cas: Optional[bool] = None
309
+
310
+ @classmethod
311
+ def supported_models(cls) -> list[str]:
312
+ """Returns a list of supported models in regex for LlmRegistry."""
313
+ return [r"gpt-.*", r"o1-.*"]
314
+
315
+ def _get_tls_config(self) -> tuple[bool, Optional[str], bool]:
316
+ """Read TLS configuration from instance fields.
317
+
318
+ Returns:
319
+ Tuple of (disable_verify, ca_cert_path, disable_system_cas)
320
+ """
321
+ # Read from instance fields only (config-based approach)
322
+ # Environment variables are no longer supported for TLS configuration
323
+ disable_verify = self.tls_disable_verify or False
324
+ ca_cert_path = self.tls_ca_cert_path
325
+ disable_system_cas = self.tls_disable_system_cas or False
326
+
327
+ return disable_verify, ca_cert_path, disable_system_cas
328
+
329
+ def _create_http_client(self) -> Optional[httpx.AsyncClient]:
330
+ """Create HTTP client with custom SSL context using OpenAI SDK defaults.
331
+
332
+ Uses DefaultAsyncHttpxClient to preserve OpenAI's default settings for
333
+ timeout, connection pooling, and redirect behavior while applying custom
334
+ SSL configuration.
335
+
336
+ Returns:
337
+ DefaultAsyncHttpxClient with SSL configuration, or None if no TLS config
338
+ """
339
+ disable_verify, ca_cert_path, disable_system_cas = self._get_tls_config()
340
+
341
+ # Only create custom http client if TLS configuration is present
342
+ if disable_verify or ca_cert_path or disable_system_cas:
343
+ ssl_context = create_ssl_context(
344
+ disable_verify=disable_verify,
345
+ ca_cert_path=ca_cert_path,
346
+ disable_system_cas=disable_system_cas,
347
+ )
348
+
349
+ # ssl_context is either False (verification disabled) or SSLContext
350
+ # Use DefaultAsyncHttpxClient to preserve OpenAI's defaults
351
+ return DefaultAsyncHttpxClient(verify=ssl_context)
352
+
353
+ # No TLS configuration, return None to use OpenAI SDK default
354
+ return None
355
+
356
+ @cached_property
357
+ def _client(self) -> AsyncOpenAI:
358
+ """Get the OpenAI client with optional custom SSL configuration."""
359
+ http_client = self._create_http_client()
360
+
361
+ return AsyncOpenAI(
362
+ api_key=self.api_key,
363
+ base_url=self.base_url or None,
364
+ default_headers=self.default_headers,
365
+ timeout=self.timeout,
366
+ http_client=http_client,
367
+ )
368
+
369
+ async def generate_content_async(
370
+ self, llm_request: LlmRequest, stream: bool = False
371
+ ) -> AsyncGenerator[LlmResponse, None]:
372
+ """Generate content using OpenAI API."""
373
+
374
+ # Convert messages
375
+ system_instruction = None
376
+ if llm_request.config and llm_request.config.system_instruction:
377
+ if isinstance(llm_request.config.system_instruction, str):
378
+ system_instruction = llm_request.config.system_instruction
379
+ elif hasattr(llm_request.config.system_instruction, "parts"):
380
+ # Handle Content type system instruction
381
+ text_parts = []
382
+ parts = getattr(llm_request.config.system_instruction, "parts", [])
383
+ if parts:
384
+ for part in parts:
385
+ if hasattr(part, "text") and part.text:
386
+ text_parts.append(part.text)
387
+ system_instruction = "\n".join(text_parts)
388
+
389
+ messages = _convert_content_to_openai_messages(llm_request.contents, system_instruction)
390
+
391
+ # Prepare request parameters
392
+ kwargs = {
393
+ "model": llm_request.model or self.model,
394
+ "messages": messages,
395
+ }
396
+
397
+ if self.frequency_penalty is not None:
398
+ kwargs["frequency_penalty"] = self.frequency_penalty
399
+ if self.max_tokens:
400
+ kwargs["max_tokens"] = self.max_tokens
401
+ if self.n is not None:
402
+ kwargs["n"] = self.n
403
+ if self.presence_penalty is not None:
404
+ kwargs["presence_penalty"] = self.presence_penalty
405
+ if self.reasoning_effort is not None:
406
+ kwargs["reasoning_effort"] = self.reasoning_effort
407
+ if self.seed is not None:
408
+ kwargs["seed"] = self.seed
409
+ if self.temperature is not None:
410
+ kwargs["temperature"] = self.temperature
411
+ if self.top_p is not None:
412
+ kwargs["top_p"] = self.top_p
413
+
414
+ # Handle tools
415
+ if llm_request.config and llm_request.config.tools:
416
+ # Filter to only google.genai.types.Tool objects
417
+ genai_tools = []
418
+ for tool in llm_request.config.tools:
419
+ if hasattr(tool, "function_declarations"):
420
+ genai_tools.append(tool)
421
+
422
+ if genai_tools:
423
+ openai_tools = _convert_tools_to_openai(genai_tools)
424
+ if openai_tools:
425
+ kwargs["tools"] = openai_tools
426
+ kwargs["tool_choice"] = "auto"
427
+
428
+ try:
429
+ if stream:
430
+ # Handle streaming
431
+ aggregated_text = ""
432
+ finish_reason = None
433
+ usage_metadata = None
434
+ # Accumulate tool calls - keyed by index since they arrive in chunks
435
+ tool_calls_acc: dict[int, dict[str, Any]] = {}
436
+
437
+ async for chunk in await self._client.chat.completions.create(stream=True, **kwargs):
438
+ if chunk.choices and chunk.choices[0].delta:
439
+ delta = chunk.choices[0].delta
440
+
441
+ # Handle text content streaming
442
+ if delta.content:
443
+ aggregated_text += delta.content
444
+ content = types.Content(role="model", parts=[types.Part.from_text(text=delta.content)])
445
+ yield LlmResponse(
446
+ content=content, partial=True, turn_complete=chunk.choices[0].finish_reason is not None
447
+ )
448
+
449
+ # Handle tool call chunks - accumulate them
450
+ if hasattr(delta, "tool_calls") and delta.tool_calls:
451
+ for tool_call_chunk in delta.tool_calls:
452
+ idx = tool_call_chunk.index
453
+ if idx not in tool_calls_acc:
454
+ tool_calls_acc[idx] = {
455
+ "id": "",
456
+ "name": "",
457
+ "arguments": "",
458
+ }
459
+ # Accumulate the chunks
460
+ if tool_call_chunk.id:
461
+ tool_calls_acc[idx]["id"] = tool_call_chunk.id
462
+ if tool_call_chunk.function:
463
+ if tool_call_chunk.function.name:
464
+ tool_calls_acc[idx]["name"] = tool_call_chunk.function.name
465
+ if tool_call_chunk.function.arguments:
466
+ tool_calls_acc[idx]["arguments"] += tool_call_chunk.function.arguments
467
+
468
+ if chunk.choices[0].finish_reason:
469
+ finish_reason = chunk.choices[0].finish_reason
470
+
471
+ if hasattr(chunk, "usage") and chunk.usage:
472
+ usage_metadata = types.GenerateContentResponseUsageMetadata(
473
+ prompt_token_count=chunk.usage.prompt_tokens,
474
+ candidates_token_count=chunk.usage.completion_tokens,
475
+ total_token_count=chunk.usage.total_tokens,
476
+ )
477
+
478
+ # Yield final aggregated response with partial=False
479
+ final_parts = []
480
+
481
+ # Add aggregated text if any
482
+ if aggregated_text:
483
+ final_parts.append(types.Part.from_text(text=aggregated_text))
484
+
485
+ # Add accumulated tool calls
486
+ for idx in sorted(tool_calls_acc.keys()):
487
+ tc = tool_calls_acc[idx]
488
+ try:
489
+ args = json.loads(tc["arguments"]) if tc["arguments"] else {}
490
+ except json.JSONDecodeError:
491
+ args = {}
492
+
493
+ part = types.Part.from_function_call(name=tc["name"], args=args)
494
+ if part.function_call:
495
+ part.function_call.id = tc["id"]
496
+ final_parts.append(part)
497
+
498
+ # Map finish reason
499
+ final_reason = types.FinishReason.STOP
500
+ if finish_reason == "length":
501
+ final_reason = types.FinishReason.MAX_TOKENS
502
+ elif finish_reason == "content_filter":
503
+ final_reason = types.FinishReason.SAFETY
504
+ elif finish_reason == "tool_calls":
505
+ final_reason = types.FinishReason.STOP # Tool calls is a normal completion
506
+
507
+ # Always yield final response to signal completion and valid metadata
508
+ final_content = types.Content(role="model", parts=final_parts)
509
+ yield LlmResponse(
510
+ content=final_content,
511
+ partial=False,
512
+ finish_reason=final_reason,
513
+ usage_metadata=usage_metadata,
514
+ turn_complete=True,
515
+ )
516
+ else:
517
+ # Handle non-streaming
518
+ response = await self._client.chat.completions.create(stream=False, **kwargs)
519
+ yield _convert_openai_response_to_llm_response(response)
520
+
521
+ except Exception as e:
522
+ yield LlmResponse(error_code="API_ERROR", error_message=str(e))
523
+
524
+
525
+ class OpenAI(BaseOpenAI):
526
+ """OpenAI model implementation."""
527
+
528
+ type: Literal["openai"]
529
+
530
+
531
+ class AzureOpenAI(BaseOpenAI):
532
+ """Azure OpenAI model implementation."""
533
+
534
+ type: Literal["azure_openai"]
535
+ api_version: Optional[str] = None
536
+ azure_endpoint: Optional[str] = None
537
+ azure_deployment: Optional[str] = None
538
+
539
+ @cached_property
540
+ def _client(self) -> AsyncAzureOpenAI:
541
+ """Get the Azure OpenAI client with optional custom SSL configuration."""
542
+ api_version = self.api_version or os.environ.get("OPENAI_API_VERSION", "2024-02-15-preview")
543
+ azure_endpoint = self.azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
544
+ api_key = self.api_key or os.environ.get("AZURE_OPENAI_API_KEY")
545
+
546
+ if not azure_endpoint:
547
+ raise ValueError(
548
+ "Azure endpoint must be provided either via azure_endpoint parameter or AZURE_OPENAI_ENDPOINT environment variable"
549
+ )
550
+
551
+ if not api_key:
552
+ raise ValueError(
553
+ "API key must be provided either via api_key parameter or AZURE_OPENAI_API_KEY environment variable"
554
+ )
555
+
556
+ http_client = self._create_http_client()
557
+
558
+ return AsyncAzureOpenAI(
559
+ api_key=api_key,
560
+ api_version=api_version,
561
+ azure_endpoint=azure_endpoint,
562
+ default_headers=self.default_headers,
563
+ http_client=http_client,
564
+ )