aiecs 1.0.1__py3-none-any.whl → 1.7.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiecs might be problematic. Click here for more details.
- aiecs/__init__.py +13 -16
- aiecs/__main__.py +7 -7
- aiecs/aiecs_client.py +269 -75
- aiecs/application/executors/operation_executor.py +79 -54
- aiecs/application/knowledge_graph/__init__.py +7 -0
- aiecs/application/knowledge_graph/builder/__init__.py +37 -0
- aiecs/application/knowledge_graph/builder/data_quality.py +302 -0
- aiecs/application/knowledge_graph/builder/data_reshaping.py +293 -0
- aiecs/application/knowledge_graph/builder/document_builder.py +369 -0
- aiecs/application/knowledge_graph/builder/graph_builder.py +490 -0
- aiecs/application/knowledge_graph/builder/import_optimizer.py +396 -0
- aiecs/application/knowledge_graph/builder/schema_inference.py +462 -0
- aiecs/application/knowledge_graph/builder/schema_mapping.py +563 -0
- aiecs/application/knowledge_graph/builder/structured_pipeline.py +1384 -0
- aiecs/application/knowledge_graph/builder/text_chunker.py +317 -0
- aiecs/application/knowledge_graph/extractors/__init__.py +27 -0
- aiecs/application/knowledge_graph/extractors/base.py +98 -0
- aiecs/application/knowledge_graph/extractors/llm_entity_extractor.py +422 -0
- aiecs/application/knowledge_graph/extractors/llm_relation_extractor.py +347 -0
- aiecs/application/knowledge_graph/extractors/ner_entity_extractor.py +241 -0
- aiecs/application/knowledge_graph/fusion/__init__.py +78 -0
- aiecs/application/knowledge_graph/fusion/ab_testing.py +395 -0
- aiecs/application/knowledge_graph/fusion/abbreviation_expander.py +327 -0
- aiecs/application/knowledge_graph/fusion/alias_index.py +597 -0
- aiecs/application/knowledge_graph/fusion/alias_matcher.py +384 -0
- aiecs/application/knowledge_graph/fusion/cache_coordinator.py +343 -0
- aiecs/application/knowledge_graph/fusion/entity_deduplicator.py +433 -0
- aiecs/application/knowledge_graph/fusion/entity_linker.py +511 -0
- aiecs/application/knowledge_graph/fusion/evaluation_dataset.py +240 -0
- aiecs/application/knowledge_graph/fusion/knowledge_fusion.py +632 -0
- aiecs/application/knowledge_graph/fusion/matching_config.py +489 -0
- aiecs/application/knowledge_graph/fusion/name_normalizer.py +352 -0
- aiecs/application/knowledge_graph/fusion/relation_deduplicator.py +183 -0
- aiecs/application/knowledge_graph/fusion/semantic_name_matcher.py +464 -0
- aiecs/application/knowledge_graph/fusion/similarity_pipeline.py +534 -0
- aiecs/application/knowledge_graph/pattern_matching/__init__.py +21 -0
- aiecs/application/knowledge_graph/pattern_matching/pattern_matcher.py +342 -0
- aiecs/application/knowledge_graph/pattern_matching/query_executor.py +366 -0
- aiecs/application/knowledge_graph/profiling/__init__.py +12 -0
- aiecs/application/knowledge_graph/profiling/query_plan_visualizer.py +195 -0
- aiecs/application/knowledge_graph/profiling/query_profiler.py +223 -0
- aiecs/application/knowledge_graph/reasoning/__init__.py +27 -0
- aiecs/application/knowledge_graph/reasoning/evidence_synthesis.py +341 -0
- aiecs/application/knowledge_graph/reasoning/inference_engine.py +500 -0
- aiecs/application/knowledge_graph/reasoning/logic_form_parser.py +163 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/__init__.py +79 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_builder.py +513 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_nodes.py +913 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_validator.py +866 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/error_handler.py +475 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/parser.py +396 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/query_context.py +208 -0
- aiecs/application/knowledge_graph/reasoning/logic_query_integration.py +170 -0
- aiecs/application/knowledge_graph/reasoning/query_planner.py +855 -0
- aiecs/application/knowledge_graph/reasoning/reasoning_engine.py +518 -0
- aiecs/application/knowledge_graph/retrieval/__init__.py +27 -0
- aiecs/application/knowledge_graph/retrieval/query_intent_classifier.py +211 -0
- aiecs/application/knowledge_graph/retrieval/retrieval_strategies.py +592 -0
- aiecs/application/knowledge_graph/retrieval/strategy_types.py +23 -0
- aiecs/application/knowledge_graph/search/__init__.py +59 -0
- aiecs/application/knowledge_graph/search/hybrid_search.py +457 -0
- aiecs/application/knowledge_graph/search/reranker.py +293 -0
- aiecs/application/knowledge_graph/search/reranker_strategies.py +535 -0
- aiecs/application/knowledge_graph/search/text_similarity.py +392 -0
- aiecs/application/knowledge_graph/traversal/__init__.py +15 -0
- aiecs/application/knowledge_graph/traversal/enhanced_traversal.py +305 -0
- aiecs/application/knowledge_graph/traversal/path_scorer.py +271 -0
- aiecs/application/knowledge_graph/validators/__init__.py +13 -0
- aiecs/application/knowledge_graph/validators/relation_validator.py +239 -0
- aiecs/application/knowledge_graph/visualization/__init__.py +11 -0
- aiecs/application/knowledge_graph/visualization/graph_visualizer.py +313 -0
- aiecs/common/__init__.py +9 -0
- aiecs/common/knowledge_graph/__init__.py +17 -0
- aiecs/common/knowledge_graph/runnable.py +471 -0
- aiecs/config/__init__.py +20 -5
- aiecs/config/config.py +762 -31
- aiecs/config/graph_config.py +131 -0
- aiecs/config/tool_config.py +435 -0
- aiecs/core/__init__.py +29 -13
- aiecs/core/interface/__init__.py +2 -2
- aiecs/core/interface/execution_interface.py +22 -22
- aiecs/core/interface/storage_interface.py +37 -88
- aiecs/core/registry/__init__.py +31 -0
- aiecs/core/registry/service_registry.py +92 -0
- aiecs/domain/__init__.py +270 -1
- aiecs/domain/agent/__init__.py +191 -0
- aiecs/domain/agent/base_agent.py +3949 -0
- aiecs/domain/agent/exceptions.py +99 -0
- aiecs/domain/agent/graph_aware_mixin.py +569 -0
- aiecs/domain/agent/hybrid_agent.py +1731 -0
- aiecs/domain/agent/integration/__init__.py +29 -0
- aiecs/domain/agent/integration/context_compressor.py +216 -0
- aiecs/domain/agent/integration/context_engine_adapter.py +587 -0
- aiecs/domain/agent/integration/protocols.py +281 -0
- aiecs/domain/agent/integration/retry_policy.py +218 -0
- aiecs/domain/agent/integration/role_config.py +213 -0
- aiecs/domain/agent/knowledge_aware_agent.py +1892 -0
- aiecs/domain/agent/lifecycle.py +291 -0
- aiecs/domain/agent/llm_agent.py +692 -0
- aiecs/domain/agent/memory/__init__.py +12 -0
- aiecs/domain/agent/memory/conversation.py +1124 -0
- aiecs/domain/agent/migration/__init__.py +14 -0
- aiecs/domain/agent/migration/conversion.py +163 -0
- aiecs/domain/agent/migration/legacy_wrapper.py +86 -0
- aiecs/domain/agent/models.py +894 -0
- aiecs/domain/agent/observability.py +479 -0
- aiecs/domain/agent/persistence.py +449 -0
- aiecs/domain/agent/prompts/__init__.py +29 -0
- aiecs/domain/agent/prompts/builder.py +159 -0
- aiecs/domain/agent/prompts/formatters.py +187 -0
- aiecs/domain/agent/prompts/template.py +255 -0
- aiecs/domain/agent/registry.py +253 -0
- aiecs/domain/agent/tool_agent.py +444 -0
- aiecs/domain/agent/tools/__init__.py +15 -0
- aiecs/domain/agent/tools/schema_generator.py +377 -0
- aiecs/domain/community/__init__.py +155 -0
- aiecs/domain/community/agent_adapter.py +469 -0
- aiecs/domain/community/analytics.py +432 -0
- aiecs/domain/community/collaborative_workflow.py +648 -0
- aiecs/domain/community/communication_hub.py +634 -0
- aiecs/domain/community/community_builder.py +320 -0
- aiecs/domain/community/community_integration.py +796 -0
- aiecs/domain/community/community_manager.py +803 -0
- aiecs/domain/community/decision_engine.py +849 -0
- aiecs/domain/community/exceptions.py +231 -0
- aiecs/domain/community/models/__init__.py +33 -0
- aiecs/domain/community/models/community_models.py +234 -0
- aiecs/domain/community/resource_manager.py +461 -0
- aiecs/domain/community/shared_context_manager.py +589 -0
- aiecs/domain/context/__init__.py +40 -10
- aiecs/domain/context/context_engine.py +1910 -0
- aiecs/domain/context/conversation_models.py +87 -53
- aiecs/domain/context/graph_memory.py +582 -0
- aiecs/domain/execution/model.py +12 -4
- aiecs/domain/knowledge_graph/__init__.py +19 -0
- aiecs/domain/knowledge_graph/models/__init__.py +52 -0
- aiecs/domain/knowledge_graph/models/entity.py +148 -0
- aiecs/domain/knowledge_graph/models/evidence.py +178 -0
- aiecs/domain/knowledge_graph/models/inference_rule.py +184 -0
- aiecs/domain/knowledge_graph/models/path.py +171 -0
- aiecs/domain/knowledge_graph/models/path_pattern.py +171 -0
- aiecs/domain/knowledge_graph/models/query.py +261 -0
- aiecs/domain/knowledge_graph/models/query_plan.py +181 -0
- aiecs/domain/knowledge_graph/models/relation.py +202 -0
- aiecs/domain/knowledge_graph/schema/__init__.py +23 -0
- aiecs/domain/knowledge_graph/schema/entity_type.py +131 -0
- aiecs/domain/knowledge_graph/schema/graph_schema.py +253 -0
- aiecs/domain/knowledge_graph/schema/property_schema.py +143 -0
- aiecs/domain/knowledge_graph/schema/relation_type.py +163 -0
- aiecs/domain/knowledge_graph/schema/schema_manager.py +691 -0
- aiecs/domain/knowledge_graph/schema/type_enums.py +209 -0
- aiecs/domain/task/dsl_processor.py +172 -56
- aiecs/domain/task/model.py +20 -8
- aiecs/domain/task/task_context.py +27 -24
- aiecs/infrastructure/__init__.py +0 -2
- aiecs/infrastructure/graph_storage/__init__.py +11 -0
- aiecs/infrastructure/graph_storage/base.py +837 -0
- aiecs/infrastructure/graph_storage/batch_operations.py +458 -0
- aiecs/infrastructure/graph_storage/cache.py +424 -0
- aiecs/infrastructure/graph_storage/distributed.py +223 -0
- aiecs/infrastructure/graph_storage/error_handling.py +380 -0
- aiecs/infrastructure/graph_storage/graceful_degradation.py +294 -0
- aiecs/infrastructure/graph_storage/health_checks.py +378 -0
- aiecs/infrastructure/graph_storage/in_memory.py +1197 -0
- aiecs/infrastructure/graph_storage/index_optimization.py +446 -0
- aiecs/infrastructure/graph_storage/lazy_loading.py +431 -0
- aiecs/infrastructure/graph_storage/metrics.py +344 -0
- aiecs/infrastructure/graph_storage/migration.py +400 -0
- aiecs/infrastructure/graph_storage/pagination.py +483 -0
- aiecs/infrastructure/graph_storage/performance_monitoring.py +456 -0
- aiecs/infrastructure/graph_storage/postgres.py +1563 -0
- aiecs/infrastructure/graph_storage/property_storage.py +353 -0
- aiecs/infrastructure/graph_storage/protocols.py +76 -0
- aiecs/infrastructure/graph_storage/query_optimizer.py +642 -0
- aiecs/infrastructure/graph_storage/schema_cache.py +290 -0
- aiecs/infrastructure/graph_storage/sqlite.py +1373 -0
- aiecs/infrastructure/graph_storage/streaming.py +487 -0
- aiecs/infrastructure/graph_storage/tenant.py +412 -0
- aiecs/infrastructure/messaging/celery_task_manager.py +92 -54
- aiecs/infrastructure/messaging/websocket_manager.py +51 -35
- aiecs/infrastructure/monitoring/__init__.py +22 -0
- aiecs/infrastructure/monitoring/executor_metrics.py +45 -11
- aiecs/infrastructure/monitoring/global_metrics_manager.py +212 -0
- aiecs/infrastructure/monitoring/structured_logger.py +3 -7
- aiecs/infrastructure/monitoring/tracing_manager.py +63 -35
- aiecs/infrastructure/persistence/__init__.py +14 -1
- aiecs/infrastructure/persistence/context_engine_client.py +184 -0
- aiecs/infrastructure/persistence/database_manager.py +67 -43
- aiecs/infrastructure/persistence/file_storage.py +180 -103
- aiecs/infrastructure/persistence/redis_client.py +74 -21
- aiecs/llm/__init__.py +73 -25
- aiecs/llm/callbacks/__init__.py +11 -0
- aiecs/llm/{custom_callbacks.py → callbacks/custom_callbacks.py} +26 -19
- aiecs/llm/client_factory.py +230 -37
- aiecs/llm/client_resolver.py +155 -0
- aiecs/llm/clients/__init__.py +38 -0
- aiecs/llm/clients/base_client.py +328 -0
- aiecs/llm/clients/google_function_calling_mixin.py +415 -0
- aiecs/llm/clients/googleai_client.py +314 -0
- aiecs/llm/clients/openai_client.py +158 -0
- aiecs/llm/clients/openai_compatible_mixin.py +367 -0
- aiecs/llm/clients/vertex_client.py +1186 -0
- aiecs/llm/clients/xai_client.py +201 -0
- aiecs/llm/config/__init__.py +51 -0
- aiecs/llm/config/config_loader.py +272 -0
- aiecs/llm/config/config_validator.py +206 -0
- aiecs/llm/config/model_config.py +143 -0
- aiecs/llm/protocols.py +149 -0
- aiecs/llm/utils/__init__.py +10 -0
- aiecs/llm/utils/validate_config.py +89 -0
- aiecs/main.py +140 -121
- aiecs/scripts/aid/VERSION_MANAGEMENT.md +138 -0
- aiecs/scripts/aid/__init__.py +19 -0
- aiecs/scripts/aid/module_checker.py +499 -0
- aiecs/scripts/aid/version_manager.py +235 -0
- aiecs/scripts/{DEPENDENCY_SYSTEM_SUMMARY.md → dependance_check/DEPENDENCY_SYSTEM_SUMMARY.md} +1 -0
- aiecs/scripts/{README_DEPENDENCY_CHECKER.md → dependance_check/README_DEPENDENCY_CHECKER.md} +1 -0
- aiecs/scripts/dependance_check/__init__.py +15 -0
- aiecs/scripts/dependance_check/dependency_checker.py +1835 -0
- aiecs/scripts/{dependency_fixer.py → dependance_check/dependency_fixer.py} +192 -90
- aiecs/scripts/{download_nlp_data.py → dependance_check/download_nlp_data.py} +203 -71
- aiecs/scripts/dependance_patch/__init__.py +7 -0
- aiecs/scripts/dependance_patch/fix_weasel/__init__.py +11 -0
- aiecs/scripts/{fix_weasel_validator.py → dependance_patch/fix_weasel/fix_weasel_validator.py} +21 -14
- aiecs/scripts/{patch_weasel_library.sh → dependance_patch/fix_weasel/patch_weasel_library.sh} +1 -1
- aiecs/scripts/knowledge_graph/__init__.py +3 -0
- aiecs/scripts/knowledge_graph/run_threshold_experiments.py +212 -0
- aiecs/scripts/migrations/multi_tenancy/README.md +142 -0
- aiecs/scripts/tools_develop/README.md +671 -0
- aiecs/scripts/tools_develop/README_CONFIG_CHECKER.md +273 -0
- aiecs/scripts/tools_develop/TOOLS_CONFIG_GUIDE.md +1287 -0
- aiecs/scripts/tools_develop/TOOL_AUTO_DISCOVERY.md +234 -0
- aiecs/scripts/tools_develop/__init__.py +21 -0
- aiecs/scripts/tools_develop/check_all_tools_config.py +548 -0
- aiecs/scripts/tools_develop/check_type_annotations.py +257 -0
- aiecs/scripts/tools_develop/pre-commit-schema-coverage.sh +66 -0
- aiecs/scripts/tools_develop/schema_coverage.py +511 -0
- aiecs/scripts/tools_develop/validate_tool_schemas.py +475 -0
- aiecs/scripts/tools_develop/verify_executor_config_fix.py +98 -0
- aiecs/scripts/tools_develop/verify_tools.py +352 -0
- aiecs/tasks/__init__.py +0 -1
- aiecs/tasks/worker.py +115 -47
- aiecs/tools/__init__.py +194 -72
- aiecs/tools/apisource/__init__.py +99 -0
- aiecs/tools/apisource/intelligence/__init__.py +19 -0
- aiecs/tools/apisource/intelligence/data_fusion.py +632 -0
- aiecs/tools/apisource/intelligence/query_analyzer.py +417 -0
- aiecs/tools/apisource/intelligence/search_enhancer.py +385 -0
- aiecs/tools/apisource/monitoring/__init__.py +9 -0
- aiecs/tools/apisource/monitoring/metrics.py +330 -0
- aiecs/tools/apisource/providers/__init__.py +112 -0
- aiecs/tools/apisource/providers/base.py +671 -0
- aiecs/tools/apisource/providers/census.py +397 -0
- aiecs/tools/apisource/providers/fred.py +535 -0
- aiecs/tools/apisource/providers/newsapi.py +409 -0
- aiecs/tools/apisource/providers/worldbank.py +352 -0
- aiecs/tools/apisource/reliability/__init__.py +12 -0
- aiecs/tools/apisource/reliability/error_handler.py +363 -0
- aiecs/tools/apisource/reliability/fallback_strategy.py +376 -0
- aiecs/tools/apisource/tool.py +832 -0
- aiecs/tools/apisource/utils/__init__.py +9 -0
- aiecs/tools/apisource/utils/validators.py +334 -0
- aiecs/tools/base_tool.py +415 -21
- aiecs/tools/docs/__init__.py +121 -0
- aiecs/tools/docs/ai_document_orchestrator.py +607 -0
- aiecs/tools/docs/ai_document_writer_orchestrator.py +2350 -0
- aiecs/tools/docs/content_insertion_tool.py +1320 -0
- aiecs/tools/docs/document_creator_tool.py +1464 -0
- aiecs/tools/docs/document_layout_tool.py +1160 -0
- aiecs/tools/docs/document_parser_tool.py +1016 -0
- aiecs/tools/docs/document_writer_tool.py +2008 -0
- aiecs/tools/knowledge_graph/__init__.py +17 -0
- aiecs/tools/knowledge_graph/graph_reasoning_tool.py +807 -0
- aiecs/tools/knowledge_graph/graph_search_tool.py +944 -0
- aiecs/tools/knowledge_graph/kg_builder_tool.py +524 -0
- aiecs/tools/langchain_adapter.py +300 -138
- aiecs/tools/schema_generator.py +455 -0
- aiecs/tools/search_tool/__init__.py +100 -0
- aiecs/tools/search_tool/analyzers.py +581 -0
- aiecs/tools/search_tool/cache.py +264 -0
- aiecs/tools/search_tool/constants.py +128 -0
- aiecs/tools/search_tool/context.py +224 -0
- aiecs/tools/search_tool/core.py +778 -0
- aiecs/tools/search_tool/deduplicator.py +119 -0
- aiecs/tools/search_tool/error_handler.py +242 -0
- aiecs/tools/search_tool/metrics.py +343 -0
- aiecs/tools/search_tool/rate_limiter.py +172 -0
- aiecs/tools/search_tool/schemas.py +275 -0
- aiecs/tools/statistics/__init__.py +80 -0
- aiecs/tools/statistics/ai_data_analysis_orchestrator.py +646 -0
- aiecs/tools/statistics/ai_insight_generator_tool.py +508 -0
- aiecs/tools/statistics/ai_report_orchestrator_tool.py +684 -0
- aiecs/tools/statistics/data_loader_tool.py +555 -0
- aiecs/tools/statistics/data_profiler_tool.py +638 -0
- aiecs/tools/statistics/data_transformer_tool.py +580 -0
- aiecs/tools/statistics/data_visualizer_tool.py +498 -0
- aiecs/tools/statistics/model_trainer_tool.py +507 -0
- aiecs/tools/statistics/statistical_analyzer_tool.py +472 -0
- aiecs/tools/task_tools/__init__.py +49 -36
- aiecs/tools/task_tools/chart_tool.py +200 -184
- aiecs/tools/task_tools/classfire_tool.py +268 -267
- aiecs/tools/task_tools/image_tool.py +220 -141
- aiecs/tools/task_tools/office_tool.py +226 -146
- aiecs/tools/task_tools/pandas_tool.py +477 -121
- aiecs/tools/task_tools/report_tool.py +390 -142
- aiecs/tools/task_tools/research_tool.py +149 -79
- aiecs/tools/task_tools/scraper_tool.py +339 -145
- aiecs/tools/task_tools/stats_tool.py +448 -209
- aiecs/tools/temp_file_manager.py +26 -24
- aiecs/tools/tool_executor/__init__.py +18 -16
- aiecs/tools/tool_executor/tool_executor.py +364 -52
- aiecs/utils/LLM_output_structor.py +74 -48
- aiecs/utils/__init__.py +14 -3
- aiecs/utils/base_callback.py +0 -3
- aiecs/utils/cache_provider.py +696 -0
- aiecs/utils/execution_utils.py +50 -31
- aiecs/utils/prompt_loader.py +1 -0
- aiecs/utils/token_usage_repository.py +37 -11
- aiecs/ws/socket_server.py +14 -4
- {aiecs-1.0.1.dist-info → aiecs-1.7.17.dist-info}/METADATA +52 -15
- aiecs-1.7.17.dist-info/RECORD +337 -0
- aiecs-1.7.17.dist-info/entry_points.txt +13 -0
- aiecs/config/registry.py +0 -19
- aiecs/domain/context/content_engine.py +0 -982
- aiecs/llm/base_client.py +0 -99
- aiecs/llm/openai_client.py +0 -125
- aiecs/llm/vertex_client.py +0 -186
- aiecs/llm/xai_client.py +0 -184
- aiecs/scripts/dependency_checker.py +0 -857
- aiecs/scripts/quick_dependency_check.py +0 -269
- aiecs/tools/task_tools/search_api.py +0 -7
- aiecs-1.0.1.dist-info/RECORD +0 -90
- aiecs-1.0.1.dist-info/entry_points.txt +0 -7
- /aiecs/scripts/{setup_nlp_data.sh → dependance_check/setup_nlp_data.sh} +0 -0
- /aiecs/scripts/{README_WEASEL_PATCH.md → dependance_patch/fix_weasel/README_WEASEL_PATCH.md} +0 -0
- /aiecs/scripts/{fix_weasel_validator.sh → dependance_patch/fix_weasel/fix_weasel_validator.sh} +0 -0
- /aiecs/scripts/{run_weasel_patch.sh → dependance_patch/fix_weasel/run_weasel_patch.sh} +0 -0
- {aiecs-1.0.1.dist-info → aiecs-1.7.17.dist-info}/WHEEL +0 -0
- {aiecs-1.0.1.dist-info → aiecs-1.7.17.dist-info}/licenses/LICENSE +0 -0
- {aiecs-1.0.1.dist-info → aiecs-1.7.17.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1186 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import warnings
|
|
6
|
+
import hashlib
|
|
7
|
+
from typing import Dict, Any, Optional, List, AsyncGenerator, Union
|
|
8
|
+
import vertexai
|
|
9
|
+
from vertexai.generative_models import (
|
|
10
|
+
GenerativeModel,
|
|
11
|
+
HarmCategory,
|
|
12
|
+
HarmBlockThreshold,
|
|
13
|
+
GenerationConfig,
|
|
14
|
+
SafetySetting,
|
|
15
|
+
Content,
|
|
16
|
+
Part,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
# Try to import CachedContent for prompt caching support
|
|
22
|
+
# CachedContent API requires google-cloud-aiplatform >= 1.38.0
|
|
23
|
+
# Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/cached-content
|
|
24
|
+
CACHED_CONTENT_AVAILABLE = False
|
|
25
|
+
CACHED_CONTENT_IMPORT_PATH = None
|
|
26
|
+
CACHED_CONTENT_SDK_VERSION = None
|
|
27
|
+
|
|
28
|
+
# Check SDK version
|
|
29
|
+
try:
|
|
30
|
+
import google.cloud.aiplatform as aiplatform
|
|
31
|
+
CACHED_CONTENT_SDK_VERSION = getattr(aiplatform, '__version__', None)
|
|
32
|
+
except ImportError:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
# Try to import CachedContent for prompt caching support
|
|
36
|
+
try:
|
|
37
|
+
from vertexai.preview import caching
|
|
38
|
+
if hasattr(caching, 'CachedContent'):
|
|
39
|
+
CACHED_CONTENT_AVAILABLE = True
|
|
40
|
+
CACHED_CONTENT_IMPORT_PATH = 'vertexai.preview.caching'
|
|
41
|
+
else:
|
|
42
|
+
# Module exists but CachedContent class not found
|
|
43
|
+
CACHED_CONTENT_AVAILABLE = False
|
|
44
|
+
except ImportError:
|
|
45
|
+
try:
|
|
46
|
+
# Alternative import path for different SDK versions
|
|
47
|
+
from vertexai import caching
|
|
48
|
+
if hasattr(caching, 'CachedContent'):
|
|
49
|
+
CACHED_CONTENT_AVAILABLE = True
|
|
50
|
+
CACHED_CONTENT_IMPORT_PATH = 'vertexai.caching'
|
|
51
|
+
else:
|
|
52
|
+
CACHED_CONTENT_AVAILABLE = False
|
|
53
|
+
except ImportError:
|
|
54
|
+
CACHED_CONTENT_AVAILABLE = False
|
|
55
|
+
|
|
56
|
+
from aiecs.llm.clients.base_client import (
|
|
57
|
+
BaseLLMClient,
|
|
58
|
+
LLMMessage,
|
|
59
|
+
LLMResponse,
|
|
60
|
+
ProviderNotAvailableError,
|
|
61
|
+
RateLimitError,
|
|
62
|
+
SafetyBlockError,
|
|
63
|
+
)
|
|
64
|
+
from aiecs.llm.clients.google_function_calling_mixin import GoogleFunctionCallingMixin
|
|
65
|
+
from aiecs.config.config import get_settings
|
|
66
|
+
|
|
67
|
+
# Suppress Vertex AI SDK deprecation warnings (deprecated June 2025, removal June 2026)
|
|
68
|
+
# TODO: Migrate to Google Gen AI SDK when official migration guide is available
|
|
69
|
+
warnings.filterwarnings(
|
|
70
|
+
"ignore",
|
|
71
|
+
category=UserWarning,
|
|
72
|
+
module="vertexai.generative_models._generative_models",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
logger = logging.getLogger(__name__)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _extract_safety_ratings(safety_ratings: Any) -> List[Dict[str, Any]]:
|
|
79
|
+
"""
|
|
80
|
+
Extract safety ratings information from Vertex AI response.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
safety_ratings: Safety ratings object from Vertex AI response
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
List of dictionaries containing safety rating details
|
|
87
|
+
"""
|
|
88
|
+
ratings_list = []
|
|
89
|
+
if not safety_ratings:
|
|
90
|
+
return ratings_list
|
|
91
|
+
|
|
92
|
+
# Handle both list and single object
|
|
93
|
+
ratings_iter = safety_ratings if isinstance(safety_ratings, list) else [safety_ratings]
|
|
94
|
+
|
|
95
|
+
for rating in ratings_iter:
|
|
96
|
+
rating_dict = {}
|
|
97
|
+
|
|
98
|
+
# Extract category
|
|
99
|
+
if hasattr(rating, "category"):
|
|
100
|
+
rating_dict["category"] = str(rating.category)
|
|
101
|
+
elif hasattr(rating, "get"):
|
|
102
|
+
rating_dict["category"] = rating.get("category", "UNKNOWN")
|
|
103
|
+
|
|
104
|
+
# Extract blocked status
|
|
105
|
+
if hasattr(rating, "blocked"):
|
|
106
|
+
rating_dict["blocked"] = bool(rating.blocked)
|
|
107
|
+
elif hasattr(rating, "get"):
|
|
108
|
+
rating_dict["blocked"] = rating.get("blocked", False)
|
|
109
|
+
|
|
110
|
+
# Extract severity (for HarmBlockMethod.SEVERITY)
|
|
111
|
+
if hasattr(rating, "severity"):
|
|
112
|
+
rating_dict["severity"] = str(rating.severity)
|
|
113
|
+
elif hasattr(rating, "get"):
|
|
114
|
+
rating_dict["severity"] = rating.get("severity")
|
|
115
|
+
|
|
116
|
+
if hasattr(rating, "severity_score"):
|
|
117
|
+
rating_dict["severity_score"] = float(rating.severity_score)
|
|
118
|
+
elif hasattr(rating, "get"):
|
|
119
|
+
rating_dict["severity_score"] = rating.get("severity_score")
|
|
120
|
+
|
|
121
|
+
# Extract probability (for HarmBlockMethod.PROBABILITY)
|
|
122
|
+
if hasattr(rating, "probability"):
|
|
123
|
+
rating_dict["probability"] = str(rating.probability)
|
|
124
|
+
elif hasattr(rating, "get"):
|
|
125
|
+
rating_dict["probability"] = rating.get("probability")
|
|
126
|
+
|
|
127
|
+
if hasattr(rating, "probability_score"):
|
|
128
|
+
rating_dict["probability_score"] = float(rating.probability_score)
|
|
129
|
+
elif hasattr(rating, "get"):
|
|
130
|
+
rating_dict["probability_score"] = rating.get("probability_score")
|
|
131
|
+
|
|
132
|
+
ratings_list.append(rating_dict)
|
|
133
|
+
|
|
134
|
+
return ratings_list
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _build_safety_block_error(
|
|
138
|
+
response: Any,
|
|
139
|
+
block_type: str,
|
|
140
|
+
default_message: str,
|
|
141
|
+
) -> SafetyBlockError:
|
|
142
|
+
"""
|
|
143
|
+
Build a detailed SafetyBlockError from Vertex AI response.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
response: Vertex AI response object
|
|
147
|
+
block_type: "prompt" or "response"
|
|
148
|
+
default_message: Default error message
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
SafetyBlockError with detailed information
|
|
152
|
+
"""
|
|
153
|
+
block_reason = None
|
|
154
|
+
safety_ratings = []
|
|
155
|
+
|
|
156
|
+
if block_type == "prompt":
|
|
157
|
+
# Check prompt_feedback for prompt blocks
|
|
158
|
+
if hasattr(response, "prompt_feedback"):
|
|
159
|
+
pf = response.prompt_feedback
|
|
160
|
+
if hasattr(pf, "block_reason"):
|
|
161
|
+
block_reason = str(pf.block_reason)
|
|
162
|
+
elif hasattr(pf, "get"):
|
|
163
|
+
block_reason = pf.get("block_reason")
|
|
164
|
+
|
|
165
|
+
if hasattr(pf, "safety_ratings"):
|
|
166
|
+
safety_ratings = _extract_safety_ratings(pf.safety_ratings)
|
|
167
|
+
elif hasattr(pf, "get"):
|
|
168
|
+
safety_ratings = _extract_safety_ratings(pf.get("safety_ratings", []))
|
|
169
|
+
|
|
170
|
+
elif block_type == "response":
|
|
171
|
+
# Check candidates for response blocks
|
|
172
|
+
if hasattr(response, "candidates") and response.candidates:
|
|
173
|
+
candidate = response.candidates[0]
|
|
174
|
+
if hasattr(candidate, "safety_ratings"):
|
|
175
|
+
safety_ratings = _extract_safety_ratings(candidate.safety_ratings)
|
|
176
|
+
elif hasattr(candidate, "get"):
|
|
177
|
+
safety_ratings = _extract_safety_ratings(candidate.get("safety_ratings", []))
|
|
178
|
+
|
|
179
|
+
# Check finish_reason
|
|
180
|
+
if hasattr(candidate, "finish_reason"):
|
|
181
|
+
finish_reason = str(candidate.finish_reason)
|
|
182
|
+
if finish_reason in ["SAFETY", "RECITATION"]:
|
|
183
|
+
block_reason = finish_reason
|
|
184
|
+
|
|
185
|
+
# Build detailed error message
|
|
186
|
+
error_parts = [default_message]
|
|
187
|
+
if block_reason:
|
|
188
|
+
error_parts.append(f"Block reason: {block_reason}")
|
|
189
|
+
|
|
190
|
+
# Safely extract blocked categories, handling potential non-dict elements
|
|
191
|
+
blocked_categories = []
|
|
192
|
+
for r in safety_ratings:
|
|
193
|
+
if isinstance(r, dict) and r.get("blocked", False):
|
|
194
|
+
blocked_categories.append(r.get("category", "UNKNOWN"))
|
|
195
|
+
if blocked_categories:
|
|
196
|
+
error_parts.append(f"Blocked categories: {', '.join(blocked_categories)}")
|
|
197
|
+
|
|
198
|
+
# Add severity/probability information
|
|
199
|
+
for rating in safety_ratings:
|
|
200
|
+
# Skip non-dict elements
|
|
201
|
+
if not isinstance(rating, dict):
|
|
202
|
+
continue
|
|
203
|
+
if rating.get("blocked"):
|
|
204
|
+
if "severity" in rating:
|
|
205
|
+
error_parts.append(
|
|
206
|
+
f"{rating.get('category', 'UNKNOWN')}: severity={rating.get('severity')}, "
|
|
207
|
+
f"score={rating.get('severity_score', 'N/A')}"
|
|
208
|
+
)
|
|
209
|
+
elif "probability" in rating:
|
|
210
|
+
error_parts.append(
|
|
211
|
+
f"{rating.get('category', 'UNKNOWN')}: probability={rating.get('probability')}, "
|
|
212
|
+
f"score={rating.get('probability_score', 'N/A')}"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
error_message = " | ".join(error_parts)
|
|
216
|
+
|
|
217
|
+
return SafetyBlockError(
|
|
218
|
+
message=error_message,
|
|
219
|
+
block_reason=block_reason,
|
|
220
|
+
block_type=block_type,
|
|
221
|
+
safety_ratings=safety_ratings,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
226
|
+
"""Vertex AI provider client"""
|
|
227
|
+
|
|
228
|
+
def __init__(self):
|
|
229
|
+
super().__init__("Vertex")
|
|
230
|
+
self.settings = get_settings()
|
|
231
|
+
self._initialized = False
|
|
232
|
+
# Track part count statistics for monitoring
|
|
233
|
+
self._part_count_stats = {
|
|
234
|
+
"total_responses": 0,
|
|
235
|
+
"part_counts": {}, # {part_count: frequency}
|
|
236
|
+
"last_part_count": None,
|
|
237
|
+
}
|
|
238
|
+
# Cache for CachedContent objects (key: content hash, value: cached_content_id)
|
|
239
|
+
self._cached_content_cache: Dict[str, str] = {}
|
|
240
|
+
|
|
241
|
+
def _init_vertex_ai(self):
|
|
242
|
+
"""Lazy initialization of Vertex AI with proper authentication"""
|
|
243
|
+
if not self._initialized:
|
|
244
|
+
if not self.settings.vertex_project_id:
|
|
245
|
+
raise ProviderNotAvailableError("Vertex AI project ID not configured")
|
|
246
|
+
|
|
247
|
+
try:
|
|
248
|
+
# Set up Google Cloud authentication
|
|
249
|
+
pass
|
|
250
|
+
|
|
251
|
+
# Check if GOOGLE_APPLICATION_CREDENTIALS is configured
|
|
252
|
+
if self.settings.google_application_credentials:
|
|
253
|
+
credentials_path = self.settings.google_application_credentials
|
|
254
|
+
if os.path.exists(credentials_path):
|
|
255
|
+
# Set the environment variable for Google Cloud SDK
|
|
256
|
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_path
|
|
257
|
+
self.logger.info(f"Using Google Cloud credentials from: {credentials_path}")
|
|
258
|
+
else:
|
|
259
|
+
self.logger.warning(f"Google Cloud credentials file not found: {credentials_path}")
|
|
260
|
+
raise ProviderNotAvailableError(f"Google Cloud credentials file not found: {credentials_path}")
|
|
261
|
+
elif "GOOGLE_APPLICATION_CREDENTIALS" in os.environ:
|
|
262
|
+
self.logger.info("Using Google Cloud credentials from environment variable")
|
|
263
|
+
else:
|
|
264
|
+
self.logger.warning("No Google Cloud credentials configured. Using default authentication.")
|
|
265
|
+
|
|
266
|
+
# Initialize Vertex AI
|
|
267
|
+
vertexai.init(
|
|
268
|
+
project=self.settings.vertex_project_id,
|
|
269
|
+
location=getattr(self.settings, "vertex_location", "us-central1"),
|
|
270
|
+
)
|
|
271
|
+
self._initialized = True
|
|
272
|
+
self.logger.info(f"Vertex AI initialized for project {self.settings.vertex_project_id}")
|
|
273
|
+
|
|
274
|
+
except Exception as e:
|
|
275
|
+
raise ProviderNotAvailableError(f"Failed to initialize Vertex AI: {str(e)}")
|
|
276
|
+
|
|
277
|
+
def _generate_content_hash(self, content: str) -> str:
|
|
278
|
+
"""Generate a hash for content to use as cache key."""
|
|
279
|
+
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
|
280
|
+
|
|
281
|
+
async def _create_or_get_cached_content(
|
|
282
|
+
self,
|
|
283
|
+
content: str,
|
|
284
|
+
model_name: str,
|
|
285
|
+
ttl_seconds: Optional[int] = None,
|
|
286
|
+
) -> Optional[str]:
|
|
287
|
+
"""
|
|
288
|
+
Create or get a CachedContent for the given content.
|
|
289
|
+
|
|
290
|
+
This method implements Gemini's CachedContent API for prompt caching.
|
|
291
|
+
It preserves the existing cache_control mechanism for developer convenience.
|
|
292
|
+
|
|
293
|
+
The method supports multiple Vertex AI SDK versions and gracefully falls back
|
|
294
|
+
to regular system_instruction if CachedContent API is unavailable.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
content: Content to cache (typically system instruction)
|
|
298
|
+
model_name: Model name to use for caching
|
|
299
|
+
ttl_seconds: Time to live in seconds (optional, defaults to 3600)
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
CachedContent resource name (e.g., "projects/.../cachedContents/...") or None if caching unavailable
|
|
303
|
+
"""
|
|
304
|
+
if not CACHED_CONTENT_AVAILABLE:
|
|
305
|
+
# Provide version info if available
|
|
306
|
+
version_info = ""
|
|
307
|
+
if CACHED_CONTENT_SDK_VERSION:
|
|
308
|
+
version_info = f" (SDK version: {CACHED_CONTENT_SDK_VERSION})"
|
|
309
|
+
elif CACHED_CONTENT_IMPORT_PATH:
|
|
310
|
+
version_info = f" (import path '{CACHED_CONTENT_IMPORT_PATH}' available but CachedContent class not found)"
|
|
311
|
+
|
|
312
|
+
self.logger.debug(
|
|
313
|
+
f"CachedContent API not available{version_info}, skipping cache creation. "
|
|
314
|
+
f"Requires google-cloud-aiplatform >=1.38.0"
|
|
315
|
+
)
|
|
316
|
+
return None
|
|
317
|
+
|
|
318
|
+
if not content or not content.strip():
|
|
319
|
+
return None
|
|
320
|
+
|
|
321
|
+
# Generate cache key
|
|
322
|
+
cache_key = self._generate_content_hash(content)
|
|
323
|
+
|
|
324
|
+
# Check if we already have this cached
|
|
325
|
+
if cache_key in self._cached_content_cache:
|
|
326
|
+
cached_content_id = self._cached_content_cache[cache_key]
|
|
327
|
+
self.logger.debug(f"Using existing CachedContent: {cached_content_id}")
|
|
328
|
+
return cached_content_id
|
|
329
|
+
|
|
330
|
+
try:
|
|
331
|
+
self._init_vertex_ai()
|
|
332
|
+
|
|
333
|
+
# Build the content to cache (system instruction as Content)
|
|
334
|
+
# For CachedContent, we typically cache the system instruction
|
|
335
|
+
cached_content_obj = Content(
|
|
336
|
+
role="user",
|
|
337
|
+
parts=[Part.from_text(content)]
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Try different API patterns based on SDK version
|
|
341
|
+
cached_content_id = None
|
|
342
|
+
|
|
343
|
+
# Pattern 1: caching.CachedContent.create() (most common)
|
|
344
|
+
if hasattr(caching, 'CachedContent'):
|
|
345
|
+
try:
|
|
346
|
+
cached_content = await asyncio.get_event_loop().run_in_executor(
|
|
347
|
+
None,
|
|
348
|
+
lambda: caching.CachedContent.create(
|
|
349
|
+
model=model_name,
|
|
350
|
+
contents=[cached_content_obj],
|
|
351
|
+
ttl_seconds=ttl_seconds or 3600, # Default 1 hour
|
|
352
|
+
)
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Extract the resource name
|
|
356
|
+
if hasattr(cached_content, 'name'):
|
|
357
|
+
cached_content_id = cached_content.name
|
|
358
|
+
elif hasattr(cached_content, 'resource_name'):
|
|
359
|
+
cached_content_id = cached_content.resource_name
|
|
360
|
+
else:
|
|
361
|
+
cached_content_id = str(cached_content)
|
|
362
|
+
|
|
363
|
+
if cached_content_id:
|
|
364
|
+
# Store in cache
|
|
365
|
+
self._cached_content_cache[cache_key] = cached_content_id
|
|
366
|
+
self.logger.info(f"Created CachedContent for prompt caching: {cached_content_id}")
|
|
367
|
+
return cached_content_id
|
|
368
|
+
|
|
369
|
+
except AttributeError as e:
|
|
370
|
+
self.logger.debug(f"CachedContent.create() signature may differ: {str(e)}")
|
|
371
|
+
except Exception as e:
|
|
372
|
+
self.logger.debug(f"Failed to create CachedContent using pattern 1: {str(e)}")
|
|
373
|
+
|
|
374
|
+
# Pattern 2: Try alternative API patterns if Pattern 1 fails
|
|
375
|
+
# Note: Different SDK versions may have different APIs
|
|
376
|
+
# This is a fallback that allows graceful degradation
|
|
377
|
+
|
|
378
|
+
# Build informative warning message with version info
|
|
379
|
+
version_info = ""
|
|
380
|
+
if CACHED_CONTENT_SDK_VERSION:
|
|
381
|
+
version_info = f" Current SDK version: {CACHED_CONTENT_SDK_VERSION}."
|
|
382
|
+
else:
|
|
383
|
+
version_info = " Unable to detect SDK version."
|
|
384
|
+
|
|
385
|
+
required_version = ">=1.38.0"
|
|
386
|
+
upgrade_command = "pip install --upgrade 'google-cloud-aiplatform>=1.38.0'"
|
|
387
|
+
|
|
388
|
+
self.logger.warning(
|
|
389
|
+
f"CachedContent API not available or incompatible with current SDK version.{version_info} "
|
|
390
|
+
f"Falling back to system_instruction (prompt caching disabled). "
|
|
391
|
+
f"To enable prompt caching, upgrade to google-cloud-aiplatform {required_version} or later: "
|
|
392
|
+
f"{upgrade_command}"
|
|
393
|
+
)
|
|
394
|
+
return None
|
|
395
|
+
|
|
396
|
+
except Exception as e:
|
|
397
|
+
self.logger.warning(
|
|
398
|
+
f"Failed to create CachedContent (prompt caching disabled, using system_instruction): {str(e)}"
|
|
399
|
+
)
|
|
400
|
+
# Don't raise - allow fallback to regular generation without caching
|
|
401
|
+
return None
|
|
402
|
+
|
|
403
|
+
def _convert_messages_to_contents(
|
|
404
|
+
self, messages: List[LLMMessage]
|
|
405
|
+
) -> List[Content]:
|
|
406
|
+
"""
|
|
407
|
+
Convert LLMMessage list to Vertex AI Content objects.
|
|
408
|
+
|
|
409
|
+
This properly handles multi-turn conversations including
|
|
410
|
+
function/tool responses for Vertex AI Function Calling.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
messages: List of LLMMessage objects (system messages should be filtered out)
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
List of Content objects for Vertex AI API
|
|
417
|
+
"""
|
|
418
|
+
contents = []
|
|
419
|
+
|
|
420
|
+
for msg in messages:
|
|
421
|
+
# Handle tool/function responses (role="tool")
|
|
422
|
+
if msg.role == "tool":
|
|
423
|
+
# Vertex AI expects function responses as user messages with FunctionResponse parts
|
|
424
|
+
# The tool_call_id maps to the function name
|
|
425
|
+
func_name = msg.tool_call_id or "unknown_function"
|
|
426
|
+
|
|
427
|
+
# Parse content as the function response
|
|
428
|
+
try:
|
|
429
|
+
# Try to parse as JSON if it looks like JSON
|
|
430
|
+
if msg.content and msg.content.strip().startswith('{'):
|
|
431
|
+
response_data = json.loads(msg.content)
|
|
432
|
+
else:
|
|
433
|
+
response_data = {"result": msg.content}
|
|
434
|
+
except json.JSONDecodeError:
|
|
435
|
+
response_data = {"result": msg.content}
|
|
436
|
+
|
|
437
|
+
# Create FunctionResponse part using Part.from_function_response
|
|
438
|
+
func_response_part = Part.from_function_response(
|
|
439
|
+
name=func_name,
|
|
440
|
+
response=response_data
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
contents.append(Content(
|
|
444
|
+
role="user", # Function responses are sent as "user" role in Vertex AI
|
|
445
|
+
parts=[func_response_part]
|
|
446
|
+
))
|
|
447
|
+
|
|
448
|
+
# Handle assistant messages with tool calls
|
|
449
|
+
elif msg.role == "assistant" and msg.tool_calls:
|
|
450
|
+
parts = []
|
|
451
|
+
if msg.content:
|
|
452
|
+
parts.append(Part.from_text(msg.content))
|
|
453
|
+
|
|
454
|
+
for tool_call in msg.tool_calls:
|
|
455
|
+
func = tool_call.get("function", {})
|
|
456
|
+
func_name = func.get("name", "")
|
|
457
|
+
func_args = func.get("arguments", "{}")
|
|
458
|
+
|
|
459
|
+
# Parse arguments
|
|
460
|
+
try:
|
|
461
|
+
args_dict = json.loads(func_args) if isinstance(func_args, str) else func_args
|
|
462
|
+
except json.JSONDecodeError:
|
|
463
|
+
args_dict = {}
|
|
464
|
+
|
|
465
|
+
# Create FunctionCall part using Part.from_dict
|
|
466
|
+
# Note: Part.from_function_call() does NOT exist in Vertex AI SDK
|
|
467
|
+
# Must use from_dict with function_call structure
|
|
468
|
+
function_call_part = Part.from_dict({
|
|
469
|
+
"function_call": {
|
|
470
|
+
"name": func_name,
|
|
471
|
+
"args": args_dict
|
|
472
|
+
}
|
|
473
|
+
})
|
|
474
|
+
parts.append(function_call_part)
|
|
475
|
+
|
|
476
|
+
contents.append(Content(
|
|
477
|
+
role="model",
|
|
478
|
+
parts=parts
|
|
479
|
+
))
|
|
480
|
+
|
|
481
|
+
# Handle regular messages (user, assistant without tool_calls)
|
|
482
|
+
else:
|
|
483
|
+
role = "model" if msg.role == "assistant" else msg.role
|
|
484
|
+
if msg.content:
|
|
485
|
+
contents.append(Content(
|
|
486
|
+
role=role,
|
|
487
|
+
parts=[Part.from_text(msg.content)]
|
|
488
|
+
))
|
|
489
|
+
|
|
490
|
+
return contents
|
|
491
|
+
|
|
492
|
+
async def generate_text(
|
|
493
|
+
self,
|
|
494
|
+
messages: List[LLMMessage],
|
|
495
|
+
model: Optional[str] = None,
|
|
496
|
+
temperature: float = 0.7,
|
|
497
|
+
max_tokens: Optional[int] = None,
|
|
498
|
+
functions: Optional[List[Dict[str, Any]]] = None,
|
|
499
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
500
|
+
tool_choice: Optional[Any] = None,
|
|
501
|
+
system_instruction: Optional[str] = None,
|
|
502
|
+
**kwargs,
|
|
503
|
+
) -> LLMResponse:
|
|
504
|
+
"""Generate text using Vertex AI"""
|
|
505
|
+
self._init_vertex_ai()
|
|
506
|
+
|
|
507
|
+
# Get model name from config if not provided
|
|
508
|
+
model_name = model or self._get_default_model() or "gemini-2.5-pro"
|
|
509
|
+
|
|
510
|
+
# Get model config for default parameters
|
|
511
|
+
model_config = self._get_model_config(model_name)
|
|
512
|
+
if model_config and max_tokens is None:
|
|
513
|
+
max_tokens = model_config.default_params.max_tokens
|
|
514
|
+
|
|
515
|
+
try:
|
|
516
|
+
# Extract system message from messages if present
|
|
517
|
+
system_msg = None
|
|
518
|
+
system_cache_control = None
|
|
519
|
+
user_messages = []
|
|
520
|
+
for msg in messages:
|
|
521
|
+
if msg.role == "system":
|
|
522
|
+
system_msg = msg.content
|
|
523
|
+
system_cache_control = msg.cache_control
|
|
524
|
+
else:
|
|
525
|
+
user_messages.append(msg)
|
|
526
|
+
|
|
527
|
+
# Use explicit system_instruction parameter if provided, else use extracted system message
|
|
528
|
+
final_system_instruction = system_instruction or system_msg
|
|
529
|
+
|
|
530
|
+
# Check if we should use CachedContent API for prompt caching
|
|
531
|
+
cached_content_id = None
|
|
532
|
+
if final_system_instruction and system_cache_control:
|
|
533
|
+
# Create or get CachedContent for the system instruction
|
|
534
|
+
# Extract TTL from cache_control if available (defaults to 3600 seconds)
|
|
535
|
+
ttl_seconds = getattr(system_cache_control, 'ttl_seconds', None) or 3600
|
|
536
|
+
cached_content_id = await self._create_or_get_cached_content(
|
|
537
|
+
content=final_system_instruction,
|
|
538
|
+
model_name=model_name,
|
|
539
|
+
ttl_seconds=ttl_seconds,
|
|
540
|
+
)
|
|
541
|
+
if cached_content_id:
|
|
542
|
+
self.logger.debug(f"Using CachedContent for prompt caching: {cached_content_id}")
|
|
543
|
+
# When using CachedContent, we don't pass system_instruction to GenerativeModel
|
|
544
|
+
# Instead, we'll pass cached_content_id to generate_content
|
|
545
|
+
final_system_instruction = None
|
|
546
|
+
|
|
547
|
+
# Initialize model WITH system instruction for prompt caching support
|
|
548
|
+
# Note: If using CachedContent, system_instruction will be None
|
|
549
|
+
model_instance = GenerativeModel(
|
|
550
|
+
model_name,
|
|
551
|
+
system_instruction=final_system_instruction
|
|
552
|
+
)
|
|
553
|
+
self.logger.debug(f"Initialized Vertex AI model: {model_name}")
|
|
554
|
+
|
|
555
|
+
# Convert messages to Vertex AI format
|
|
556
|
+
if len(user_messages) == 1 and user_messages[0].role == "user":
|
|
557
|
+
contents = user_messages[0].content
|
|
558
|
+
else:
|
|
559
|
+
# For multi-turn conversations, use proper Content objects
|
|
560
|
+
contents = self._convert_messages_to_contents(user_messages)
|
|
561
|
+
|
|
562
|
+
# Use modern GenerationConfig object
|
|
563
|
+
generation_config = GenerationConfig(
|
|
564
|
+
temperature=temperature,
|
|
565
|
+
# Increased to account for thinking tokens
|
|
566
|
+
max_output_tokens=max_tokens or 8192,
|
|
567
|
+
top_p=0.95,
|
|
568
|
+
top_k=40,
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
# Modern safety settings configuration using SafetySetting objects
|
|
572
|
+
# Allow override via kwargs, otherwise use defaults (BLOCK_NONE for all categories)
|
|
573
|
+
if "safety_settings" in kwargs:
|
|
574
|
+
safety_settings = kwargs["safety_settings"]
|
|
575
|
+
if not isinstance(safety_settings, list):
|
|
576
|
+
raise ValueError("safety_settings must be a list of SafetySetting objects")
|
|
577
|
+
else:
|
|
578
|
+
# Default safety settings - can be configured via environment or config
|
|
579
|
+
# Default to BLOCK_NONE to allow all content (can be overridden)
|
|
580
|
+
safety_settings = [
|
|
581
|
+
SafetySetting(
|
|
582
|
+
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
|
|
583
|
+
threshold=HarmBlockThreshold.BLOCK_NONE,
|
|
584
|
+
),
|
|
585
|
+
SafetySetting(
|
|
586
|
+
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
|
587
|
+
threshold=HarmBlockThreshold.BLOCK_NONE,
|
|
588
|
+
),
|
|
589
|
+
SafetySetting(
|
|
590
|
+
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
|
591
|
+
threshold=HarmBlockThreshold.BLOCK_NONE,
|
|
592
|
+
),
|
|
593
|
+
SafetySetting(
|
|
594
|
+
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
|
595
|
+
threshold=HarmBlockThreshold.BLOCK_NONE,
|
|
596
|
+
),
|
|
597
|
+
]
|
|
598
|
+
|
|
599
|
+
# Prepare tools for Function Calling if provided
|
|
600
|
+
tools_for_api = None
|
|
601
|
+
if tools or functions:
|
|
602
|
+
# Convert OpenAI format to Google format
|
|
603
|
+
tools_list = tools or []
|
|
604
|
+
if functions:
|
|
605
|
+
# Convert legacy functions format to tools format
|
|
606
|
+
tools_list = [{"type": "function", "function": func} for func in functions]
|
|
607
|
+
|
|
608
|
+
google_tools = self._convert_openai_to_google_format(tools_list)
|
|
609
|
+
if google_tools:
|
|
610
|
+
tools_for_api = google_tools
|
|
611
|
+
|
|
612
|
+
# Build API call parameters
|
|
613
|
+
api_params = {
|
|
614
|
+
"contents": contents,
|
|
615
|
+
"generation_config": generation_config,
|
|
616
|
+
"safety_settings": safety_settings,
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
# Add cached_content if using CachedContent API for prompt caching
|
|
620
|
+
if cached_content_id:
|
|
621
|
+
api_params["cached_content"] = cached_content_id
|
|
622
|
+
self.logger.debug(f"Added cached_content to API params: {cached_content_id}")
|
|
623
|
+
|
|
624
|
+
# Add tools if available
|
|
625
|
+
if tools_for_api:
|
|
626
|
+
api_params["tools"] = tools_for_api
|
|
627
|
+
|
|
628
|
+
# Add any additional kwargs (but exclude tools/safety_settings/cached_content to avoid conflicts)
|
|
629
|
+
for key, value in kwargs.items():
|
|
630
|
+
if key not in ["tools", "safety_settings", "cached_content"]:
|
|
631
|
+
api_params[key] = value
|
|
632
|
+
|
|
633
|
+
response = await asyncio.get_event_loop().run_in_executor(
|
|
634
|
+
None,
|
|
635
|
+
lambda: model_instance.generate_content(**api_params),
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
# Check for prompt-level safety blocks first
|
|
639
|
+
if hasattr(response, "prompt_feedback"):
|
|
640
|
+
pf = response.prompt_feedback
|
|
641
|
+
# Check if prompt was blocked
|
|
642
|
+
if hasattr(pf, "block_reason") and pf.block_reason:
|
|
643
|
+
block_reason = str(pf.block_reason)
|
|
644
|
+
if block_reason not in ["BLOCKED_REASON_UNSPECIFIED", "OTHER"]:
|
|
645
|
+
# Prompt was blocked by safety filters
|
|
646
|
+
raise _build_safety_block_error(
|
|
647
|
+
response,
|
|
648
|
+
block_type="prompt",
|
|
649
|
+
default_message="Prompt blocked by safety filters",
|
|
650
|
+
)
|
|
651
|
+
elif hasattr(pf, "get") and pf.get("block_reason"):
|
|
652
|
+
block_reason = pf.get("block_reason")
|
|
653
|
+
if block_reason not in ["BLOCKED_REASON_UNSPECIFIED", "OTHER"]:
|
|
654
|
+
raise _build_safety_block_error(
|
|
655
|
+
response,
|
|
656
|
+
block_type="prompt",
|
|
657
|
+
default_message="Prompt blocked by safety filters",
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
# Handle response content safely - improved multi-part response
|
|
661
|
+
# handling
|
|
662
|
+
content = None
|
|
663
|
+
try:
|
|
664
|
+
# First try to get text directly
|
|
665
|
+
content = response.text
|
|
666
|
+
self.logger.debug(f"Vertex AI response received: {content[:100]}...")
|
|
667
|
+
except (ValueError, AttributeError) as ve:
|
|
668
|
+
# Handle multi-part responses and other issues
|
|
669
|
+
self.logger.warning(f"Cannot get response text directly: {str(ve)}")
|
|
670
|
+
|
|
671
|
+
# Try to extract content from candidates with multi-part
|
|
672
|
+
# support
|
|
673
|
+
if hasattr(response, "candidates") and response.candidates:
|
|
674
|
+
candidate = response.candidates[0]
|
|
675
|
+
self.logger.debug(f"Candidate finish_reason: {getattr(candidate, 'finish_reason', 'unknown')}")
|
|
676
|
+
|
|
677
|
+
# Handle multi-part content
|
|
678
|
+
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
|
|
679
|
+
try:
|
|
680
|
+
# Extract text from all parts
|
|
681
|
+
text_parts: List[str] = []
|
|
682
|
+
for part in candidate.content.parts:
|
|
683
|
+
if hasattr(part, "text") and part.text:
|
|
684
|
+
text_parts.append(str(part.text))
|
|
685
|
+
|
|
686
|
+
if text_parts:
|
|
687
|
+
# Log part count for monitoring
|
|
688
|
+
part_count = len(text_parts)
|
|
689
|
+
self.logger.info(f"📊 Vertex AI response: {part_count} parts detected")
|
|
690
|
+
|
|
691
|
+
# Update statistics
|
|
692
|
+
self._part_count_stats["total_responses"] += 1
|
|
693
|
+
self._part_count_stats["part_counts"][part_count] = self._part_count_stats["part_counts"].get(part_count, 0) + 1
|
|
694
|
+
self._part_count_stats["last_part_count"] = part_count
|
|
695
|
+
|
|
696
|
+
# Log statistics if significant variation
|
|
697
|
+
# detected
|
|
698
|
+
if part_count != self._part_count_stats.get("last_part_count", part_count):
|
|
699
|
+
self.logger.warning(f"⚠️ Part count variation detected: {part_count} parts (previous: {self._part_count_stats.get('last_part_count', 'unknown')})")
|
|
700
|
+
|
|
701
|
+
# Handle multi-part response format
|
|
702
|
+
if len(text_parts) > 1:
|
|
703
|
+
# Multi-part response
|
|
704
|
+
# Minimal fix: only fix incomplete <thinking> tags, preserve original order
|
|
705
|
+
# Do NOT reorganize content - let
|
|
706
|
+
# downstream code handle semantics
|
|
707
|
+
|
|
708
|
+
processed_parts = []
|
|
709
|
+
fixed_count = 0
|
|
710
|
+
|
|
711
|
+
for i, part_raw in enumerate(text_parts):
|
|
712
|
+
# Check for thinking content that needs
|
|
713
|
+
# formatting
|
|
714
|
+
needs_thinking_format = False
|
|
715
|
+
# Ensure part is a string (use different name to avoid redefinition)
|
|
716
|
+
part_str: str = str(part_raw) if not isinstance(part_raw, str) else part_raw
|
|
717
|
+
|
|
718
|
+
if "<thinking>" in part_str and "</thinking>" not in part_str: # type: ignore[operator]
|
|
719
|
+
# Incomplete <thinking> tag: add
|
|
720
|
+
# closing tag
|
|
721
|
+
part_str = part_str + "\n</thinking>" # type: ignore[operator]
|
|
722
|
+
needs_thinking_format = True
|
|
723
|
+
self.logger.debug(f" Part {i+1}: Incomplete <thinking> tag fixed")
|
|
724
|
+
elif isinstance(part_str, str) and part_str.startswith("thinking") and "</thinking>" not in part_str: # type: ignore[operator]
|
|
725
|
+
# thinking\n format: convert to
|
|
726
|
+
# <thinking>...</thinking>
|
|
727
|
+
if part_str.startswith("thinking\n"):
|
|
728
|
+
# thinking\n格式:提取内容并包装
|
|
729
|
+
# 跳过 "thinking\n"
|
|
730
|
+
content = part_str[8:]
|
|
731
|
+
else:
|
|
732
|
+
# thinking开头但无换行:提取内容并包装
|
|
733
|
+
# 跳过 "thinking"
|
|
734
|
+
content = part_str[7:]
|
|
735
|
+
|
|
736
|
+
part_str = f"<thinking>\n{content}\n</thinking>"
|
|
737
|
+
needs_thinking_format = True
|
|
738
|
+
self.logger.debug(f" Part {i+1}: thinking\\n format converted to <thinking> tags")
|
|
739
|
+
|
|
740
|
+
if needs_thinking_format:
|
|
741
|
+
fixed_count += 1
|
|
742
|
+
|
|
743
|
+
processed_parts.append(part_str)
|
|
744
|
+
|
|
745
|
+
# Merge in original order
|
|
746
|
+
content = "\n".join(processed_parts)
|
|
747
|
+
|
|
748
|
+
if fixed_count > 0:
|
|
749
|
+
self.logger.info(f"✅ Multi-part response merged: {len(text_parts)} parts, {fixed_count} incomplete tags fixed, order preserved")
|
|
750
|
+
else:
|
|
751
|
+
self.logger.info(f"✅ Multi-part response merged: {len(text_parts)} parts, order preserved")
|
|
752
|
+
else:
|
|
753
|
+
# Single part response - use as is
|
|
754
|
+
content = text_parts[0]
|
|
755
|
+
self.logger.info("Successfully extracted single-part response")
|
|
756
|
+
else:
|
|
757
|
+
self.logger.warning("No text content found in multi-part response")
|
|
758
|
+
except Exception as part_error:
|
|
759
|
+
self.logger.error(f"Failed to extract content from multi-part response: {str(part_error)}")
|
|
760
|
+
|
|
761
|
+
# If still no content, check finish reason
|
|
762
|
+
if not content:
|
|
763
|
+
if hasattr(candidate, "finish_reason"):
|
|
764
|
+
if candidate.finish_reason == "MAX_TOKENS":
|
|
765
|
+
content = "[Response truncated due to token limit - consider increasing max_tokens for Gemini 2.5 models]"
|
|
766
|
+
self.logger.warning("Response truncated due to MAX_TOKENS - Gemini 2.5 uses thinking tokens")
|
|
767
|
+
elif candidate.finish_reason in [
|
|
768
|
+
"SAFETY",
|
|
769
|
+
"RECITATION",
|
|
770
|
+
]:
|
|
771
|
+
# Response was blocked by safety filters
|
|
772
|
+
raise _build_safety_block_error(
|
|
773
|
+
response,
|
|
774
|
+
block_type="response",
|
|
775
|
+
default_message="Response blocked by safety filters",
|
|
776
|
+
)
|
|
777
|
+
else:
|
|
778
|
+
content = f"[Response error: Cannot get response text - {candidate.finish_reason}]"
|
|
779
|
+
else:
|
|
780
|
+
content = "[Response error: Cannot get the response text]"
|
|
781
|
+
else:
|
|
782
|
+
# No candidates found - check if this is due to safety filters
|
|
783
|
+
# Check prompt_feedback for block reason
|
|
784
|
+
if hasattr(response, "prompt_feedback"):
|
|
785
|
+
pf = response.prompt_feedback
|
|
786
|
+
if hasattr(pf, "block_reason") and pf.block_reason:
|
|
787
|
+
block_reason = str(pf.block_reason)
|
|
788
|
+
if block_reason not in ["BLOCKED_REASON_UNSPECIFIED", "OTHER"]:
|
|
789
|
+
raise _build_safety_block_error(
|
|
790
|
+
response,
|
|
791
|
+
block_type="prompt",
|
|
792
|
+
default_message="No candidates found - prompt blocked by safety filters",
|
|
793
|
+
)
|
|
794
|
+
elif hasattr(pf, "get") and pf.get("block_reason"):
|
|
795
|
+
block_reason = pf.get("block_reason")
|
|
796
|
+
if block_reason not in ["BLOCKED_REASON_UNSPECIFIED", "OTHER"]:
|
|
797
|
+
raise _build_safety_block_error(
|
|
798
|
+
response,
|
|
799
|
+
block_type="prompt",
|
|
800
|
+
default_message="No candidates found - prompt blocked by safety filters",
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
# If not a safety block, raise generic error with details
|
|
804
|
+
error_msg = f"Response error: No candidates found - Response has no candidates (and thus no text)."
|
|
805
|
+
if hasattr(response, "prompt_feedback"):
|
|
806
|
+
error_msg += " Check prompt_feedback for details."
|
|
807
|
+
raise ValueError(error_msg)
|
|
808
|
+
|
|
809
|
+
# Final fallback
|
|
810
|
+
if not content:
|
|
811
|
+
content = "[Response error: Cannot get the response text. Multiple content parts are not supported.]"
|
|
812
|
+
|
|
813
|
+
# Vertex AI doesn't provide detailed token usage in the response
|
|
814
|
+
# Use estimation method as fallback
|
|
815
|
+
# Estimate input tokens from messages content
|
|
816
|
+
prompt_text = " ".join(msg.content for msg in messages if msg.content)
|
|
817
|
+
input_tokens = self._count_tokens_estimate(prompt_text)
|
|
818
|
+
output_tokens = self._count_tokens_estimate(content)
|
|
819
|
+
tokens_used = input_tokens + output_tokens
|
|
820
|
+
|
|
821
|
+
# Extract cache metadata from Vertex AI response if available
|
|
822
|
+
cache_read_tokens = None
|
|
823
|
+
cache_hit = None
|
|
824
|
+
if hasattr(response, "usage_metadata"):
|
|
825
|
+
usage = response.usage_metadata
|
|
826
|
+
if hasattr(usage, "cached_content_token_count"):
|
|
827
|
+
cache_read_tokens = usage.cached_content_token_count
|
|
828
|
+
cache_hit = cache_read_tokens is not None and cache_read_tokens > 0
|
|
829
|
+
|
|
830
|
+
# Use config-based cost estimation
|
|
831
|
+
cost = self._estimate_cost_from_config(model_name, input_tokens, output_tokens)
|
|
832
|
+
|
|
833
|
+
# Extract function calls from response if present
|
|
834
|
+
function_calls = self._extract_function_calls_from_google_response(response)
|
|
835
|
+
|
|
836
|
+
llm_response = LLMResponse(
|
|
837
|
+
content=content,
|
|
838
|
+
provider=self.provider_name,
|
|
839
|
+
model=model_name,
|
|
840
|
+
tokens_used=tokens_used,
|
|
841
|
+
prompt_tokens=input_tokens, # Estimated value since Vertex AI doesn't provide detailed usage
|
|
842
|
+
completion_tokens=output_tokens, # Estimated value since Vertex AI doesn't provide detailed usage
|
|
843
|
+
cost_estimate=cost,
|
|
844
|
+
cache_read_tokens=cache_read_tokens,
|
|
845
|
+
cache_hit=cache_hit,
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
# Attach function call info if present
|
|
849
|
+
if function_calls:
|
|
850
|
+
self._attach_function_calls_to_response(llm_response, function_calls)
|
|
851
|
+
|
|
852
|
+
return llm_response
|
|
853
|
+
|
|
854
|
+
except SafetyBlockError:
|
|
855
|
+
# Re-raise safety block errors as-is (they already contain detailed information)
|
|
856
|
+
raise
|
|
857
|
+
except Exception as e:
|
|
858
|
+
if "quota" in str(e).lower() or "limit" in str(e).lower():
|
|
859
|
+
raise RateLimitError(f"Vertex AI quota exceeded: {str(e)}")
|
|
860
|
+
# Handle specific Vertex AI response errors
|
|
861
|
+
if any(
|
|
862
|
+
keyword in str(e).lower()
|
|
863
|
+
for keyword in [
|
|
864
|
+
"cannot get the response text",
|
|
865
|
+
"safety filters",
|
|
866
|
+
"multiple content parts are not supported",
|
|
867
|
+
"cannot get the candidate text",
|
|
868
|
+
]
|
|
869
|
+
):
|
|
870
|
+
self.logger.warning(f"Vertex AI response issue: {str(e)}")
|
|
871
|
+
# Return a response indicating the issue
|
|
872
|
+
# Estimate prompt tokens from messages content
|
|
873
|
+
prompt_text = " ".join(msg.content for msg in messages if msg.content)
|
|
874
|
+
estimated_prompt_tokens = self._count_tokens_estimate(prompt_text)
|
|
875
|
+
return LLMResponse(
|
|
876
|
+
content="[Response unavailable due to content processing issues or safety filters]",
|
|
877
|
+
provider=self.provider_name,
|
|
878
|
+
model=model_name,
|
|
879
|
+
tokens_used=estimated_prompt_tokens,
|
|
880
|
+
prompt_tokens=estimated_prompt_tokens,
|
|
881
|
+
completion_tokens=0,
|
|
882
|
+
cost_estimate=0.0,
|
|
883
|
+
)
|
|
884
|
+
raise
|
|
885
|
+
|
|
886
|
+
async def stream_text( # type: ignore[override]
|
|
887
|
+
self,
|
|
888
|
+
messages: List[LLMMessage],
|
|
889
|
+
model: Optional[str] = None,
|
|
890
|
+
temperature: float = 0.7,
|
|
891
|
+
max_tokens: Optional[int] = None,
|
|
892
|
+
functions: Optional[List[Dict[str, Any]]] = None,
|
|
893
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
894
|
+
tool_choice: Optional[Any] = None,
|
|
895
|
+
return_chunks: bool = False,
|
|
896
|
+
system_instruction: Optional[str] = None,
|
|
897
|
+
**kwargs,
|
|
898
|
+
) -> AsyncGenerator[Any, None]:
|
|
899
|
+
"""
|
|
900
|
+
Stream text using Vertex AI real streaming API with Function Calling support.
|
|
901
|
+
|
|
902
|
+
Args:
|
|
903
|
+
messages: List of LLM messages
|
|
904
|
+
model: Model name (optional)
|
|
905
|
+
temperature: Temperature for generation
|
|
906
|
+
max_tokens: Maximum tokens to generate
|
|
907
|
+
functions: List of function schemas (legacy format)
|
|
908
|
+
tools: List of tool schemas (new format)
|
|
909
|
+
tool_choice: Tool choice strategy (not used for Google Vertex AI)
|
|
910
|
+
return_chunks: If True, returns GoogleStreamChunk objects; if False, returns str tokens only
|
|
911
|
+
system_instruction: System instruction for prompt caching support
|
|
912
|
+
**kwargs: Additional arguments
|
|
913
|
+
|
|
914
|
+
Yields:
|
|
915
|
+
str or GoogleStreamChunk: Text tokens or StreamChunk objects
|
|
916
|
+
"""
|
|
917
|
+
self._init_vertex_ai()
|
|
918
|
+
|
|
919
|
+
# Get model name from config if not provided
|
|
920
|
+
model_name = model or self._get_default_model() or "gemini-2.5-pro"
|
|
921
|
+
|
|
922
|
+
# Get model config for default parameters
|
|
923
|
+
model_config = self._get_model_config(model_name)
|
|
924
|
+
if model_config and max_tokens is None:
|
|
925
|
+
max_tokens = model_config.default_params.max_tokens
|
|
926
|
+
|
|
927
|
+
try:
|
|
928
|
+
# Extract system message from messages if present
|
|
929
|
+
system_msg = None
|
|
930
|
+
system_cache_control = None
|
|
931
|
+
user_messages = []
|
|
932
|
+
for msg in messages:
|
|
933
|
+
if msg.role == "system":
|
|
934
|
+
system_msg = msg.content
|
|
935
|
+
system_cache_control = msg.cache_control
|
|
936
|
+
else:
|
|
937
|
+
user_messages.append(msg)
|
|
938
|
+
|
|
939
|
+
# Use explicit system_instruction parameter if provided, else use extracted system message
|
|
940
|
+
final_system_instruction = system_instruction or system_msg
|
|
941
|
+
|
|
942
|
+
# Check if we should use CachedContent API for prompt caching
|
|
943
|
+
cached_content_id = None
|
|
944
|
+
if final_system_instruction and system_cache_control:
|
|
945
|
+
# Create or get CachedContent for the system instruction
|
|
946
|
+
# Extract TTL from cache_control if available (defaults to 3600 seconds)
|
|
947
|
+
ttl_seconds = getattr(system_cache_control, 'ttl_seconds', None) or 3600
|
|
948
|
+
cached_content_id = await self._create_or_get_cached_content(
|
|
949
|
+
content=final_system_instruction,
|
|
950
|
+
model_name=model_name,
|
|
951
|
+
ttl_seconds=ttl_seconds,
|
|
952
|
+
)
|
|
953
|
+
if cached_content_id:
|
|
954
|
+
self.logger.debug(f"Using CachedContent for prompt caching in streaming: {cached_content_id}")
|
|
955
|
+
# When using CachedContent, we don't pass system_instruction to GenerativeModel
|
|
956
|
+
# Instead, we'll pass cached_content_id to generate_content
|
|
957
|
+
final_system_instruction = None
|
|
958
|
+
|
|
959
|
+
# Initialize model WITH system instruction for prompt caching support
|
|
960
|
+
# Note: If using CachedContent, system_instruction will be None
|
|
961
|
+
model_instance = GenerativeModel(
|
|
962
|
+
model_name,
|
|
963
|
+
system_instruction=final_system_instruction
|
|
964
|
+
)
|
|
965
|
+
self.logger.debug(f"Initialized Vertex AI model for streaming: {model_name}")
|
|
966
|
+
|
|
967
|
+
# Convert messages to Vertex AI format
|
|
968
|
+
if len(user_messages) == 1 and user_messages[0].role == "user":
|
|
969
|
+
contents = user_messages[0].content
|
|
970
|
+
else:
|
|
971
|
+
# For multi-turn conversations, use proper Content objects
|
|
972
|
+
contents = self._convert_messages_to_contents(user_messages)
|
|
973
|
+
|
|
974
|
+
# Use modern GenerationConfig object
|
|
975
|
+
generation_config = GenerationConfig(
|
|
976
|
+
temperature=temperature,
|
|
977
|
+
max_output_tokens=max_tokens or 8192,
|
|
978
|
+
top_p=0.95,
|
|
979
|
+
top_k=40,
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
# Get safety settings from kwargs or use defaults
|
|
983
|
+
if "safety_settings" in kwargs:
|
|
984
|
+
safety_settings = kwargs["safety_settings"]
|
|
985
|
+
if not isinstance(safety_settings, list):
|
|
986
|
+
raise ValueError("safety_settings must be a list of SafetySetting objects")
|
|
987
|
+
else:
|
|
988
|
+
# Default safety settings
|
|
989
|
+
safety_settings = [
|
|
990
|
+
SafetySetting(
|
|
991
|
+
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
|
|
992
|
+
threshold=HarmBlockThreshold.BLOCK_NONE,
|
|
993
|
+
),
|
|
994
|
+
SafetySetting(
|
|
995
|
+
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
|
996
|
+
threshold=HarmBlockThreshold.BLOCK_NONE,
|
|
997
|
+
),
|
|
998
|
+
SafetySetting(
|
|
999
|
+
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
|
1000
|
+
threshold=HarmBlockThreshold.BLOCK_NONE,
|
|
1001
|
+
),
|
|
1002
|
+
SafetySetting(
|
|
1003
|
+
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
|
1004
|
+
threshold=HarmBlockThreshold.BLOCK_NONE,
|
|
1005
|
+
),
|
|
1006
|
+
]
|
|
1007
|
+
|
|
1008
|
+
# Prepare tools for Function Calling if provided
|
|
1009
|
+
tools_for_api = None
|
|
1010
|
+
if tools or functions:
|
|
1011
|
+
# Convert OpenAI format to Google format
|
|
1012
|
+
tools_list = tools or []
|
|
1013
|
+
if functions:
|
|
1014
|
+
# Convert legacy functions format to tools format
|
|
1015
|
+
tools_list = [{"type": "function", "function": func} for func in functions]
|
|
1016
|
+
|
|
1017
|
+
google_tools = self._convert_openai_to_google_format(tools_list)
|
|
1018
|
+
if google_tools:
|
|
1019
|
+
tools_for_api = google_tools
|
|
1020
|
+
|
|
1021
|
+
# Use mixin method for Function Calling support
|
|
1022
|
+
from aiecs.llm.clients.openai_compatible_mixin import StreamChunk
|
|
1023
|
+
|
|
1024
|
+
# Add cached_content to kwargs if using CachedContent API
|
|
1025
|
+
stream_kwargs = kwargs.copy()
|
|
1026
|
+
if cached_content_id:
|
|
1027
|
+
stream_kwargs["cached_content"] = cached_content_id
|
|
1028
|
+
self.logger.debug(f"Added cached_content to streaming API params: {cached_content_id}")
|
|
1029
|
+
|
|
1030
|
+
async for chunk in self._stream_text_with_function_calling(
|
|
1031
|
+
model_instance=model_instance,
|
|
1032
|
+
contents=contents,
|
|
1033
|
+
generation_config=generation_config,
|
|
1034
|
+
safety_settings=safety_settings,
|
|
1035
|
+
tools=tools_for_api,
|
|
1036
|
+
return_chunks=return_chunks,
|
|
1037
|
+
**stream_kwargs,
|
|
1038
|
+
):
|
|
1039
|
+
# Yield chunk (can be str or StreamChunk)
|
|
1040
|
+
yield chunk
|
|
1041
|
+
|
|
1042
|
+
except SafetyBlockError:
|
|
1043
|
+
# Re-raise safety block errors as-is
|
|
1044
|
+
raise
|
|
1045
|
+
except Exception as e:
|
|
1046
|
+
if "quota" in str(e).lower() or "limit" in str(e).lower():
|
|
1047
|
+
raise RateLimitError(f"Vertex AI quota exceeded: {str(e)}")
|
|
1048
|
+
self.logger.error(f"Error in Vertex AI streaming: {str(e)}")
|
|
1049
|
+
raise
|
|
1050
|
+
|
|
1051
|
+
def get_part_count_stats(self) -> Dict[str, Any]:
|
|
1052
|
+
"""
|
|
1053
|
+
Get statistics about part count variations in Vertex AI responses.
|
|
1054
|
+
|
|
1055
|
+
Returns:
|
|
1056
|
+
Dictionary containing part count statistics and analysis
|
|
1057
|
+
"""
|
|
1058
|
+
stats = self._part_count_stats.copy()
|
|
1059
|
+
|
|
1060
|
+
if stats["total_responses"] > 0:
|
|
1061
|
+
# Calculate variation metrics
|
|
1062
|
+
part_counts = list(stats["part_counts"].keys())
|
|
1063
|
+
stats["variation_analysis"] = {
|
|
1064
|
+
"unique_part_counts": len(part_counts),
|
|
1065
|
+
"most_common_count": (max(stats["part_counts"].items(), key=lambda x: x[1])[0] if stats["part_counts"] else None),
|
|
1066
|
+
"part_count_range": (f"{min(part_counts)}-{max(part_counts)}" if part_counts else "N/A"),
|
|
1067
|
+
# 0-1, higher is more stable
|
|
1068
|
+
"stability_score": 1.0 - (len(part_counts) - 1) / max(stats["total_responses"], 1),
|
|
1069
|
+
}
|
|
1070
|
+
|
|
1071
|
+
# Generate recommendations
|
|
1072
|
+
if stats["variation_analysis"]["stability_score"] < 0.7:
|
|
1073
|
+
stats["recommendations"] = [
|
|
1074
|
+
"High part count variation detected",
|
|
1075
|
+
"Consider optimizing prompt structure",
|
|
1076
|
+
"Monitor input complexity patterns",
|
|
1077
|
+
"Review tool calling configuration",
|
|
1078
|
+
]
|
|
1079
|
+
else:
|
|
1080
|
+
stats["recommendations"] = [
|
|
1081
|
+
"Part count variation is within acceptable range",
|
|
1082
|
+
"Continue monitoring for patterns",
|
|
1083
|
+
]
|
|
1084
|
+
|
|
1085
|
+
return stats
|
|
1086
|
+
|
|
1087
|
+
def log_part_count_summary(self):
|
|
1088
|
+
"""Log a summary of part count statistics"""
|
|
1089
|
+
stats = self.get_part_count_stats()
|
|
1090
|
+
|
|
1091
|
+
if stats["total_responses"] > 0:
|
|
1092
|
+
self.logger.info("📈 Vertex AI Part Count Summary:")
|
|
1093
|
+
self.logger.info(f" Total responses: {stats['total_responses']}")
|
|
1094
|
+
self.logger.info(f" Part count distribution: {stats['part_counts']}")
|
|
1095
|
+
|
|
1096
|
+
if "variation_analysis" in stats:
|
|
1097
|
+
analysis = stats["variation_analysis"]
|
|
1098
|
+
self.logger.info(f" Stability score: {analysis['stability_score']:.2f}")
|
|
1099
|
+
self.logger.info(f" Most common count: {analysis['most_common_count']}")
|
|
1100
|
+
self.logger.info(f" Count range: {analysis['part_count_range']}")
|
|
1101
|
+
|
|
1102
|
+
if "recommendations" in stats:
|
|
1103
|
+
self.logger.info(" Recommendations:")
|
|
1104
|
+
for rec in stats["recommendations"]:
|
|
1105
|
+
self.logger.info(f" • {rec}")
|
|
1106
|
+
|
|
1107
|
+
async def get_embeddings(
|
|
1108
|
+
self,
|
|
1109
|
+
texts: List[str],
|
|
1110
|
+
model: Optional[str] = None,
|
|
1111
|
+
) -> List[List[float]]:
|
|
1112
|
+
"""
|
|
1113
|
+
Generate embeddings using Vertex AI embedding model
|
|
1114
|
+
|
|
1115
|
+
Args:
|
|
1116
|
+
texts: List of texts to embed
|
|
1117
|
+
model: Embedding model name (default: gemini-embedding-001)
|
|
1118
|
+
|
|
1119
|
+
Returns:
|
|
1120
|
+
List of embedding vectors (each is a list of floats)
|
|
1121
|
+
"""
|
|
1122
|
+
self._init_vertex_ai()
|
|
1123
|
+
|
|
1124
|
+
# Use gemini-embedding-001 as default
|
|
1125
|
+
embedding_model_name = model or "gemini-embedding-001"
|
|
1126
|
+
|
|
1127
|
+
try:
|
|
1128
|
+
from google.cloud import aiplatform
|
|
1129
|
+
from google.cloud.aiplatform.gapic import PredictionServiceClient
|
|
1130
|
+
from google.protobuf import struct_pb2
|
|
1131
|
+
|
|
1132
|
+
# Initialize prediction client
|
|
1133
|
+
location = getattr(self.settings, "vertex_location", "us-central1")
|
|
1134
|
+
endpoint = f"{location}-aiplatform.googleapis.com"
|
|
1135
|
+
client = PredictionServiceClient(client_options={"api_endpoint": endpoint})
|
|
1136
|
+
|
|
1137
|
+
# Model resource name
|
|
1138
|
+
model_resource = f"projects/{self.settings.vertex_project_id}/locations/{location}/publishers/google/models/{embedding_model_name}"
|
|
1139
|
+
|
|
1140
|
+
# Generate embeddings for each text
|
|
1141
|
+
embeddings = []
|
|
1142
|
+
for text in texts:
|
|
1143
|
+
# Prepare instance
|
|
1144
|
+
instance = struct_pb2.Struct()
|
|
1145
|
+
instance.fields["content"].string_value = text
|
|
1146
|
+
|
|
1147
|
+
# Make prediction request
|
|
1148
|
+
response = await asyncio.get_event_loop().run_in_executor(
|
|
1149
|
+
None,
|
|
1150
|
+
lambda: client.predict(
|
|
1151
|
+
endpoint=model_resource,
|
|
1152
|
+
instances=[instance]
|
|
1153
|
+
)
|
|
1154
|
+
)
|
|
1155
|
+
|
|
1156
|
+
# Extract embedding
|
|
1157
|
+
if response.predictions and len(response.predictions) > 0:
|
|
1158
|
+
prediction = response.predictions[0]
|
|
1159
|
+
if "embeddings" in prediction and "values" in prediction["embeddings"]:
|
|
1160
|
+
embedding = list(prediction["embeddings"]["values"])
|
|
1161
|
+
embeddings.append(embedding)
|
|
1162
|
+
else:
|
|
1163
|
+
self.logger.warning(f"Unexpected response format for embedding: {prediction}")
|
|
1164
|
+
# Return zero vector as fallback
|
|
1165
|
+
embeddings.append([0.0] * 768)
|
|
1166
|
+
else:
|
|
1167
|
+
self.logger.warning("No predictions returned from embedding model")
|
|
1168
|
+
embeddings.append([0.0] * 768)
|
|
1169
|
+
|
|
1170
|
+
return embeddings
|
|
1171
|
+
|
|
1172
|
+
except ImportError as e:
|
|
1173
|
+
self.logger.error(f"Required Vertex AI libraries not available: {e}")
|
|
1174
|
+
# Return zero vectors as fallback
|
|
1175
|
+
return [[0.0] * 768 for _ in texts]
|
|
1176
|
+
except Exception as e:
|
|
1177
|
+
self.logger.error(f"Error generating embeddings with Vertex AI: {e}")
|
|
1178
|
+
# Return zero vectors as fallback
|
|
1179
|
+
return [[0.0] * 768 for _ in texts]
|
|
1180
|
+
|
|
1181
|
+
async def close(self):
|
|
1182
|
+
"""Clean up resources"""
|
|
1183
|
+
# Log final statistics before cleanup
|
|
1184
|
+
self.log_part_count_summary()
|
|
1185
|
+
# Vertex AI doesn't require explicit cleanup
|
|
1186
|
+
self._initialized = False
|