aiecs 1.7.6__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiecs might be problematic. Click here for more details.
- aiecs/__init__.py +1 -1
- aiecs/application/knowledge_graph/extractors/llm_entity_extractor.py +5 -1
- aiecs/application/knowledge_graph/retrieval/query_intent_classifier.py +7 -5
- aiecs/config/config.py +3 -0
- aiecs/config/tool_config.py +55 -19
- aiecs/domain/agent/base_agent.py +79 -0
- aiecs/domain/agent/hybrid_agent.py +552 -175
- aiecs/domain/agent/knowledge_aware_agent.py +3 -2
- aiecs/domain/agent/llm_agent.py +2 -0
- aiecs/domain/agent/models.py +10 -0
- aiecs/domain/agent/tools/schema_generator.py +17 -4
- aiecs/llm/callbacks/custom_callbacks.py +9 -4
- aiecs/llm/client_factory.py +20 -7
- aiecs/llm/clients/base_client.py +50 -5
- aiecs/llm/clients/google_function_calling_mixin.py +46 -88
- aiecs/llm/clients/googleai_client.py +183 -9
- aiecs/llm/clients/openai_client.py +12 -0
- aiecs/llm/clients/openai_compatible_mixin.py +42 -2
- aiecs/llm/clients/openrouter_client.py +272 -0
- aiecs/llm/clients/vertex_client.py +385 -22
- aiecs/llm/clients/xai_client.py +41 -3
- aiecs/llm/protocols.py +19 -1
- aiecs/llm/utils/image_utils.py +179 -0
- aiecs/main.py +2 -2
- aiecs/tools/docs/document_creator_tool.py +143 -2
- aiecs/tools/docs/document_parser_tool.py +9 -4
- aiecs/tools/docs/document_writer_tool.py +179 -0
- aiecs/tools/task_tools/image_tool.py +49 -14
- aiecs/tools/task_tools/scraper_tool.py +39 -2
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/METADATA +4 -2
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/RECORD +35 -33
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/WHEEL +0 -0
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/entry_points.txt +0 -0
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/licenses/LICENSE +0 -0
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/top_level.txt +0 -0
|
@@ -95,7 +95,7 @@ class KnowledgeAwareAgent(HybridAgent):
|
|
|
95
95
|
graph_store: Optional[GraphStore] = None,
|
|
96
96
|
description: Optional[str] = None,
|
|
97
97
|
version: str = "1.0.0",
|
|
98
|
-
max_iterations: int =
|
|
98
|
+
max_iterations: Optional[int] = None,
|
|
99
99
|
enable_graph_reasoning: bool = True,
|
|
100
100
|
config_manager: Optional["ConfigManagerProtocol"] = None,
|
|
101
101
|
checkpointer: Optional["CheckpointerProtocol"] = None,
|
|
@@ -118,7 +118,7 @@ class KnowledgeAwareAgent(HybridAgent):
|
|
|
118
118
|
graph_store: Optional knowledge graph store
|
|
119
119
|
description: Optional description
|
|
120
120
|
version: Agent version
|
|
121
|
-
max_iterations: Maximum ReAct iterations
|
|
121
|
+
max_iterations: Maximum ReAct iterations (if None, uses config.max_iterations)
|
|
122
122
|
enable_graph_reasoning: Whether to enable graph reasoning capabilities
|
|
123
123
|
config_manager: Optional configuration manager for dynamic config
|
|
124
124
|
checkpointer: Optional checkpointer for state persistence
|
|
@@ -745,6 +745,7 @@ Use graph reasoning proactively when questions involve:
|
|
|
745
745
|
model=self._config.llm_model,
|
|
746
746
|
temperature=self._config.temperature,
|
|
747
747
|
max_tokens=self._config.max_tokens,
|
|
748
|
+
context=context,
|
|
748
749
|
)
|
|
749
750
|
|
|
750
751
|
thought = response.content
|
aiecs/domain/agent/llm_agent.py
CHANGED
|
@@ -376,6 +376,7 @@ class LLMAgent(BaseAIAgent):
|
|
|
376
376
|
model=self._config.llm_model,
|
|
377
377
|
temperature=self._config.temperature,
|
|
378
378
|
max_tokens=self._config.max_tokens,
|
|
379
|
+
context=context,
|
|
379
380
|
)
|
|
380
381
|
|
|
381
382
|
# Extract result
|
|
@@ -513,6 +514,7 @@ class LLMAgent(BaseAIAgent):
|
|
|
513
514
|
model=self._config.llm_model,
|
|
514
515
|
temperature=self._config.temperature,
|
|
515
516
|
max_tokens=self._config.max_tokens,
|
|
517
|
+
context=context,
|
|
516
518
|
):
|
|
517
519
|
output_tokens.append(token)
|
|
518
520
|
yield {
|
aiecs/domain/agent/models.py
CHANGED
|
@@ -381,6 +381,16 @@ class AgentMetrics(BaseModel):
|
|
|
381
381
|
p95_operation_time: Optional[float] = Field(None, ge=0, description="95th percentile operation time in seconds")
|
|
382
382
|
p99_operation_time: Optional[float] = Field(None, ge=0, description="99th percentile operation time in seconds")
|
|
383
383
|
|
|
384
|
+
# Prompt cache metrics (for LLM provider-level caching observability)
|
|
385
|
+
total_llm_requests: int = Field(default=0, ge=0, description="Total number of LLM requests made")
|
|
386
|
+
cache_hits: int = Field(default=0, ge=0, description="Number of LLM requests with cache hits")
|
|
387
|
+
cache_misses: int = Field(default=0, ge=0, description="Number of LLM requests without cache hits (cache creation)")
|
|
388
|
+
cache_hit_rate: float = Field(default=0.0, ge=0.0, le=1.0, description="Prompt cache hit rate (0-1)")
|
|
389
|
+
total_cache_read_tokens: int = Field(default=0, ge=0, description="Total tokens read from prompt cache")
|
|
390
|
+
total_cache_creation_tokens: int = Field(default=0, ge=0, description="Total tokens used to create cache entries")
|
|
391
|
+
estimated_cache_savings_tokens: int = Field(default=0, ge=0, description="Estimated tokens saved from cache (cache_read_tokens * 0.9)")
|
|
392
|
+
estimated_cache_savings_cost: float = Field(default=0.0, ge=0, description="Estimated cost saved from cache in USD")
|
|
393
|
+
|
|
384
394
|
# Timestamps
|
|
385
395
|
last_reset_at: Optional[datetime] = Field(None, description="When metrics were last reset")
|
|
386
396
|
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last metrics update")
|
|
@@ -207,6 +207,12 @@ class ToolSchemaGenerator:
|
|
|
207
207
|
if not hasattr(schema_class, "model_fields"):
|
|
208
208
|
return properties, required
|
|
209
209
|
|
|
210
|
+
# Import PydanticUndefined for v2 compatibility
|
|
211
|
+
try:
|
|
212
|
+
from pydantic_core import PydanticUndefined
|
|
213
|
+
except ImportError:
|
|
214
|
+
PydanticUndefined = type(None) # Fallback for Pydantic v1
|
|
215
|
+
|
|
210
216
|
for field_name, field_info in schema_class.model_fields.items():
|
|
211
217
|
# Build property schema
|
|
212
218
|
prop_schema: Dict[str, Any] = {}
|
|
@@ -219,11 +225,18 @@ class ToolSchemaGenerator:
|
|
|
219
225
|
if hasattr(field_info, "description") and field_info.description:
|
|
220
226
|
prop_schema["description"] = field_info.description
|
|
221
227
|
|
|
222
|
-
#
|
|
223
|
-
if field_info
|
|
224
|
-
|
|
228
|
+
# Check if required using Pydantic v2 API (preferred)
|
|
229
|
+
if hasattr(field_info, "is_required") and callable(field_info.is_required):
|
|
230
|
+
if field_info.is_required():
|
|
231
|
+
required.append(field_name)
|
|
232
|
+
elif field_info.default is not None and field_info.default is not PydanticUndefined:
|
|
233
|
+
prop_schema["default"] = field_info.default
|
|
225
234
|
else:
|
|
226
|
-
|
|
235
|
+
# Fallback for Pydantic v1
|
|
236
|
+
if field_info.default is None or field_info.default == inspect.Parameter.empty:
|
|
237
|
+
required.append(field_name)
|
|
238
|
+
else:
|
|
239
|
+
prop_schema["default"] = field_info.default
|
|
227
240
|
|
|
228
241
|
properties[field_name] = prop_schema
|
|
229
242
|
|
|
@@ -33,7 +33,9 @@ class RedisTokenCallbackHandler(CustomAsyncCallbackHandler):
|
|
|
33
33
|
self.start_time = time.time()
|
|
34
34
|
self.messages = messages
|
|
35
35
|
|
|
36
|
-
|
|
36
|
+
# Defensive check for None messages
|
|
37
|
+
message_count = len(messages) if messages is not None else 0
|
|
38
|
+
logger.info(f"[Callback] LLM call started for user '{self.user_id}' with {message_count} messages")
|
|
37
39
|
|
|
38
40
|
async def on_llm_end(self, response: dict, **kwargs: Any) -> None:
|
|
39
41
|
"""Triggered when LLM call ends successfully"""
|
|
@@ -93,8 +95,8 @@ class DetailedRedisTokenCallbackHandler(CustomAsyncCallbackHandler):
|
|
|
93
95
|
self.start_time = time.time()
|
|
94
96
|
self.messages = messages
|
|
95
97
|
|
|
96
|
-
# Estimate input token count
|
|
97
|
-
self.prompt_tokens = self._estimate_prompt_tokens(messages)
|
|
98
|
+
# Estimate input token count with None check
|
|
99
|
+
self.prompt_tokens = self._estimate_prompt_tokens(messages) if messages else 0
|
|
98
100
|
|
|
99
101
|
logger.info(f"[DetailedCallback] LLM call started for user '{self.user_id}' with estimated {self.prompt_tokens} prompt tokens")
|
|
100
102
|
|
|
@@ -144,7 +146,10 @@ class DetailedRedisTokenCallbackHandler(CustomAsyncCallbackHandler):
|
|
|
144
146
|
|
|
145
147
|
def _estimate_prompt_tokens(self, messages: List[dict]) -> int:
|
|
146
148
|
"""Estimate token count for input messages"""
|
|
147
|
-
|
|
149
|
+
if not messages:
|
|
150
|
+
return 0
|
|
151
|
+
# Use `or ""` to handle both missing key AND None value
|
|
152
|
+
total_chars = sum(len(msg.get("content") or "") for msg in messages)
|
|
148
153
|
# Rough estimation: 4 characters ≈ 1 token
|
|
149
154
|
return total_chars // 4
|
|
150
155
|
|
aiecs/llm/client_factory.py
CHANGED
|
@@ -7,6 +7,8 @@ from .clients.openai_client import OpenAIClient
|
|
|
7
7
|
from .clients.vertex_client import VertexAIClient
|
|
8
8
|
from .clients.googleai_client import GoogleAIClient
|
|
9
9
|
from .clients.xai_client import XAIClient
|
|
10
|
+
from .clients.openrouter_client import OpenRouterClient
|
|
11
|
+
from .clients.openai_compatible_mixin import StreamChunk
|
|
10
12
|
from .callbacks.custom_callbacks import CustomAsyncCallbackHandler
|
|
11
13
|
|
|
12
14
|
if TYPE_CHECKING:
|
|
@@ -20,6 +22,7 @@ class AIProvider(str, Enum):
|
|
|
20
22
|
VERTEX = "Vertex"
|
|
21
23
|
GOOGLEAI = "GoogleAI"
|
|
22
24
|
XAI = "xAI"
|
|
25
|
+
OPENROUTER = "OpenRouter"
|
|
23
26
|
|
|
24
27
|
|
|
25
28
|
class LLMClientFactory:
|
|
@@ -132,6 +135,8 @@ class LLMClientFactory:
|
|
|
132
135
|
return GoogleAIClient()
|
|
133
136
|
elif provider == AIProvider.XAI:
|
|
134
137
|
return XAIClient()
|
|
138
|
+
elif provider == AIProvider.OPENROUTER:
|
|
139
|
+
return OpenRouterClient()
|
|
135
140
|
else:
|
|
136
141
|
raise ValueError(f"Unsupported provider: {provider}")
|
|
137
142
|
|
|
@@ -261,14 +266,16 @@ class LLMClientManager:
|
|
|
261
266
|
final_provider = context_provider or provider or AIProvider.OPENAI
|
|
262
267
|
final_model = context_model or model
|
|
263
268
|
|
|
264
|
-
# Convert string prompt to messages format
|
|
265
|
-
if
|
|
269
|
+
# Convert string prompt to messages format and handle None
|
|
270
|
+
if messages is None:
|
|
271
|
+
messages = []
|
|
272
|
+
elif isinstance(messages, str):
|
|
266
273
|
messages = [LLMMessage(role="user", content=messages)]
|
|
267
274
|
|
|
268
275
|
# Execute on_llm_start callbacks
|
|
269
276
|
if callbacks:
|
|
270
277
|
# Convert LLMMessage objects to dictionaries for callbacks
|
|
271
|
-
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
278
|
+
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages] if messages else []
|
|
272
279
|
for callback in callbacks:
|
|
273
280
|
try:
|
|
274
281
|
await callback.on_llm_start(
|
|
@@ -371,14 +378,16 @@ class LLMClientManager:
|
|
|
371
378
|
final_provider = context_provider or provider or AIProvider.OPENAI
|
|
372
379
|
final_model = context_model or model
|
|
373
380
|
|
|
374
|
-
# Convert string prompt to messages format
|
|
375
|
-
if
|
|
381
|
+
# Convert string prompt to messages format and handle None
|
|
382
|
+
if messages is None:
|
|
383
|
+
messages = []
|
|
384
|
+
elif isinstance(messages, str):
|
|
376
385
|
messages = [LLMMessage(role="user", content=messages)]
|
|
377
386
|
|
|
378
387
|
# Execute on_llm_start callbacks
|
|
379
388
|
if callbacks:
|
|
380
389
|
# Convert LLMMessage objects to dictionaries for callbacks
|
|
381
|
-
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
390
|
+
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages] if messages else []
|
|
382
391
|
for callback in callbacks:
|
|
383
392
|
try:
|
|
384
393
|
await callback.on_llm_start(
|
|
@@ -407,7 +416,11 @@ class LLMClientManager:
|
|
|
407
416
|
max_tokens=max_tokens,
|
|
408
417
|
**kwargs,
|
|
409
418
|
):
|
|
410
|
-
|
|
419
|
+
# Handle StreamChunk objects (when return_chunks=True or function calling)
|
|
420
|
+
if hasattr(chunk, 'content') and chunk.content:
|
|
421
|
+
collected_content += chunk.content
|
|
422
|
+
elif isinstance(chunk, str):
|
|
423
|
+
collected_content += chunk
|
|
411
424
|
yield chunk
|
|
412
425
|
|
|
413
426
|
# Create a response object for callbacks (streaming doesn't return LLMResponse directly)
|
aiecs/llm/clients/base_client.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Dict, Any, Optional, List, AsyncGenerator
|
|
3
|
-
from dataclasses import dataclass
|
|
2
|
+
from typing import Dict, Any, Optional, List, AsyncGenerator, Union
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
4
|
import logging
|
|
5
5
|
|
|
6
6
|
logger = logging.getLogger(__name__)
|
|
@@ -35,6 +35,7 @@ class LLMMessage:
|
|
|
35
35
|
Attributes:
|
|
36
36
|
role: Message role - "system", "user", "assistant", or "tool"
|
|
37
37
|
content: Text content of the message (None when using tool calls)
|
|
38
|
+
images: List of image sources (URLs, base64 data URIs, or file paths) for vision support
|
|
38
39
|
tool_calls: Tool call information for assistant messages
|
|
39
40
|
tool_call_id: Tool call ID for tool response messages
|
|
40
41
|
cache_control: Cache control marker for prompt caching support
|
|
@@ -42,6 +43,7 @@ class LLMMessage:
|
|
|
42
43
|
|
|
43
44
|
role: str # "system", "user", "assistant", "tool"
|
|
44
45
|
content: Optional[str] = None # None when using tool calls
|
|
46
|
+
images: List[Union[str, Dict[str, Any]]] = field(default_factory=list) # Image sources for vision support
|
|
45
47
|
tool_calls: Optional[List[Dict[str, Any]]] = None # For assistant messages with tool calls
|
|
46
48
|
tool_call_id: Optional[str] = None # For tool messages
|
|
47
49
|
cache_control: Optional[CacheControl] = None # Cache control for prompt caching
|
|
@@ -139,7 +141,11 @@ class SafetyBlockError(LLMClientError):
|
|
|
139
141
|
if self.block_type:
|
|
140
142
|
msg += f" (Block type: {self.block_type})"
|
|
141
143
|
if self.safety_ratings:
|
|
142
|
-
|
|
144
|
+
# Safely extract categories, handling potential non-dict elements
|
|
145
|
+
categories = []
|
|
146
|
+
for r in self.safety_ratings:
|
|
147
|
+
if isinstance(r, dict) and r.get("blocked"):
|
|
148
|
+
categories.append(r.get("category", "UNKNOWN"))
|
|
143
149
|
if categories:
|
|
144
150
|
msg += f" (Categories: {', '.join(categories)})"
|
|
145
151
|
return msg
|
|
@@ -159,9 +165,28 @@ class BaseLLMClient(ABC):
|
|
|
159
165
|
model: Optional[str] = None,
|
|
160
166
|
temperature: float = 0.7,
|
|
161
167
|
max_tokens: Optional[int] = None,
|
|
168
|
+
context: Optional[Dict[str, Any]] = None,
|
|
162
169
|
**kwargs,
|
|
163
170
|
) -> LLMResponse:
|
|
164
|
-
"""
|
|
171
|
+
"""
|
|
172
|
+
Generate text using the provider's API.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
messages: List of conversation messages
|
|
176
|
+
model: Model name (optional, uses default if not provided)
|
|
177
|
+
temperature: Sampling temperature (0.0 to 1.0)
|
|
178
|
+
max_tokens: Maximum tokens to generate
|
|
179
|
+
context: Optional context dictionary containing metadata such as:
|
|
180
|
+
- user_id: User identifier for tracking/billing
|
|
181
|
+
- tenant_id: Tenant identifier for multi-tenant setups
|
|
182
|
+
- request_id: Request identifier for tracing
|
|
183
|
+
- session_id: Session identifier
|
|
184
|
+
- Any other custom metadata for observability or middleware
|
|
185
|
+
**kwargs: Additional provider-specific parameters
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
LLMResponse with generated text and metadata
|
|
189
|
+
"""
|
|
165
190
|
|
|
166
191
|
@abstractmethod
|
|
167
192
|
async def stream_text(
|
|
@@ -170,9 +195,28 @@ class BaseLLMClient(ABC):
|
|
|
170
195
|
model: Optional[str] = None,
|
|
171
196
|
temperature: float = 0.7,
|
|
172
197
|
max_tokens: Optional[int] = None,
|
|
198
|
+
context: Optional[Dict[str, Any]] = None,
|
|
173
199
|
**kwargs,
|
|
174
200
|
) -> AsyncGenerator[str, None]:
|
|
175
|
-
"""
|
|
201
|
+
"""
|
|
202
|
+
Stream text generation using the provider's API.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
messages: List of conversation messages
|
|
206
|
+
model: Model name (optional, uses default if not provided)
|
|
207
|
+
temperature: Sampling temperature (0.0 to 1.0)
|
|
208
|
+
max_tokens: Maximum tokens to generate
|
|
209
|
+
context: Optional context dictionary containing metadata such as:
|
|
210
|
+
- user_id: User identifier for tracking/billing
|
|
211
|
+
- tenant_id: Tenant identifier for multi-tenant setups
|
|
212
|
+
- request_id: Request identifier for tracing
|
|
213
|
+
- session_id: Session identifier
|
|
214
|
+
- Any other custom metadata for observability or middleware
|
|
215
|
+
**kwargs: Additional provider-specific parameters
|
|
216
|
+
|
|
217
|
+
Yields:
|
|
218
|
+
Text tokens as they are generated
|
|
219
|
+
"""
|
|
176
220
|
|
|
177
221
|
@abstractmethod
|
|
178
222
|
async def close(self):
|
|
@@ -217,6 +261,7 @@ class BaseLLMClient(ABC):
|
|
|
217
261
|
LLMMessage(
|
|
218
262
|
role=msg.role,
|
|
219
263
|
content=msg.content,
|
|
264
|
+
images=msg.images,
|
|
220
265
|
tool_calls=msg.tool_calls,
|
|
221
266
|
tool_call_id=msg.tool_call_id,
|
|
222
267
|
cache_control=CacheControl(type="ephemeral"),
|
|
@@ -5,6 +5,7 @@ Provides shared implementation for Google providers (Vertex AI, Google AI)
|
|
|
5
5
|
that use FunctionDeclaration format for Function Calling.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
+
import json
|
|
8
9
|
import logging
|
|
9
10
|
from typing import Dict, Any, Optional, List, Union, AsyncGenerator
|
|
10
11
|
from dataclasses import dataclass
|
|
@@ -12,8 +13,6 @@ from vertexai.generative_models import (
|
|
|
12
13
|
FunctionDeclaration,
|
|
13
14
|
Tool,
|
|
14
15
|
)
|
|
15
|
-
from google.genai.types import Schema, Type
|
|
16
|
-
|
|
17
16
|
from .base_client import LLMMessage, LLMResponse
|
|
18
17
|
|
|
19
18
|
logger = logging.getLogger(__name__)
|
|
@@ -32,13 +31,46 @@ except ImportError:
|
|
|
32
31
|
tool_calls: Optional[List[Dict[str, Any]]] = None
|
|
33
32
|
|
|
34
33
|
|
|
34
|
+
def _serialize_function_args(args) -> str:
|
|
35
|
+
"""
|
|
36
|
+
Safely serialize function call arguments to JSON string.
|
|
37
|
+
|
|
38
|
+
Handles MapComposite/protobuf objects from Vertex AI by converting
|
|
39
|
+
them to regular dicts before JSON serialization.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
args: Function call arguments (may be MapComposite, dict, or other)
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
JSON string representation of the arguments
|
|
46
|
+
"""
|
|
47
|
+
if args is None:
|
|
48
|
+
return "{}"
|
|
49
|
+
|
|
50
|
+
# Handle MapComposite/protobuf objects (they have items() method)
|
|
51
|
+
if hasattr(args, 'items'):
|
|
52
|
+
# Convert to regular dict
|
|
53
|
+
args_dict = dict(args)
|
|
54
|
+
elif isinstance(args, dict):
|
|
55
|
+
args_dict = args
|
|
56
|
+
else:
|
|
57
|
+
# Try to convert to dict if possible
|
|
58
|
+
try:
|
|
59
|
+
args_dict = dict(args)
|
|
60
|
+
except (TypeError, ValueError):
|
|
61
|
+
# Last resort: use str() but this should rarely happen
|
|
62
|
+
return str(args)
|
|
63
|
+
|
|
64
|
+
return json.dumps(args_dict, ensure_ascii=False)
|
|
65
|
+
|
|
66
|
+
|
|
35
67
|
class GoogleFunctionCallingMixin:
|
|
36
68
|
"""
|
|
37
69
|
Mixin class providing Google Function Calling implementation.
|
|
38
|
-
|
|
70
|
+
|
|
39
71
|
This mixin can be used by Google providers (Vertex AI, Google AI)
|
|
40
72
|
that use FunctionDeclaration format for Function Calling.
|
|
41
|
-
|
|
73
|
+
|
|
42
74
|
Usage:
|
|
43
75
|
class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
44
76
|
async def generate_text(self, messages, tools=None, ...):
|
|
@@ -71,15 +103,13 @@ class GoogleFunctionCallingMixin:
|
|
|
71
103
|
if not func_name:
|
|
72
104
|
logger.warning(f"Skipping tool without name: {tool}")
|
|
73
105
|
continue
|
|
74
|
-
|
|
75
|
-
#
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
# Create FunctionDeclaration
|
|
106
|
+
|
|
107
|
+
# Create FunctionDeclaration with raw dict parameters
|
|
108
|
+
# Let Vertex SDK handle the schema conversion internally
|
|
79
109
|
function_declaration = FunctionDeclaration(
|
|
80
110
|
name=func_name,
|
|
81
111
|
description=func_description,
|
|
82
|
-
parameters=
|
|
112
|
+
parameters=func_parameters,
|
|
83
113
|
)
|
|
84
114
|
|
|
85
115
|
function_declarations.append(function_declaration)
|
|
@@ -91,78 +121,6 @@ class GoogleFunctionCallingMixin:
|
|
|
91
121
|
return [Tool(function_declarations=function_declarations)]
|
|
92
122
|
return []
|
|
93
123
|
|
|
94
|
-
def _convert_json_schema_to_google_schema(
|
|
95
|
-
self, json_schema: Dict[str, Any]
|
|
96
|
-
) -> Schema:
|
|
97
|
-
"""
|
|
98
|
-
Convert JSON Schema to Google Schema format.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
json_schema: JSON Schema dictionary
|
|
102
|
-
|
|
103
|
-
Returns:
|
|
104
|
-
Google Schema object
|
|
105
|
-
"""
|
|
106
|
-
schema_type = json_schema.get("type", "object")
|
|
107
|
-
properties = json_schema.get("properties", {})
|
|
108
|
-
required = json_schema.get("required", [])
|
|
109
|
-
|
|
110
|
-
# Convert type
|
|
111
|
-
google_type = self._convert_json_type_to_google_type(schema_type)
|
|
112
|
-
|
|
113
|
-
# Convert properties (only for object types)
|
|
114
|
-
google_properties = None
|
|
115
|
-
if schema_type == "object" and properties:
|
|
116
|
-
google_properties = {}
|
|
117
|
-
for prop_name, prop_schema in properties.items():
|
|
118
|
-
google_properties[prop_name] = self._convert_json_schema_to_google_schema(
|
|
119
|
-
prop_schema
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
# Handle array items
|
|
123
|
-
items = None
|
|
124
|
-
if schema_type == "array" and "items" in json_schema:
|
|
125
|
-
items = self._convert_json_schema_to_google_schema(json_schema["items"])
|
|
126
|
-
|
|
127
|
-
# Create Schema
|
|
128
|
-
schema_kwargs = {
|
|
129
|
-
"type": google_type,
|
|
130
|
-
}
|
|
131
|
-
|
|
132
|
-
if google_properties is not None:
|
|
133
|
-
schema_kwargs["properties"] = google_properties
|
|
134
|
-
|
|
135
|
-
if required:
|
|
136
|
-
schema_kwargs["required"] = required
|
|
137
|
-
|
|
138
|
-
if items is not None:
|
|
139
|
-
schema_kwargs["items"] = items
|
|
140
|
-
|
|
141
|
-
schema = Schema(**schema_kwargs)
|
|
142
|
-
|
|
143
|
-
return schema
|
|
144
|
-
|
|
145
|
-
def _convert_json_type_to_google_type(self, json_type: str) -> Type:
|
|
146
|
-
"""
|
|
147
|
-
Convert JSON Schema type to Google Type enum.
|
|
148
|
-
|
|
149
|
-
Args:
|
|
150
|
-
json_type: JSON Schema type string
|
|
151
|
-
|
|
152
|
-
Returns:
|
|
153
|
-
Google Type enum value
|
|
154
|
-
"""
|
|
155
|
-
type_mapping = {
|
|
156
|
-
"string": Type.STRING,
|
|
157
|
-
"number": Type.NUMBER,
|
|
158
|
-
"integer": Type.NUMBER, # Google uses NUMBER for both
|
|
159
|
-
"boolean": Type.BOOLEAN,
|
|
160
|
-
"array": Type.ARRAY,
|
|
161
|
-
"object": Type.OBJECT,
|
|
162
|
-
}
|
|
163
|
-
|
|
164
|
-
return type_mapping.get(json_type.lower(), Type.OBJECT)
|
|
165
|
-
|
|
166
124
|
def _extract_function_calls_from_google_response(
|
|
167
125
|
self, response: Any
|
|
168
126
|
) -> Optional[List[Dict[str, Any]]]:
|
|
@@ -191,10 +149,10 @@ class GoogleFunctionCallingMixin:
|
|
|
191
149
|
"type": "function",
|
|
192
150
|
"function": {
|
|
193
151
|
"name": func_call.name,
|
|
194
|
-
"arguments":
|
|
152
|
+
"arguments": _serialize_function_args(func_call.args) if hasattr(func_call, "args") else "{}",
|
|
195
153
|
},
|
|
196
154
|
})
|
|
197
|
-
|
|
155
|
+
|
|
198
156
|
# Check for content.parts with function_call (newer API)
|
|
199
157
|
elif hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
|
|
200
158
|
for part in candidate.content.parts:
|
|
@@ -205,7 +163,7 @@ class GoogleFunctionCallingMixin:
|
|
|
205
163
|
"type": "function",
|
|
206
164
|
"function": {
|
|
207
165
|
"name": func_call.name,
|
|
208
|
-
"arguments":
|
|
166
|
+
"arguments": _serialize_function_args(func_call.args) if hasattr(func_call, "args") else "{}",
|
|
209
167
|
},
|
|
210
168
|
})
|
|
211
169
|
|
|
@@ -300,10 +258,10 @@ class GoogleFunctionCallingMixin:
|
|
|
300
258
|
"type": "function",
|
|
301
259
|
"function": {
|
|
302
260
|
"name": func_call.name,
|
|
303
|
-
"arguments":
|
|
261
|
+
"arguments": _serialize_function_args(func_call.args) if hasattr(func_call, "args") else "{}",
|
|
304
262
|
},
|
|
305
263
|
})
|
|
306
|
-
|
|
264
|
+
|
|
307
265
|
# Check for function_call attribute directly on candidate
|
|
308
266
|
elif hasattr(candidate, "function_call") and candidate.function_call:
|
|
309
267
|
func_call = candidate.function_call
|
|
@@ -312,7 +270,7 @@ class GoogleFunctionCallingMixin:
|
|
|
312
270
|
"type": "function",
|
|
313
271
|
"function": {
|
|
314
272
|
"name": func_call.name,
|
|
315
|
-
"arguments":
|
|
273
|
+
"arguments": _serialize_function_args(func_call.args) if hasattr(func_call, "args") else "{}",
|
|
316
274
|
},
|
|
317
275
|
})
|
|
318
276
|
|