aiecs 1.7.6__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiecs might be problematic. Click here for more details.
- aiecs/__init__.py +1 -1
- aiecs/application/knowledge_graph/extractors/llm_entity_extractor.py +5 -1
- aiecs/application/knowledge_graph/retrieval/query_intent_classifier.py +7 -5
- aiecs/config/config.py +3 -0
- aiecs/config/tool_config.py +55 -19
- aiecs/domain/agent/base_agent.py +79 -0
- aiecs/domain/agent/hybrid_agent.py +552 -175
- aiecs/domain/agent/knowledge_aware_agent.py +3 -2
- aiecs/domain/agent/llm_agent.py +2 -0
- aiecs/domain/agent/models.py +10 -0
- aiecs/domain/agent/tools/schema_generator.py +17 -4
- aiecs/llm/callbacks/custom_callbacks.py +9 -4
- aiecs/llm/client_factory.py +20 -7
- aiecs/llm/clients/base_client.py +50 -5
- aiecs/llm/clients/google_function_calling_mixin.py +46 -88
- aiecs/llm/clients/googleai_client.py +183 -9
- aiecs/llm/clients/openai_client.py +12 -0
- aiecs/llm/clients/openai_compatible_mixin.py +42 -2
- aiecs/llm/clients/openrouter_client.py +272 -0
- aiecs/llm/clients/vertex_client.py +385 -22
- aiecs/llm/clients/xai_client.py +41 -3
- aiecs/llm/protocols.py +19 -1
- aiecs/llm/utils/image_utils.py +179 -0
- aiecs/main.py +2 -2
- aiecs/tools/docs/document_creator_tool.py +143 -2
- aiecs/tools/docs/document_parser_tool.py +9 -4
- aiecs/tools/docs/document_writer_tool.py +179 -0
- aiecs/tools/task_tools/image_tool.py +49 -14
- aiecs/tools/task_tools/scraper_tool.py +39 -2
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/METADATA +4 -2
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/RECORD +35 -33
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/WHEEL +0 -0
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/entry_points.txt +0 -0
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/licenses/LICENSE +0 -0
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import json
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
4
5
|
import warnings
|
|
5
|
-
|
|
6
|
+
import hashlib
|
|
7
|
+
import base64
|
|
8
|
+
from typing import Dict, Any, Optional, List, AsyncGenerator, Union
|
|
6
9
|
import vertexai
|
|
7
10
|
from vertexai.generative_models import (
|
|
8
11
|
GenerativeModel,
|
|
@@ -14,6 +17,45 @@ from vertexai.generative_models import (
|
|
|
14
17
|
Part,
|
|
15
18
|
)
|
|
16
19
|
|
|
20
|
+
from aiecs.llm.utils.image_utils import parse_image_source, ImageContent
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
# Try to import CachedContent for prompt caching support
|
|
25
|
+
# CachedContent API requires google-cloud-aiplatform >= 1.38.0
|
|
26
|
+
# Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/cached-content
|
|
27
|
+
CACHED_CONTENT_AVAILABLE = False
|
|
28
|
+
CACHED_CONTENT_IMPORT_PATH = None
|
|
29
|
+
CACHED_CONTENT_SDK_VERSION = None
|
|
30
|
+
|
|
31
|
+
# Check SDK version
|
|
32
|
+
try:
|
|
33
|
+
import google.cloud.aiplatform as aiplatform
|
|
34
|
+
CACHED_CONTENT_SDK_VERSION = getattr(aiplatform, '__version__', None)
|
|
35
|
+
except ImportError:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
# Try to import CachedContent for prompt caching support
|
|
39
|
+
try:
|
|
40
|
+
from vertexai.preview import caching
|
|
41
|
+
if hasattr(caching, 'CachedContent'):
|
|
42
|
+
CACHED_CONTENT_AVAILABLE = True
|
|
43
|
+
CACHED_CONTENT_IMPORT_PATH = 'vertexai.preview.caching'
|
|
44
|
+
else:
|
|
45
|
+
# Module exists but CachedContent class not found
|
|
46
|
+
CACHED_CONTENT_AVAILABLE = False
|
|
47
|
+
except ImportError:
|
|
48
|
+
try:
|
|
49
|
+
# Alternative import path for different SDK versions
|
|
50
|
+
from vertexai import caching
|
|
51
|
+
if hasattr(caching, 'CachedContent'):
|
|
52
|
+
CACHED_CONTENT_AVAILABLE = True
|
|
53
|
+
CACHED_CONTENT_IMPORT_PATH = 'vertexai.caching'
|
|
54
|
+
else:
|
|
55
|
+
CACHED_CONTENT_AVAILABLE = False
|
|
56
|
+
except ImportError:
|
|
57
|
+
CACHED_CONTENT_AVAILABLE = False
|
|
58
|
+
|
|
17
59
|
from aiecs.llm.clients.base_client import (
|
|
18
60
|
BaseLLMClient,
|
|
19
61
|
LLMMessage,
|
|
@@ -147,17 +189,20 @@ def _build_safety_block_error(
|
|
|
147
189
|
error_parts = [default_message]
|
|
148
190
|
if block_reason:
|
|
149
191
|
error_parts.append(f"Block reason: {block_reason}")
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
if r.get("blocked", False)
|
|
155
|
-
|
|
192
|
+
|
|
193
|
+
# Safely extract blocked categories, handling potential non-dict elements
|
|
194
|
+
blocked_categories = []
|
|
195
|
+
for r in safety_ratings:
|
|
196
|
+
if isinstance(r, dict) and r.get("blocked", False):
|
|
197
|
+
blocked_categories.append(r.get("category", "UNKNOWN"))
|
|
156
198
|
if blocked_categories:
|
|
157
199
|
error_parts.append(f"Blocked categories: {', '.join(blocked_categories)}")
|
|
158
|
-
|
|
200
|
+
|
|
159
201
|
# Add severity/probability information
|
|
160
202
|
for rating in safety_ratings:
|
|
203
|
+
# Skip non-dict elements
|
|
204
|
+
if not isinstance(rating, dict):
|
|
205
|
+
continue
|
|
161
206
|
if rating.get("blocked"):
|
|
162
207
|
if "severity" in rating:
|
|
163
208
|
error_parts.append(
|
|
@@ -193,6 +238,8 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
193
238
|
"part_counts": {}, # {part_count: frequency}
|
|
194
239
|
"last_part_count": None,
|
|
195
240
|
}
|
|
241
|
+
# Cache for CachedContent objects (key: content hash, value: cached_content_id)
|
|
242
|
+
self._cached_content_cache: Dict[str, str] = {}
|
|
196
243
|
|
|
197
244
|
def _init_vertex_ai(self):
|
|
198
245
|
"""Lazy initialization of Vertex AI with proper authentication"""
|
|
@@ -230,14 +277,140 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
230
277
|
except Exception as e:
|
|
231
278
|
raise ProviderNotAvailableError(f"Failed to initialize Vertex AI: {str(e)}")
|
|
232
279
|
|
|
280
|
+
def _generate_content_hash(self, content: str) -> str:
|
|
281
|
+
"""Generate a hash for content to use as cache key."""
|
|
282
|
+
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
|
283
|
+
|
|
284
|
+
async def _create_or_get_cached_content(
|
|
285
|
+
self,
|
|
286
|
+
content: str,
|
|
287
|
+
model_name: str,
|
|
288
|
+
ttl_seconds: Optional[int] = None,
|
|
289
|
+
) -> Optional[str]:
|
|
290
|
+
"""
|
|
291
|
+
Create or get a CachedContent for the given content.
|
|
292
|
+
|
|
293
|
+
This method implements Gemini's CachedContent API for prompt caching.
|
|
294
|
+
It preserves the existing cache_control mechanism for developer convenience.
|
|
295
|
+
|
|
296
|
+
The method supports multiple Vertex AI SDK versions and gracefully falls back
|
|
297
|
+
to regular system_instruction if CachedContent API is unavailable.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
content: Content to cache (typically system instruction)
|
|
301
|
+
model_name: Model name to use for caching
|
|
302
|
+
ttl_seconds: Time to live in seconds (optional, defaults to 3600)
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
CachedContent resource name (e.g., "projects/.../cachedContents/...") or None if caching unavailable
|
|
306
|
+
"""
|
|
307
|
+
if not CACHED_CONTENT_AVAILABLE:
|
|
308
|
+
# Provide version info if available
|
|
309
|
+
version_info = ""
|
|
310
|
+
if CACHED_CONTENT_SDK_VERSION:
|
|
311
|
+
version_info = f" (SDK version: {CACHED_CONTENT_SDK_VERSION})"
|
|
312
|
+
elif CACHED_CONTENT_IMPORT_PATH:
|
|
313
|
+
version_info = f" (import path '{CACHED_CONTENT_IMPORT_PATH}' available but CachedContent class not found)"
|
|
314
|
+
|
|
315
|
+
self.logger.debug(
|
|
316
|
+
f"CachedContent API not available{version_info}, skipping cache creation. "
|
|
317
|
+
f"Requires google-cloud-aiplatform >=1.38.0"
|
|
318
|
+
)
|
|
319
|
+
return None
|
|
320
|
+
|
|
321
|
+
if not content or not content.strip():
|
|
322
|
+
return None
|
|
323
|
+
|
|
324
|
+
# Generate cache key
|
|
325
|
+
cache_key = self._generate_content_hash(content)
|
|
326
|
+
|
|
327
|
+
# Check if we already have this cached
|
|
328
|
+
if cache_key in self._cached_content_cache:
|
|
329
|
+
cached_content_id = self._cached_content_cache[cache_key]
|
|
330
|
+
self.logger.debug(f"Using existing CachedContent: {cached_content_id}")
|
|
331
|
+
return cached_content_id
|
|
332
|
+
|
|
333
|
+
try:
|
|
334
|
+
self._init_vertex_ai()
|
|
335
|
+
|
|
336
|
+
# Build the content to cache (system instruction as Content)
|
|
337
|
+
# For CachedContent, we typically cache the system instruction
|
|
338
|
+
cached_content_obj = Content(
|
|
339
|
+
role="user",
|
|
340
|
+
parts=[Part.from_text(content)]
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Try different API patterns based on SDK version
|
|
344
|
+
cached_content_id = None
|
|
345
|
+
|
|
346
|
+
# Pattern 1: caching.CachedContent.create() (most common)
|
|
347
|
+
if hasattr(caching, 'CachedContent'):
|
|
348
|
+
try:
|
|
349
|
+
cached_content = await asyncio.get_event_loop().run_in_executor(
|
|
350
|
+
None,
|
|
351
|
+
lambda: caching.CachedContent.create(
|
|
352
|
+
model=model_name,
|
|
353
|
+
contents=[cached_content_obj],
|
|
354
|
+
ttl_seconds=ttl_seconds or 3600, # Default 1 hour
|
|
355
|
+
)
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# Extract the resource name
|
|
359
|
+
if hasattr(cached_content, 'name'):
|
|
360
|
+
cached_content_id = cached_content.name
|
|
361
|
+
elif hasattr(cached_content, 'resource_name'):
|
|
362
|
+
cached_content_id = cached_content.resource_name
|
|
363
|
+
else:
|
|
364
|
+
cached_content_id = str(cached_content)
|
|
365
|
+
|
|
366
|
+
if cached_content_id:
|
|
367
|
+
# Store in cache
|
|
368
|
+
self._cached_content_cache[cache_key] = cached_content_id
|
|
369
|
+
self.logger.info(f"Created CachedContent for prompt caching: {cached_content_id}")
|
|
370
|
+
return cached_content_id
|
|
371
|
+
|
|
372
|
+
except AttributeError as e:
|
|
373
|
+
self.logger.debug(f"CachedContent.create() signature may differ: {str(e)}")
|
|
374
|
+
except Exception as e:
|
|
375
|
+
self.logger.debug(f"Failed to create CachedContent using pattern 1: {str(e)}")
|
|
376
|
+
|
|
377
|
+
# Pattern 2: Try alternative API patterns if Pattern 1 fails
|
|
378
|
+
# Note: Different SDK versions may have different APIs
|
|
379
|
+
# This is a fallback that allows graceful degradation
|
|
380
|
+
|
|
381
|
+
# Build informative warning message with version info
|
|
382
|
+
version_info = ""
|
|
383
|
+
if CACHED_CONTENT_SDK_VERSION:
|
|
384
|
+
version_info = f" Current SDK version: {CACHED_CONTENT_SDK_VERSION}."
|
|
385
|
+
else:
|
|
386
|
+
version_info = " Unable to detect SDK version."
|
|
387
|
+
|
|
388
|
+
required_version = ">=1.38.0"
|
|
389
|
+
upgrade_command = "pip install --upgrade 'google-cloud-aiplatform>=1.38.0'"
|
|
390
|
+
|
|
391
|
+
self.logger.warning(
|
|
392
|
+
f"CachedContent API not available or incompatible with current SDK version.{version_info} "
|
|
393
|
+
f"Falling back to system_instruction (prompt caching disabled). "
|
|
394
|
+
f"To enable prompt caching, upgrade to google-cloud-aiplatform {required_version} or later: "
|
|
395
|
+
f"{upgrade_command}"
|
|
396
|
+
)
|
|
397
|
+
return None
|
|
398
|
+
|
|
399
|
+
except Exception as e:
|
|
400
|
+
self.logger.warning(
|
|
401
|
+
f"Failed to create CachedContent (prompt caching disabled, using system_instruction): {str(e)}"
|
|
402
|
+
)
|
|
403
|
+
# Don't raise - allow fallback to regular generation without caching
|
|
404
|
+
return None
|
|
405
|
+
|
|
233
406
|
def _convert_messages_to_contents(
|
|
234
407
|
self, messages: List[LLMMessage]
|
|
235
408
|
) -> List[Content]:
|
|
236
409
|
"""
|
|
237
410
|
Convert LLMMessage list to Vertex AI Content objects.
|
|
238
411
|
|
|
239
|
-
This properly handles multi-turn conversations
|
|
240
|
-
|
|
412
|
+
This properly handles multi-turn conversations including
|
|
413
|
+
function/tool responses for Vertex AI Function Calling.
|
|
241
414
|
|
|
242
415
|
Args:
|
|
243
416
|
messages: List of LLMMessage objects (system messages should be filtered out)
|
|
@@ -246,13 +419,118 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
246
419
|
List of Content objects for Vertex AI API
|
|
247
420
|
"""
|
|
248
421
|
contents = []
|
|
422
|
+
|
|
249
423
|
for msg in messages:
|
|
250
|
-
#
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
424
|
+
# Handle tool/function responses (role="tool")
|
|
425
|
+
if msg.role == "tool":
|
|
426
|
+
# Vertex AI expects function responses as user messages with FunctionResponse parts
|
|
427
|
+
# The tool_call_id maps to the function name
|
|
428
|
+
func_name = msg.tool_call_id or "unknown_function"
|
|
429
|
+
|
|
430
|
+
# Parse content as the function response
|
|
431
|
+
try:
|
|
432
|
+
# Try to parse as JSON if it looks like JSON
|
|
433
|
+
if msg.content and msg.content.strip().startswith('{'):
|
|
434
|
+
response_data = json.loads(msg.content)
|
|
435
|
+
else:
|
|
436
|
+
response_data = {"result": msg.content}
|
|
437
|
+
except json.JSONDecodeError:
|
|
438
|
+
response_data = {"result": msg.content}
|
|
439
|
+
|
|
440
|
+
# Create FunctionResponse part using Part.from_function_response
|
|
441
|
+
func_response_part = Part.from_function_response(
|
|
442
|
+
name=func_name,
|
|
443
|
+
response=response_data
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
contents.append(Content(
|
|
447
|
+
role="user", # Function responses are sent as "user" role in Vertex AI
|
|
448
|
+
parts=[func_response_part]
|
|
449
|
+
))
|
|
450
|
+
|
|
451
|
+
# Handle assistant messages with tool calls
|
|
452
|
+
elif msg.role == "assistant" and msg.tool_calls:
|
|
453
|
+
parts = []
|
|
454
|
+
if msg.content:
|
|
455
|
+
parts.append(Part.from_text(msg.content))
|
|
456
|
+
|
|
457
|
+
# Add images if present
|
|
458
|
+
if msg.images:
|
|
459
|
+
for image_source in msg.images:
|
|
460
|
+
image_content = parse_image_source(image_source)
|
|
461
|
+
|
|
462
|
+
if image_content.is_url():
|
|
463
|
+
parts.append(Part.from_uri(
|
|
464
|
+
uri=image_content.get_url(),
|
|
465
|
+
mime_type=image_content.mime_type
|
|
466
|
+
))
|
|
467
|
+
else:
|
|
468
|
+
base64_data = image_content.get_base64_data()
|
|
469
|
+
image_bytes = base64.b64decode(base64_data)
|
|
470
|
+
parts.append(Part.from_bytes(
|
|
471
|
+
data=image_bytes,
|
|
472
|
+
mime_type=image_content.mime_type
|
|
473
|
+
))
|
|
474
|
+
|
|
475
|
+
for tool_call in msg.tool_calls:
|
|
476
|
+
func = tool_call.get("function", {})
|
|
477
|
+
func_name = func.get("name", "")
|
|
478
|
+
func_args = func.get("arguments", "{}")
|
|
479
|
+
|
|
480
|
+
# Parse arguments
|
|
481
|
+
try:
|
|
482
|
+
args_dict = json.loads(func_args) if isinstance(func_args, str) else func_args
|
|
483
|
+
except json.JSONDecodeError:
|
|
484
|
+
args_dict = {}
|
|
485
|
+
|
|
486
|
+
# Create FunctionCall part using Part.from_dict
|
|
487
|
+
# Note: Part.from_function_call() does NOT exist in Vertex AI SDK
|
|
488
|
+
# Must use from_dict with function_call structure
|
|
489
|
+
function_call_part = Part.from_dict({
|
|
490
|
+
"function_call": {
|
|
491
|
+
"name": func_name,
|
|
492
|
+
"args": args_dict
|
|
493
|
+
}
|
|
494
|
+
})
|
|
495
|
+
parts.append(function_call_part)
|
|
496
|
+
|
|
497
|
+
contents.append(Content(
|
|
498
|
+
role="model",
|
|
499
|
+
parts=parts
|
|
500
|
+
))
|
|
501
|
+
|
|
502
|
+
# Handle regular messages (user, assistant without tool_calls)
|
|
503
|
+
else:
|
|
504
|
+
role = "model" if msg.role == "assistant" else msg.role
|
|
505
|
+
parts = []
|
|
506
|
+
|
|
507
|
+
# Add text content if present
|
|
508
|
+
if msg.content:
|
|
509
|
+
parts.append(Part.from_text(msg.content))
|
|
510
|
+
|
|
511
|
+
# Add images if present
|
|
512
|
+
if msg.images:
|
|
513
|
+
for image_source in msg.images:
|
|
514
|
+
image_content = parse_image_source(image_source)
|
|
515
|
+
|
|
516
|
+
if image_content.is_url():
|
|
517
|
+
# Use Part.from_uri for URLs
|
|
518
|
+
parts.append(Part.from_uri(
|
|
519
|
+
uri=image_content.get_url(),
|
|
520
|
+
mime_type=image_content.mime_type
|
|
521
|
+
))
|
|
522
|
+
else:
|
|
523
|
+
# Convert to bytes for inline_data
|
|
524
|
+
base64_data = image_content.get_base64_data()
|
|
525
|
+
image_bytes = base64.b64decode(base64_data)
|
|
526
|
+
parts.append(Part.from_bytes(
|
|
527
|
+
data=image_bytes,
|
|
528
|
+
mime_type=image_content.mime_type
|
|
529
|
+
))
|
|
530
|
+
|
|
531
|
+
if parts:
|
|
532
|
+
contents.append(Content(role=role, parts=parts))
|
|
533
|
+
|
|
256
534
|
return contents
|
|
257
535
|
|
|
258
536
|
async def generate_text(
|
|
@@ -261,13 +539,36 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
261
539
|
model: Optional[str] = None,
|
|
262
540
|
temperature: float = 0.7,
|
|
263
541
|
max_tokens: Optional[int] = None,
|
|
542
|
+
context: Optional[Dict[str, Any]] = None,
|
|
264
543
|
functions: Optional[List[Dict[str, Any]]] = None,
|
|
265
544
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
266
545
|
tool_choice: Optional[Any] = None,
|
|
267
546
|
system_instruction: Optional[str] = None,
|
|
268
547
|
**kwargs,
|
|
269
548
|
) -> LLMResponse:
|
|
270
|
-
"""
|
|
549
|
+
"""
|
|
550
|
+
Generate text using Vertex AI.
|
|
551
|
+
|
|
552
|
+
Args:
|
|
553
|
+
messages: List of conversation messages
|
|
554
|
+
model: Model name (optional, uses default if not provided)
|
|
555
|
+
temperature: Sampling temperature (0.0 to 1.0)
|
|
556
|
+
max_tokens: Maximum tokens to generate
|
|
557
|
+
context: Optional context dictionary containing metadata such as:
|
|
558
|
+
- user_id: User identifier for tracking/billing
|
|
559
|
+
- tenant_id: Tenant identifier for multi-tenant setups
|
|
560
|
+
- request_id: Request identifier for tracing
|
|
561
|
+
- session_id: Session identifier
|
|
562
|
+
- Any other custom metadata for observability or middleware
|
|
563
|
+
functions: List of function schemas (legacy format)
|
|
564
|
+
tools: List of tool schemas (new format, recommended)
|
|
565
|
+
tool_choice: Tool choice strategy
|
|
566
|
+
system_instruction: System instruction for the model
|
|
567
|
+
**kwargs: Additional provider-specific parameters
|
|
568
|
+
|
|
569
|
+
Returns:
|
|
570
|
+
LLMResponse with generated text and metadata
|
|
571
|
+
"""
|
|
271
572
|
self._init_vertex_ai()
|
|
272
573
|
|
|
273
574
|
# Get model name from config if not provided
|
|
@@ -281,17 +582,37 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
281
582
|
try:
|
|
282
583
|
# Extract system message from messages if present
|
|
283
584
|
system_msg = None
|
|
585
|
+
system_cache_control = None
|
|
284
586
|
user_messages = []
|
|
285
587
|
for msg in messages:
|
|
286
588
|
if msg.role == "system":
|
|
287
589
|
system_msg = msg.content
|
|
590
|
+
system_cache_control = msg.cache_control
|
|
288
591
|
else:
|
|
289
592
|
user_messages.append(msg)
|
|
290
593
|
|
|
291
594
|
# Use explicit system_instruction parameter if provided, else use extracted system message
|
|
292
595
|
final_system_instruction = system_instruction or system_msg
|
|
293
596
|
|
|
597
|
+
# Check if we should use CachedContent API for prompt caching
|
|
598
|
+
cached_content_id = None
|
|
599
|
+
if final_system_instruction and system_cache_control:
|
|
600
|
+
# Create or get CachedContent for the system instruction
|
|
601
|
+
# Extract TTL from cache_control if available (defaults to 3600 seconds)
|
|
602
|
+
ttl_seconds = getattr(system_cache_control, 'ttl_seconds', None) or 3600
|
|
603
|
+
cached_content_id = await self._create_or_get_cached_content(
|
|
604
|
+
content=final_system_instruction,
|
|
605
|
+
model_name=model_name,
|
|
606
|
+
ttl_seconds=ttl_seconds,
|
|
607
|
+
)
|
|
608
|
+
if cached_content_id:
|
|
609
|
+
self.logger.debug(f"Using CachedContent for prompt caching: {cached_content_id}")
|
|
610
|
+
# When using CachedContent, we don't pass system_instruction to GenerativeModel
|
|
611
|
+
# Instead, we'll pass cached_content_id to generate_content
|
|
612
|
+
final_system_instruction = None
|
|
613
|
+
|
|
294
614
|
# Initialize model WITH system instruction for prompt caching support
|
|
615
|
+
# Note: If using CachedContent, system_instruction will be None
|
|
295
616
|
model_instance = GenerativeModel(
|
|
296
617
|
model_name,
|
|
297
618
|
system_instruction=final_system_instruction
|
|
@@ -362,13 +683,18 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
362
683
|
"safety_settings": safety_settings,
|
|
363
684
|
}
|
|
364
685
|
|
|
686
|
+
# Add cached_content if using CachedContent API for prompt caching
|
|
687
|
+
if cached_content_id:
|
|
688
|
+
api_params["cached_content"] = cached_content_id
|
|
689
|
+
self.logger.debug(f"Added cached_content to API params: {cached_content_id}")
|
|
690
|
+
|
|
365
691
|
# Add tools if available
|
|
366
692
|
if tools_for_api:
|
|
367
693
|
api_params["tools"] = tools_for_api
|
|
368
694
|
|
|
369
|
-
# Add any additional kwargs (but exclude tools/safety_settings to avoid conflicts)
|
|
695
|
+
# Add any additional kwargs (but exclude tools/safety_settings/cached_content to avoid conflicts)
|
|
370
696
|
for key, value in kwargs.items():
|
|
371
|
-
if key not in ["tools", "safety_settings"]:
|
|
697
|
+
if key not in ["tools", "safety_settings", "cached_content"]:
|
|
372
698
|
api_params[key] = value
|
|
373
699
|
|
|
374
700
|
response = await asyncio.get_event_loop().run_in_executor(
|
|
@@ -553,7 +879,9 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
553
879
|
|
|
554
880
|
# Vertex AI doesn't provide detailed token usage in the response
|
|
555
881
|
# Use estimation method as fallback
|
|
556
|
-
|
|
882
|
+
# Estimate input tokens from messages content
|
|
883
|
+
prompt_text = " ".join(msg.content for msg in messages if msg.content)
|
|
884
|
+
input_tokens = self._count_tokens_estimate(prompt_text)
|
|
557
885
|
output_tokens = self._count_tokens_estimate(content)
|
|
558
886
|
tokens_used = input_tokens + output_tokens
|
|
559
887
|
|
|
@@ -608,7 +936,9 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
608
936
|
):
|
|
609
937
|
self.logger.warning(f"Vertex AI response issue: {str(e)}")
|
|
610
938
|
# Return a response indicating the issue
|
|
611
|
-
|
|
939
|
+
# Estimate prompt tokens from messages content
|
|
940
|
+
prompt_text = " ".join(msg.content for msg in messages if msg.content)
|
|
941
|
+
estimated_prompt_tokens = self._count_tokens_estimate(prompt_text)
|
|
612
942
|
return LLMResponse(
|
|
613
943
|
content="[Response unavailable due to content processing issues or safety filters]",
|
|
614
944
|
provider=self.provider_name,
|
|
@@ -626,6 +956,7 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
626
956
|
model: Optional[str] = None,
|
|
627
957
|
temperature: float = 0.7,
|
|
628
958
|
max_tokens: Optional[int] = None,
|
|
959
|
+
context: Optional[Dict[str, Any]] = None,
|
|
629
960
|
functions: Optional[List[Dict[str, Any]]] = None,
|
|
630
961
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
631
962
|
tool_choice: Optional[Any] = None,
|
|
@@ -641,6 +972,12 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
641
972
|
model: Model name (optional)
|
|
642
973
|
temperature: Temperature for generation
|
|
643
974
|
max_tokens: Maximum tokens to generate
|
|
975
|
+
context: Optional context dictionary containing metadata such as:
|
|
976
|
+
- user_id: User identifier for tracking/billing
|
|
977
|
+
- tenant_id: Tenant identifier for multi-tenant setups
|
|
978
|
+
- request_id: Request identifier for tracing
|
|
979
|
+
- session_id: Session identifier
|
|
980
|
+
- Any other custom metadata for observability or middleware
|
|
644
981
|
functions: List of function schemas (legacy format)
|
|
645
982
|
tools: List of tool schemas (new format)
|
|
646
983
|
tool_choice: Tool choice strategy (not used for Google Vertex AI)
|
|
@@ -664,17 +1001,37 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
664
1001
|
try:
|
|
665
1002
|
# Extract system message from messages if present
|
|
666
1003
|
system_msg = None
|
|
1004
|
+
system_cache_control = None
|
|
667
1005
|
user_messages = []
|
|
668
1006
|
for msg in messages:
|
|
669
1007
|
if msg.role == "system":
|
|
670
1008
|
system_msg = msg.content
|
|
1009
|
+
system_cache_control = msg.cache_control
|
|
671
1010
|
else:
|
|
672
1011
|
user_messages.append(msg)
|
|
673
1012
|
|
|
674
1013
|
# Use explicit system_instruction parameter if provided, else use extracted system message
|
|
675
1014
|
final_system_instruction = system_instruction or system_msg
|
|
676
1015
|
|
|
1016
|
+
# Check if we should use CachedContent API for prompt caching
|
|
1017
|
+
cached_content_id = None
|
|
1018
|
+
if final_system_instruction and system_cache_control:
|
|
1019
|
+
# Create or get CachedContent for the system instruction
|
|
1020
|
+
# Extract TTL from cache_control if available (defaults to 3600 seconds)
|
|
1021
|
+
ttl_seconds = getattr(system_cache_control, 'ttl_seconds', None) or 3600
|
|
1022
|
+
cached_content_id = await self._create_or_get_cached_content(
|
|
1023
|
+
content=final_system_instruction,
|
|
1024
|
+
model_name=model_name,
|
|
1025
|
+
ttl_seconds=ttl_seconds,
|
|
1026
|
+
)
|
|
1027
|
+
if cached_content_id:
|
|
1028
|
+
self.logger.debug(f"Using CachedContent for prompt caching in streaming: {cached_content_id}")
|
|
1029
|
+
# When using CachedContent, we don't pass system_instruction to GenerativeModel
|
|
1030
|
+
# Instead, we'll pass cached_content_id to generate_content
|
|
1031
|
+
final_system_instruction = None
|
|
1032
|
+
|
|
677
1033
|
# Initialize model WITH system instruction for prompt caching support
|
|
1034
|
+
# Note: If using CachedContent, system_instruction will be None
|
|
678
1035
|
model_instance = GenerativeModel(
|
|
679
1036
|
model_name,
|
|
680
1037
|
system_instruction=final_system_instruction
|
|
@@ -738,6 +1095,12 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
738
1095
|
# Use mixin method for Function Calling support
|
|
739
1096
|
from aiecs.llm.clients.openai_compatible_mixin import StreamChunk
|
|
740
1097
|
|
|
1098
|
+
# Add cached_content to kwargs if using CachedContent API
|
|
1099
|
+
stream_kwargs = kwargs.copy()
|
|
1100
|
+
if cached_content_id:
|
|
1101
|
+
stream_kwargs["cached_content"] = cached_content_id
|
|
1102
|
+
self.logger.debug(f"Added cached_content to streaming API params: {cached_content_id}")
|
|
1103
|
+
|
|
741
1104
|
async for chunk in self._stream_text_with_function_calling(
|
|
742
1105
|
model_instance=model_instance,
|
|
743
1106
|
contents=contents,
|
|
@@ -745,7 +1108,7 @@ class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
|
745
1108
|
safety_settings=safety_settings,
|
|
746
1109
|
tools=tools_for_api,
|
|
747
1110
|
return_chunks=return_chunks,
|
|
748
|
-
**
|
|
1111
|
+
**stream_kwargs,
|
|
749
1112
|
):
|
|
750
1113
|
# Yield chunk (can be str or StreamChunk)
|
|
751
1114
|
yield chunk
|
aiecs/llm/clients/xai_client.py
CHANGED
|
@@ -87,6 +87,7 @@ class XAIClient(BaseLLMClient, OpenAICompatibleFunctionCallingMixin):
|
|
|
87
87
|
model: Optional[str] = None,
|
|
88
88
|
temperature: float = 0.7,
|
|
89
89
|
max_tokens: Optional[int] = None,
|
|
90
|
+
context: Optional[Dict[str, Any]] = None,
|
|
90
91
|
functions: Optional[List[Dict[str, Any]]] = None,
|
|
91
92
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
92
93
|
tool_choice: Optional[Any] = None,
|
|
@@ -94,8 +95,27 @@ class XAIClient(BaseLLMClient, OpenAICompatibleFunctionCallingMixin):
|
|
|
94
95
|
) -> LLMResponse:
|
|
95
96
|
"""
|
|
96
97
|
Generate text using xAI API via OpenAI library (supports all Grok models).
|
|
97
|
-
|
|
98
|
+
|
|
98
99
|
xAI API is OpenAI-compatible, so it supports Function Calling.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
messages: List of conversation messages
|
|
103
|
+
model: Model name (optional, uses default if not provided)
|
|
104
|
+
temperature: Sampling temperature (0.0 to 1.0)
|
|
105
|
+
max_tokens: Maximum tokens to generate
|
|
106
|
+
context: Optional context dictionary containing metadata such as:
|
|
107
|
+
- user_id: User identifier for tracking/billing
|
|
108
|
+
- tenant_id: Tenant identifier for multi-tenant setups
|
|
109
|
+
- request_id: Request identifier for tracing
|
|
110
|
+
- session_id: Session identifier
|
|
111
|
+
- Any other custom metadata for observability or middleware
|
|
112
|
+
functions: List of function schemas (legacy format)
|
|
113
|
+
tools: List of tool schemas (new format, recommended)
|
|
114
|
+
tool_choice: Tool choice strategy
|
|
115
|
+
**kwargs: Additional provider-specific parameters
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
LLMResponse with generated text and metadata
|
|
99
119
|
"""
|
|
100
120
|
# Check API key availability
|
|
101
121
|
api_key = self._get_api_key()
|
|
@@ -144,6 +164,7 @@ class XAIClient(BaseLLMClient, OpenAICompatibleFunctionCallingMixin):
|
|
|
144
164
|
model: Optional[str] = None,
|
|
145
165
|
temperature: float = 0.7,
|
|
146
166
|
max_tokens: Optional[int] = None,
|
|
167
|
+
context: Optional[Dict[str, Any]] = None,
|
|
147
168
|
functions: Optional[List[Dict[str, Any]]] = None,
|
|
148
169
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
149
170
|
tool_choice: Optional[Any] = None,
|
|
@@ -152,11 +173,28 @@ class XAIClient(BaseLLMClient, OpenAICompatibleFunctionCallingMixin):
|
|
|
152
173
|
) -> AsyncGenerator[Any, None]:
|
|
153
174
|
"""
|
|
154
175
|
Stream text using xAI API via OpenAI library (supports all Grok models).
|
|
155
|
-
|
|
176
|
+
|
|
156
177
|
xAI API is OpenAI-compatible, so it supports Function Calling.
|
|
157
|
-
|
|
178
|
+
|
|
158
179
|
Args:
|
|
180
|
+
messages: List of conversation messages
|
|
181
|
+
model: Model name (optional, uses default if not provided)
|
|
182
|
+
temperature: Sampling temperature (0.0 to 1.0)
|
|
183
|
+
max_tokens: Maximum tokens to generate
|
|
184
|
+
context: Optional context dictionary containing metadata such as:
|
|
185
|
+
- user_id: User identifier for tracking/billing
|
|
186
|
+
- tenant_id: Tenant identifier for multi-tenant setups
|
|
187
|
+
- request_id: Request identifier for tracing
|
|
188
|
+
- session_id: Session identifier
|
|
189
|
+
- Any other custom metadata for observability or middleware
|
|
190
|
+
functions: List of function schemas (legacy format)
|
|
191
|
+
tools: List of tool schemas (new format, recommended)
|
|
192
|
+
tool_choice: Tool choice strategy
|
|
159
193
|
return_chunks: If True, returns StreamChunk objects with tool_calls info; if False, returns str tokens only
|
|
194
|
+
**kwargs: Additional provider-specific parameters
|
|
195
|
+
|
|
196
|
+
Yields:
|
|
197
|
+
str or StreamChunk: Text tokens or StreamChunk objects
|
|
160
198
|
"""
|
|
161
199
|
# Check API key availability
|
|
162
200
|
api_key = self._get_api_key()
|