aiecs 1.0.1__py3-none-any.whl → 1.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiecs might be problematic. Click here for more details.
- aiecs/__init__.py +13 -16
- aiecs/__main__.py +7 -7
- aiecs/aiecs_client.py +269 -75
- aiecs/application/executors/operation_executor.py +79 -54
- aiecs/application/knowledge_graph/__init__.py +7 -0
- aiecs/application/knowledge_graph/builder/__init__.py +37 -0
- aiecs/application/knowledge_graph/builder/data_quality.py +302 -0
- aiecs/application/knowledge_graph/builder/data_reshaping.py +293 -0
- aiecs/application/knowledge_graph/builder/document_builder.py +369 -0
- aiecs/application/knowledge_graph/builder/graph_builder.py +490 -0
- aiecs/application/knowledge_graph/builder/import_optimizer.py +396 -0
- aiecs/application/knowledge_graph/builder/schema_inference.py +462 -0
- aiecs/application/knowledge_graph/builder/schema_mapping.py +563 -0
- aiecs/application/knowledge_graph/builder/structured_pipeline.py +1384 -0
- aiecs/application/knowledge_graph/builder/text_chunker.py +317 -0
- aiecs/application/knowledge_graph/extractors/__init__.py +27 -0
- aiecs/application/knowledge_graph/extractors/base.py +98 -0
- aiecs/application/knowledge_graph/extractors/llm_entity_extractor.py +422 -0
- aiecs/application/knowledge_graph/extractors/llm_relation_extractor.py +347 -0
- aiecs/application/knowledge_graph/extractors/ner_entity_extractor.py +241 -0
- aiecs/application/knowledge_graph/fusion/__init__.py +78 -0
- aiecs/application/knowledge_graph/fusion/ab_testing.py +395 -0
- aiecs/application/knowledge_graph/fusion/abbreviation_expander.py +327 -0
- aiecs/application/knowledge_graph/fusion/alias_index.py +597 -0
- aiecs/application/knowledge_graph/fusion/alias_matcher.py +384 -0
- aiecs/application/knowledge_graph/fusion/cache_coordinator.py +343 -0
- aiecs/application/knowledge_graph/fusion/entity_deduplicator.py +433 -0
- aiecs/application/knowledge_graph/fusion/entity_linker.py +511 -0
- aiecs/application/knowledge_graph/fusion/evaluation_dataset.py +240 -0
- aiecs/application/knowledge_graph/fusion/knowledge_fusion.py +632 -0
- aiecs/application/knowledge_graph/fusion/matching_config.py +489 -0
- aiecs/application/knowledge_graph/fusion/name_normalizer.py +352 -0
- aiecs/application/knowledge_graph/fusion/relation_deduplicator.py +183 -0
- aiecs/application/knowledge_graph/fusion/semantic_name_matcher.py +464 -0
- aiecs/application/knowledge_graph/fusion/similarity_pipeline.py +534 -0
- aiecs/application/knowledge_graph/pattern_matching/__init__.py +21 -0
- aiecs/application/knowledge_graph/pattern_matching/pattern_matcher.py +342 -0
- aiecs/application/knowledge_graph/pattern_matching/query_executor.py +366 -0
- aiecs/application/knowledge_graph/profiling/__init__.py +12 -0
- aiecs/application/knowledge_graph/profiling/query_plan_visualizer.py +195 -0
- aiecs/application/knowledge_graph/profiling/query_profiler.py +223 -0
- aiecs/application/knowledge_graph/reasoning/__init__.py +27 -0
- aiecs/application/knowledge_graph/reasoning/evidence_synthesis.py +341 -0
- aiecs/application/knowledge_graph/reasoning/inference_engine.py +500 -0
- aiecs/application/knowledge_graph/reasoning/logic_form_parser.py +163 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/__init__.py +79 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_builder.py +513 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_nodes.py +913 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_validator.py +866 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/error_handler.py +475 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/parser.py +396 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/query_context.py +208 -0
- aiecs/application/knowledge_graph/reasoning/logic_query_integration.py +170 -0
- aiecs/application/knowledge_graph/reasoning/query_planner.py +855 -0
- aiecs/application/knowledge_graph/reasoning/reasoning_engine.py +518 -0
- aiecs/application/knowledge_graph/retrieval/__init__.py +27 -0
- aiecs/application/knowledge_graph/retrieval/query_intent_classifier.py +211 -0
- aiecs/application/knowledge_graph/retrieval/retrieval_strategies.py +592 -0
- aiecs/application/knowledge_graph/retrieval/strategy_types.py +23 -0
- aiecs/application/knowledge_graph/search/__init__.py +59 -0
- aiecs/application/knowledge_graph/search/hybrid_search.py +457 -0
- aiecs/application/knowledge_graph/search/reranker.py +293 -0
- aiecs/application/knowledge_graph/search/reranker_strategies.py +535 -0
- aiecs/application/knowledge_graph/search/text_similarity.py +392 -0
- aiecs/application/knowledge_graph/traversal/__init__.py +15 -0
- aiecs/application/knowledge_graph/traversal/enhanced_traversal.py +305 -0
- aiecs/application/knowledge_graph/traversal/path_scorer.py +271 -0
- aiecs/application/knowledge_graph/validators/__init__.py +13 -0
- aiecs/application/knowledge_graph/validators/relation_validator.py +239 -0
- aiecs/application/knowledge_graph/visualization/__init__.py +11 -0
- aiecs/application/knowledge_graph/visualization/graph_visualizer.py +313 -0
- aiecs/common/__init__.py +9 -0
- aiecs/common/knowledge_graph/__init__.py +17 -0
- aiecs/common/knowledge_graph/runnable.py +471 -0
- aiecs/config/__init__.py +20 -5
- aiecs/config/config.py +762 -31
- aiecs/config/graph_config.py +131 -0
- aiecs/config/tool_config.py +399 -0
- aiecs/core/__init__.py +29 -13
- aiecs/core/interface/__init__.py +2 -2
- aiecs/core/interface/execution_interface.py +22 -22
- aiecs/core/interface/storage_interface.py +37 -88
- aiecs/core/registry/__init__.py +31 -0
- aiecs/core/registry/service_registry.py +92 -0
- aiecs/domain/__init__.py +270 -1
- aiecs/domain/agent/__init__.py +191 -0
- aiecs/domain/agent/base_agent.py +3870 -0
- aiecs/domain/agent/exceptions.py +99 -0
- aiecs/domain/agent/graph_aware_mixin.py +569 -0
- aiecs/domain/agent/hybrid_agent.py +1435 -0
- aiecs/domain/agent/integration/__init__.py +29 -0
- aiecs/domain/agent/integration/context_compressor.py +216 -0
- aiecs/domain/agent/integration/context_engine_adapter.py +587 -0
- aiecs/domain/agent/integration/protocols.py +281 -0
- aiecs/domain/agent/integration/retry_policy.py +218 -0
- aiecs/domain/agent/integration/role_config.py +213 -0
- aiecs/domain/agent/knowledge_aware_agent.py +1892 -0
- aiecs/domain/agent/lifecycle.py +291 -0
- aiecs/domain/agent/llm_agent.py +692 -0
- aiecs/domain/agent/memory/__init__.py +12 -0
- aiecs/domain/agent/memory/conversation.py +1124 -0
- aiecs/domain/agent/migration/__init__.py +14 -0
- aiecs/domain/agent/migration/conversion.py +163 -0
- aiecs/domain/agent/migration/legacy_wrapper.py +86 -0
- aiecs/domain/agent/models.py +884 -0
- aiecs/domain/agent/observability.py +479 -0
- aiecs/domain/agent/persistence.py +449 -0
- aiecs/domain/agent/prompts/__init__.py +29 -0
- aiecs/domain/agent/prompts/builder.py +159 -0
- aiecs/domain/agent/prompts/formatters.py +187 -0
- aiecs/domain/agent/prompts/template.py +255 -0
- aiecs/domain/agent/registry.py +253 -0
- aiecs/domain/agent/tool_agent.py +444 -0
- aiecs/domain/agent/tools/__init__.py +15 -0
- aiecs/domain/agent/tools/schema_generator.py +364 -0
- aiecs/domain/community/__init__.py +155 -0
- aiecs/domain/community/agent_adapter.py +469 -0
- aiecs/domain/community/analytics.py +432 -0
- aiecs/domain/community/collaborative_workflow.py +648 -0
- aiecs/domain/community/communication_hub.py +634 -0
- aiecs/domain/community/community_builder.py +320 -0
- aiecs/domain/community/community_integration.py +796 -0
- aiecs/domain/community/community_manager.py +803 -0
- aiecs/domain/community/decision_engine.py +849 -0
- aiecs/domain/community/exceptions.py +231 -0
- aiecs/domain/community/models/__init__.py +33 -0
- aiecs/domain/community/models/community_models.py +234 -0
- aiecs/domain/community/resource_manager.py +461 -0
- aiecs/domain/community/shared_context_manager.py +589 -0
- aiecs/domain/context/__init__.py +40 -10
- aiecs/domain/context/context_engine.py +1910 -0
- aiecs/domain/context/conversation_models.py +87 -53
- aiecs/domain/context/graph_memory.py +582 -0
- aiecs/domain/execution/model.py +12 -4
- aiecs/domain/knowledge_graph/__init__.py +19 -0
- aiecs/domain/knowledge_graph/models/__init__.py +52 -0
- aiecs/domain/knowledge_graph/models/entity.py +148 -0
- aiecs/domain/knowledge_graph/models/evidence.py +178 -0
- aiecs/domain/knowledge_graph/models/inference_rule.py +184 -0
- aiecs/domain/knowledge_graph/models/path.py +171 -0
- aiecs/domain/knowledge_graph/models/path_pattern.py +171 -0
- aiecs/domain/knowledge_graph/models/query.py +261 -0
- aiecs/domain/knowledge_graph/models/query_plan.py +181 -0
- aiecs/domain/knowledge_graph/models/relation.py +202 -0
- aiecs/domain/knowledge_graph/schema/__init__.py +23 -0
- aiecs/domain/knowledge_graph/schema/entity_type.py +131 -0
- aiecs/domain/knowledge_graph/schema/graph_schema.py +253 -0
- aiecs/domain/knowledge_graph/schema/property_schema.py +143 -0
- aiecs/domain/knowledge_graph/schema/relation_type.py +163 -0
- aiecs/domain/knowledge_graph/schema/schema_manager.py +691 -0
- aiecs/domain/knowledge_graph/schema/type_enums.py +209 -0
- aiecs/domain/task/dsl_processor.py +172 -56
- aiecs/domain/task/model.py +20 -8
- aiecs/domain/task/task_context.py +27 -24
- aiecs/infrastructure/__init__.py +0 -2
- aiecs/infrastructure/graph_storage/__init__.py +11 -0
- aiecs/infrastructure/graph_storage/base.py +837 -0
- aiecs/infrastructure/graph_storage/batch_operations.py +458 -0
- aiecs/infrastructure/graph_storage/cache.py +424 -0
- aiecs/infrastructure/graph_storage/distributed.py +223 -0
- aiecs/infrastructure/graph_storage/error_handling.py +380 -0
- aiecs/infrastructure/graph_storage/graceful_degradation.py +294 -0
- aiecs/infrastructure/graph_storage/health_checks.py +378 -0
- aiecs/infrastructure/graph_storage/in_memory.py +1197 -0
- aiecs/infrastructure/graph_storage/index_optimization.py +446 -0
- aiecs/infrastructure/graph_storage/lazy_loading.py +431 -0
- aiecs/infrastructure/graph_storage/metrics.py +344 -0
- aiecs/infrastructure/graph_storage/migration.py +400 -0
- aiecs/infrastructure/graph_storage/pagination.py +483 -0
- aiecs/infrastructure/graph_storage/performance_monitoring.py +456 -0
- aiecs/infrastructure/graph_storage/postgres.py +1563 -0
- aiecs/infrastructure/graph_storage/property_storage.py +353 -0
- aiecs/infrastructure/graph_storage/protocols.py +76 -0
- aiecs/infrastructure/graph_storage/query_optimizer.py +642 -0
- aiecs/infrastructure/graph_storage/schema_cache.py +290 -0
- aiecs/infrastructure/graph_storage/sqlite.py +1373 -0
- aiecs/infrastructure/graph_storage/streaming.py +487 -0
- aiecs/infrastructure/graph_storage/tenant.py +412 -0
- aiecs/infrastructure/messaging/celery_task_manager.py +92 -54
- aiecs/infrastructure/messaging/websocket_manager.py +51 -35
- aiecs/infrastructure/monitoring/__init__.py +22 -0
- aiecs/infrastructure/monitoring/executor_metrics.py +45 -11
- aiecs/infrastructure/monitoring/global_metrics_manager.py +212 -0
- aiecs/infrastructure/monitoring/structured_logger.py +3 -7
- aiecs/infrastructure/monitoring/tracing_manager.py +63 -35
- aiecs/infrastructure/persistence/__init__.py +14 -1
- aiecs/infrastructure/persistence/context_engine_client.py +184 -0
- aiecs/infrastructure/persistence/database_manager.py +67 -43
- aiecs/infrastructure/persistence/file_storage.py +180 -103
- aiecs/infrastructure/persistence/redis_client.py +74 -21
- aiecs/llm/__init__.py +73 -25
- aiecs/llm/callbacks/__init__.py +11 -0
- aiecs/llm/{custom_callbacks.py → callbacks/custom_callbacks.py} +26 -19
- aiecs/llm/client_factory.py +224 -36
- aiecs/llm/client_resolver.py +155 -0
- aiecs/llm/clients/__init__.py +38 -0
- aiecs/llm/clients/base_client.py +324 -0
- aiecs/llm/clients/google_function_calling_mixin.py +457 -0
- aiecs/llm/clients/googleai_client.py +241 -0
- aiecs/llm/clients/openai_client.py +158 -0
- aiecs/llm/clients/openai_compatible_mixin.py +367 -0
- aiecs/llm/clients/vertex_client.py +897 -0
- aiecs/llm/clients/xai_client.py +201 -0
- aiecs/llm/config/__init__.py +51 -0
- aiecs/llm/config/config_loader.py +272 -0
- aiecs/llm/config/config_validator.py +206 -0
- aiecs/llm/config/model_config.py +143 -0
- aiecs/llm/protocols.py +149 -0
- aiecs/llm/utils/__init__.py +10 -0
- aiecs/llm/utils/validate_config.py +89 -0
- aiecs/main.py +140 -121
- aiecs/scripts/aid/VERSION_MANAGEMENT.md +138 -0
- aiecs/scripts/aid/__init__.py +19 -0
- aiecs/scripts/aid/module_checker.py +499 -0
- aiecs/scripts/aid/version_manager.py +235 -0
- aiecs/scripts/{DEPENDENCY_SYSTEM_SUMMARY.md → dependance_check/DEPENDENCY_SYSTEM_SUMMARY.md} +1 -0
- aiecs/scripts/{README_DEPENDENCY_CHECKER.md → dependance_check/README_DEPENDENCY_CHECKER.md} +1 -0
- aiecs/scripts/dependance_check/__init__.py +15 -0
- aiecs/scripts/dependance_check/dependency_checker.py +1835 -0
- aiecs/scripts/{dependency_fixer.py → dependance_check/dependency_fixer.py} +192 -90
- aiecs/scripts/{download_nlp_data.py → dependance_check/download_nlp_data.py} +203 -71
- aiecs/scripts/dependance_patch/__init__.py +7 -0
- aiecs/scripts/dependance_patch/fix_weasel/__init__.py +11 -0
- aiecs/scripts/{fix_weasel_validator.py → dependance_patch/fix_weasel/fix_weasel_validator.py} +21 -14
- aiecs/scripts/{patch_weasel_library.sh → dependance_patch/fix_weasel/patch_weasel_library.sh} +1 -1
- aiecs/scripts/knowledge_graph/__init__.py +3 -0
- aiecs/scripts/knowledge_graph/run_threshold_experiments.py +212 -0
- aiecs/scripts/migrations/multi_tenancy/README.md +142 -0
- aiecs/scripts/tools_develop/README.md +671 -0
- aiecs/scripts/tools_develop/README_CONFIG_CHECKER.md +273 -0
- aiecs/scripts/tools_develop/TOOLS_CONFIG_GUIDE.md +1287 -0
- aiecs/scripts/tools_develop/TOOL_AUTO_DISCOVERY.md +234 -0
- aiecs/scripts/tools_develop/__init__.py +21 -0
- aiecs/scripts/tools_develop/check_all_tools_config.py +548 -0
- aiecs/scripts/tools_develop/check_type_annotations.py +257 -0
- aiecs/scripts/tools_develop/pre-commit-schema-coverage.sh +66 -0
- aiecs/scripts/tools_develop/schema_coverage.py +511 -0
- aiecs/scripts/tools_develop/validate_tool_schemas.py +475 -0
- aiecs/scripts/tools_develop/verify_executor_config_fix.py +98 -0
- aiecs/scripts/tools_develop/verify_tools.py +352 -0
- aiecs/tasks/__init__.py +0 -1
- aiecs/tasks/worker.py +115 -47
- aiecs/tools/__init__.py +194 -72
- aiecs/tools/apisource/__init__.py +99 -0
- aiecs/tools/apisource/intelligence/__init__.py +19 -0
- aiecs/tools/apisource/intelligence/data_fusion.py +632 -0
- aiecs/tools/apisource/intelligence/query_analyzer.py +417 -0
- aiecs/tools/apisource/intelligence/search_enhancer.py +385 -0
- aiecs/tools/apisource/monitoring/__init__.py +9 -0
- aiecs/tools/apisource/monitoring/metrics.py +330 -0
- aiecs/tools/apisource/providers/__init__.py +112 -0
- aiecs/tools/apisource/providers/base.py +671 -0
- aiecs/tools/apisource/providers/census.py +397 -0
- aiecs/tools/apisource/providers/fred.py +535 -0
- aiecs/tools/apisource/providers/newsapi.py +409 -0
- aiecs/tools/apisource/providers/worldbank.py +352 -0
- aiecs/tools/apisource/reliability/__init__.py +12 -0
- aiecs/tools/apisource/reliability/error_handler.py +363 -0
- aiecs/tools/apisource/reliability/fallback_strategy.py +376 -0
- aiecs/tools/apisource/tool.py +832 -0
- aiecs/tools/apisource/utils/__init__.py +9 -0
- aiecs/tools/apisource/utils/validators.py +334 -0
- aiecs/tools/base_tool.py +415 -21
- aiecs/tools/docs/__init__.py +121 -0
- aiecs/tools/docs/ai_document_orchestrator.py +607 -0
- aiecs/tools/docs/ai_document_writer_orchestrator.py +2350 -0
- aiecs/tools/docs/content_insertion_tool.py +1320 -0
- aiecs/tools/docs/document_creator_tool.py +1323 -0
- aiecs/tools/docs/document_layout_tool.py +1160 -0
- aiecs/tools/docs/document_parser_tool.py +1011 -0
- aiecs/tools/docs/document_writer_tool.py +1829 -0
- aiecs/tools/knowledge_graph/__init__.py +17 -0
- aiecs/tools/knowledge_graph/graph_reasoning_tool.py +807 -0
- aiecs/tools/knowledge_graph/graph_search_tool.py +944 -0
- aiecs/tools/knowledge_graph/kg_builder_tool.py +524 -0
- aiecs/tools/langchain_adapter.py +300 -138
- aiecs/tools/schema_generator.py +455 -0
- aiecs/tools/search_tool/__init__.py +100 -0
- aiecs/tools/search_tool/analyzers.py +581 -0
- aiecs/tools/search_tool/cache.py +264 -0
- aiecs/tools/search_tool/constants.py +128 -0
- aiecs/tools/search_tool/context.py +224 -0
- aiecs/tools/search_tool/core.py +778 -0
- aiecs/tools/search_tool/deduplicator.py +119 -0
- aiecs/tools/search_tool/error_handler.py +242 -0
- aiecs/tools/search_tool/metrics.py +343 -0
- aiecs/tools/search_tool/rate_limiter.py +172 -0
- aiecs/tools/search_tool/schemas.py +275 -0
- aiecs/tools/statistics/__init__.py +80 -0
- aiecs/tools/statistics/ai_data_analysis_orchestrator.py +646 -0
- aiecs/tools/statistics/ai_insight_generator_tool.py +508 -0
- aiecs/tools/statistics/ai_report_orchestrator_tool.py +684 -0
- aiecs/tools/statistics/data_loader_tool.py +555 -0
- aiecs/tools/statistics/data_profiler_tool.py +638 -0
- aiecs/tools/statistics/data_transformer_tool.py +580 -0
- aiecs/tools/statistics/data_visualizer_tool.py +498 -0
- aiecs/tools/statistics/model_trainer_tool.py +507 -0
- aiecs/tools/statistics/statistical_analyzer_tool.py +472 -0
- aiecs/tools/task_tools/__init__.py +49 -36
- aiecs/tools/task_tools/chart_tool.py +200 -184
- aiecs/tools/task_tools/classfire_tool.py +268 -267
- aiecs/tools/task_tools/image_tool.py +175 -131
- aiecs/tools/task_tools/office_tool.py +226 -146
- aiecs/tools/task_tools/pandas_tool.py +477 -121
- aiecs/tools/task_tools/report_tool.py +390 -142
- aiecs/tools/task_tools/research_tool.py +149 -79
- aiecs/tools/task_tools/scraper_tool.py +339 -145
- aiecs/tools/task_tools/stats_tool.py +448 -209
- aiecs/tools/temp_file_manager.py +26 -24
- aiecs/tools/tool_executor/__init__.py +18 -16
- aiecs/tools/tool_executor/tool_executor.py +364 -52
- aiecs/utils/LLM_output_structor.py +74 -48
- aiecs/utils/__init__.py +14 -3
- aiecs/utils/base_callback.py +0 -3
- aiecs/utils/cache_provider.py +696 -0
- aiecs/utils/execution_utils.py +50 -31
- aiecs/utils/prompt_loader.py +1 -0
- aiecs/utils/token_usage_repository.py +37 -11
- aiecs/ws/socket_server.py +14 -4
- {aiecs-1.0.1.dist-info → aiecs-1.7.6.dist-info}/METADATA +52 -15
- aiecs-1.7.6.dist-info/RECORD +337 -0
- aiecs-1.7.6.dist-info/entry_points.txt +13 -0
- aiecs/config/registry.py +0 -19
- aiecs/domain/context/content_engine.py +0 -982
- aiecs/llm/base_client.py +0 -99
- aiecs/llm/openai_client.py +0 -125
- aiecs/llm/vertex_client.py +0 -186
- aiecs/llm/xai_client.py +0 -184
- aiecs/scripts/dependency_checker.py +0 -857
- aiecs/scripts/quick_dependency_check.py +0 -269
- aiecs/tools/task_tools/search_api.py +0 -7
- aiecs-1.0.1.dist-info/RECORD +0 -90
- aiecs-1.0.1.dist-info/entry_points.txt +0 -7
- /aiecs/scripts/{setup_nlp_data.sh → dependance_check/setup_nlp_data.sh} +0 -0
- /aiecs/scripts/{README_WEASEL_PATCH.md → dependance_patch/fix_weasel/README_WEASEL_PATCH.md} +0 -0
- /aiecs/scripts/{fix_weasel_validator.sh → dependance_patch/fix_weasel/fix_weasel_validator.sh} +0 -0
- /aiecs/scripts/{run_weasel_patch.sh → dependance_patch/fix_weasel/run_weasel_patch.sh} +0 -0
- {aiecs-1.0.1.dist-info → aiecs-1.7.6.dist-info}/WHEEL +0 -0
- {aiecs-1.0.1.dist-info → aiecs-1.7.6.dist-info}/licenses/LICENSE +0 -0
- {aiecs-1.0.1.dist-info → aiecs-1.7.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,457 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Google Function Calling Mixin
|
|
3
|
+
|
|
4
|
+
Provides shared implementation for Google providers (Vertex AI, Google AI)
|
|
5
|
+
that use FunctionDeclaration format for Function Calling.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Dict, Any, Optional, List, Union, AsyncGenerator
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from vertexai.generative_models import (
|
|
12
|
+
FunctionDeclaration,
|
|
13
|
+
Tool,
|
|
14
|
+
)
|
|
15
|
+
from google.genai.types import Schema, Type
|
|
16
|
+
|
|
17
|
+
from .base_client import LLMMessage, LLMResponse
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
# Import StreamChunk from OpenAI mixin for compatibility
|
|
22
|
+
try:
|
|
23
|
+
from .openai_compatible_mixin import StreamChunk
|
|
24
|
+
except ImportError:
|
|
25
|
+
# Fallback if not available
|
|
26
|
+
@dataclass
|
|
27
|
+
class StreamChunk:
|
|
28
|
+
"""Fallback StreamChunk definition"""
|
|
29
|
+
type: str
|
|
30
|
+
content: Optional[str] = None
|
|
31
|
+
tool_call: Optional[Dict[str, Any]] = None
|
|
32
|
+
tool_calls: Optional[List[Dict[str, Any]]] = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class GoogleFunctionCallingMixin:
|
|
36
|
+
"""
|
|
37
|
+
Mixin class providing Google Function Calling implementation.
|
|
38
|
+
|
|
39
|
+
This mixin can be used by Google providers (Vertex AI, Google AI)
|
|
40
|
+
that use FunctionDeclaration format for Function Calling.
|
|
41
|
+
|
|
42
|
+
Usage:
|
|
43
|
+
class VertexAIClient(BaseLLMClient, GoogleFunctionCallingMixin):
|
|
44
|
+
async def generate_text(self, messages, tools=None, ...):
|
|
45
|
+
if tools:
|
|
46
|
+
vertex_tools = self._convert_openai_to_google_format(tools)
|
|
47
|
+
# Use in API call
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def _convert_openai_to_google_format(
|
|
51
|
+
self, tools: List[Dict[str, Any]]
|
|
52
|
+
) -> List[Tool]:
|
|
53
|
+
"""
|
|
54
|
+
Convert OpenAI tools format to Google FunctionDeclaration format.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
tools: List of OpenAI-format tool dictionaries
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
List of Google Tool objects containing FunctionDeclaration
|
|
61
|
+
"""
|
|
62
|
+
function_declarations = []
|
|
63
|
+
|
|
64
|
+
for tool in tools:
|
|
65
|
+
if tool.get("type") == "function":
|
|
66
|
+
func = tool.get("function", {})
|
|
67
|
+
func_name = func.get("name", "")
|
|
68
|
+
func_description = func.get("description", "")
|
|
69
|
+
func_parameters = func.get("parameters", {})
|
|
70
|
+
|
|
71
|
+
if not func_name:
|
|
72
|
+
logger.warning(f"Skipping tool without name: {tool}")
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
# Convert parameters schema
|
|
76
|
+
google_schema = self._convert_json_schema_to_google_schema(func_parameters)
|
|
77
|
+
|
|
78
|
+
# Create FunctionDeclaration
|
|
79
|
+
function_declaration = FunctionDeclaration(
|
|
80
|
+
name=func_name,
|
|
81
|
+
description=func_description,
|
|
82
|
+
parameters=google_schema,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
function_declarations.append(function_declaration)
|
|
86
|
+
else:
|
|
87
|
+
logger.warning(f"Unsupported tool type: {tool.get('type')}")
|
|
88
|
+
|
|
89
|
+
# Wrap in Tool objects (Google format requires tools to be wrapped)
|
|
90
|
+
if function_declarations:
|
|
91
|
+
return [Tool(function_declarations=function_declarations)]
|
|
92
|
+
return []
|
|
93
|
+
|
|
94
|
+
def _convert_json_schema_to_google_schema(
|
|
95
|
+
self, json_schema: Dict[str, Any]
|
|
96
|
+
) -> Schema:
|
|
97
|
+
"""
|
|
98
|
+
Convert JSON Schema to Google Schema format.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
json_schema: JSON Schema dictionary
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Google Schema object
|
|
105
|
+
"""
|
|
106
|
+
schema_type = json_schema.get("type", "object")
|
|
107
|
+
properties = json_schema.get("properties", {})
|
|
108
|
+
required = json_schema.get("required", [])
|
|
109
|
+
|
|
110
|
+
# Convert type
|
|
111
|
+
google_type = self._convert_json_type_to_google_type(schema_type)
|
|
112
|
+
|
|
113
|
+
# Convert properties (only for object types)
|
|
114
|
+
google_properties = None
|
|
115
|
+
if schema_type == "object" and properties:
|
|
116
|
+
google_properties = {}
|
|
117
|
+
for prop_name, prop_schema in properties.items():
|
|
118
|
+
google_properties[prop_name] = self._convert_json_schema_to_google_schema(
|
|
119
|
+
prop_schema
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Handle array items
|
|
123
|
+
items = None
|
|
124
|
+
if schema_type == "array" and "items" in json_schema:
|
|
125
|
+
items = self._convert_json_schema_to_google_schema(json_schema["items"])
|
|
126
|
+
|
|
127
|
+
# Create Schema
|
|
128
|
+
schema_kwargs = {
|
|
129
|
+
"type": google_type,
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
if google_properties is not None:
|
|
133
|
+
schema_kwargs["properties"] = google_properties
|
|
134
|
+
|
|
135
|
+
if required:
|
|
136
|
+
schema_kwargs["required"] = required
|
|
137
|
+
|
|
138
|
+
if items is not None:
|
|
139
|
+
schema_kwargs["items"] = items
|
|
140
|
+
|
|
141
|
+
schema = Schema(**schema_kwargs)
|
|
142
|
+
|
|
143
|
+
return schema
|
|
144
|
+
|
|
145
|
+
def _convert_json_type_to_google_type(self, json_type: str) -> Type:
|
|
146
|
+
"""
|
|
147
|
+
Convert JSON Schema type to Google Type enum.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
json_type: JSON Schema type string
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Google Type enum value
|
|
154
|
+
"""
|
|
155
|
+
type_mapping = {
|
|
156
|
+
"string": Type.STRING,
|
|
157
|
+
"number": Type.NUMBER,
|
|
158
|
+
"integer": Type.NUMBER, # Google uses NUMBER for both
|
|
159
|
+
"boolean": Type.BOOLEAN,
|
|
160
|
+
"array": Type.ARRAY,
|
|
161
|
+
"object": Type.OBJECT,
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
return type_mapping.get(json_type.lower(), Type.OBJECT)
|
|
165
|
+
|
|
166
|
+
def _extract_function_calls_from_google_response(
|
|
167
|
+
self, response: Any
|
|
168
|
+
) -> Optional[List[Dict[str, Any]]]:
|
|
169
|
+
"""
|
|
170
|
+
Extract function calls from Google Vertex AI response.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
response: Response object from Google Vertex AI API
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
List of function call dictionaries in OpenAI-compatible format,
|
|
177
|
+
or None if no function calls found
|
|
178
|
+
"""
|
|
179
|
+
function_calls = []
|
|
180
|
+
|
|
181
|
+
# Check for function calls in response
|
|
182
|
+
# Google Vertex AI returns function calls in different places depending on API version
|
|
183
|
+
if hasattr(response, "candidates") and response.candidates:
|
|
184
|
+
candidate = response.candidates[0]
|
|
185
|
+
|
|
186
|
+
# Check for function_call attribute (older API)
|
|
187
|
+
if hasattr(candidate, "function_call") and candidate.function_call:
|
|
188
|
+
func_call = candidate.function_call
|
|
189
|
+
function_calls.append({
|
|
190
|
+
"id": f"call_{len(function_calls)}",
|
|
191
|
+
"type": "function",
|
|
192
|
+
"function": {
|
|
193
|
+
"name": func_call.name,
|
|
194
|
+
"arguments": str(func_call.args) if hasattr(func_call, "args") else "{}",
|
|
195
|
+
},
|
|
196
|
+
})
|
|
197
|
+
|
|
198
|
+
# Check for content.parts with function_call (newer API)
|
|
199
|
+
elif hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
|
|
200
|
+
for part in candidate.content.parts:
|
|
201
|
+
if hasattr(part, "function_call") and part.function_call:
|
|
202
|
+
func_call = part.function_call
|
|
203
|
+
function_calls.append({
|
|
204
|
+
"id": f"call_{len(function_calls)}",
|
|
205
|
+
"type": "function",
|
|
206
|
+
"function": {
|
|
207
|
+
"name": func_call.name,
|
|
208
|
+
"arguments": str(func_call.args) if hasattr(func_call, "args") else "{}",
|
|
209
|
+
},
|
|
210
|
+
})
|
|
211
|
+
|
|
212
|
+
return function_calls if function_calls else None
|
|
213
|
+
|
|
214
|
+
def _attach_function_calls_to_response(
|
|
215
|
+
self,
|
|
216
|
+
response: LLMResponse,
|
|
217
|
+
function_calls: Optional[List[Dict[str, Any]]] = None,
|
|
218
|
+
) -> LLMResponse:
|
|
219
|
+
"""
|
|
220
|
+
Attach function call information to LLMResponse.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
response: LLMResponse object
|
|
224
|
+
function_calls: List of function call dictionaries
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
LLMResponse with function call info attached
|
|
228
|
+
"""
|
|
229
|
+
if function_calls:
|
|
230
|
+
setattr(response, "tool_calls", function_calls)
|
|
231
|
+
return response
|
|
232
|
+
|
|
233
|
+
def _convert_messages_to_google_format(
|
|
234
|
+
self, messages: List[LLMMessage]
|
|
235
|
+
) -> List[Dict[str, Any]]:
|
|
236
|
+
"""
|
|
237
|
+
Convert LLMMessage list to Google message format.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
messages: List of LLMMessage objects
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
List of Google-format message dictionaries
|
|
244
|
+
"""
|
|
245
|
+
google_messages = []
|
|
246
|
+
|
|
247
|
+
for msg in messages:
|
|
248
|
+
# Google format uses "role" and "parts" structure
|
|
249
|
+
parts = []
|
|
250
|
+
|
|
251
|
+
if msg.content:
|
|
252
|
+
parts.append({"text": msg.content})
|
|
253
|
+
|
|
254
|
+
# Handle tool responses (role="tool")
|
|
255
|
+
if msg.role == "tool" and msg.tool_call_id:
|
|
256
|
+
# Google format uses function_response
|
|
257
|
+
# Note: This may need adjustment based on actual API format
|
|
258
|
+
if msg.content:
|
|
259
|
+
parts.append({
|
|
260
|
+
"function_response": {
|
|
261
|
+
"name": msg.tool_call_id, # May need mapping
|
|
262
|
+
"response": {"result": msg.content},
|
|
263
|
+
}
|
|
264
|
+
})
|
|
265
|
+
|
|
266
|
+
if parts:
|
|
267
|
+
google_messages.append({
|
|
268
|
+
"role": msg.role if msg.role != "tool" else "model", # Adjust role mapping
|
|
269
|
+
"parts": parts,
|
|
270
|
+
})
|
|
271
|
+
|
|
272
|
+
return google_messages
|
|
273
|
+
|
|
274
|
+
def _extract_function_calls_from_google_chunk(
|
|
275
|
+
self, chunk: Any
|
|
276
|
+
) -> Optional[List[Dict[str, Any]]]:
|
|
277
|
+
"""
|
|
278
|
+
Extract function calls from Google Vertex AI streaming chunk.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
chunk: Streaming chunk object from Google Vertex AI API
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
List of function call dictionaries in OpenAI-compatible format,
|
|
285
|
+
or None if no function calls found
|
|
286
|
+
"""
|
|
287
|
+
function_calls = []
|
|
288
|
+
|
|
289
|
+
# Check for function calls in chunk
|
|
290
|
+
if hasattr(chunk, "candidates") and chunk.candidates:
|
|
291
|
+
candidate = chunk.candidates[0]
|
|
292
|
+
|
|
293
|
+
# Check for content.parts with function_call
|
|
294
|
+
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
|
|
295
|
+
for part in candidate.content.parts:
|
|
296
|
+
if hasattr(part, "function_call") and part.function_call:
|
|
297
|
+
func_call = part.function_call
|
|
298
|
+
function_calls.append({
|
|
299
|
+
"id": f"call_{len(function_calls)}",
|
|
300
|
+
"type": "function",
|
|
301
|
+
"function": {
|
|
302
|
+
"name": func_call.name,
|
|
303
|
+
"arguments": str(func_call.args) if hasattr(func_call, "args") else "{}",
|
|
304
|
+
},
|
|
305
|
+
})
|
|
306
|
+
|
|
307
|
+
# Check for function_call attribute directly on candidate
|
|
308
|
+
elif hasattr(candidate, "function_call") and candidate.function_call:
|
|
309
|
+
func_call = candidate.function_call
|
|
310
|
+
function_calls.append({
|
|
311
|
+
"id": f"call_{len(function_calls)}",
|
|
312
|
+
"type": "function",
|
|
313
|
+
"function": {
|
|
314
|
+
"name": func_call.name,
|
|
315
|
+
"arguments": str(func_call.args) if hasattr(func_call, "args") else "{}",
|
|
316
|
+
},
|
|
317
|
+
})
|
|
318
|
+
|
|
319
|
+
return function_calls if function_calls else None
|
|
320
|
+
|
|
321
|
+
async def _stream_text_with_function_calling(
|
|
322
|
+
self,
|
|
323
|
+
model_instance: Any,
|
|
324
|
+
contents: Any,
|
|
325
|
+
generation_config: Any,
|
|
326
|
+
safety_settings: List[Any],
|
|
327
|
+
tools: Optional[List[Tool]] = None,
|
|
328
|
+
return_chunks: bool = False,
|
|
329
|
+
**kwargs,
|
|
330
|
+
) -> AsyncGenerator[Union[str, StreamChunk], None]:
|
|
331
|
+
"""
|
|
332
|
+
Stream text with Function Calling support (Google Vertex AI format).
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
model_instance: GenerativeModel instance (should include system_instruction)
|
|
336
|
+
contents: Input contents (string or list of Content objects)
|
|
337
|
+
generation_config: GenerationConfig object
|
|
338
|
+
safety_settings: List of SafetySetting objects
|
|
339
|
+
tools: List of Tool objects (Google format)
|
|
340
|
+
return_chunks: If True, returns StreamChunk objects; if False, returns str tokens only
|
|
341
|
+
**kwargs: Additional arguments
|
|
342
|
+
|
|
343
|
+
Yields:
|
|
344
|
+
str or StreamChunk: Text tokens or StreamChunk objects
|
|
345
|
+
"""
|
|
346
|
+
# Build API call parameters
|
|
347
|
+
api_params = {
|
|
348
|
+
"contents": contents,
|
|
349
|
+
"generation_config": generation_config,
|
|
350
|
+
"safety_settings": safety_settings,
|
|
351
|
+
"stream": True,
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
# Add tools if available
|
|
355
|
+
if tools:
|
|
356
|
+
api_params["tools"] = tools
|
|
357
|
+
|
|
358
|
+
# Add any additional kwargs
|
|
359
|
+
api_params.update(kwargs)
|
|
360
|
+
|
|
361
|
+
# Get streaming response
|
|
362
|
+
stream_response = model_instance.generate_content(**api_params)
|
|
363
|
+
|
|
364
|
+
# Accumulator for tool calls
|
|
365
|
+
tool_calls_accumulator: Dict[str, Dict[str, Any]] = {}
|
|
366
|
+
|
|
367
|
+
# Stream chunks
|
|
368
|
+
import asyncio
|
|
369
|
+
first_chunk_checked = False
|
|
370
|
+
|
|
371
|
+
for chunk in stream_response:
|
|
372
|
+
# Yield control to event loop
|
|
373
|
+
await asyncio.sleep(0)
|
|
374
|
+
|
|
375
|
+
# Check for prompt-level safety blocks
|
|
376
|
+
if not first_chunk_checked and hasattr(chunk, "prompt_feedback"):
|
|
377
|
+
pf = chunk.prompt_feedback
|
|
378
|
+
if hasattr(pf, "block_reason") and pf.block_reason:
|
|
379
|
+
block_reason = str(pf.block_reason)
|
|
380
|
+
if block_reason not in ["BLOCKED_REASON_UNSPECIFIED", "OTHER"]:
|
|
381
|
+
from .base_client import SafetyBlockError
|
|
382
|
+
raise SafetyBlockError(
|
|
383
|
+
"Prompt blocked by safety filters",
|
|
384
|
+
block_reason=block_reason,
|
|
385
|
+
block_type="prompt",
|
|
386
|
+
)
|
|
387
|
+
first_chunk_checked = True
|
|
388
|
+
|
|
389
|
+
# Extract text content and function calls
|
|
390
|
+
if hasattr(chunk, "candidates") and chunk.candidates:
|
|
391
|
+
candidate = chunk.candidates[0]
|
|
392
|
+
|
|
393
|
+
# Check for safety blocks in response
|
|
394
|
+
if hasattr(candidate, "finish_reason"):
|
|
395
|
+
finish_reason = candidate.finish_reason
|
|
396
|
+
if finish_reason in ["SAFETY", "RECITATION"]:
|
|
397
|
+
from .base_client import SafetyBlockError
|
|
398
|
+
raise SafetyBlockError(
|
|
399
|
+
"Response blocked by safety filters",
|
|
400
|
+
block_type="response",
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# Extract text from chunk parts
|
|
404
|
+
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
|
|
405
|
+
for part in candidate.content.parts:
|
|
406
|
+
if hasattr(part, "text") and part.text:
|
|
407
|
+
text_content = part.text
|
|
408
|
+
if return_chunks:
|
|
409
|
+
yield StreamChunk(type="token", content=text_content)
|
|
410
|
+
else:
|
|
411
|
+
yield text_content
|
|
412
|
+
|
|
413
|
+
# Also check if text is directly available
|
|
414
|
+
elif hasattr(candidate, "text") and candidate.text:
|
|
415
|
+
text_content = candidate.text
|
|
416
|
+
if return_chunks:
|
|
417
|
+
yield StreamChunk(type="token", content=text_content)
|
|
418
|
+
else:
|
|
419
|
+
yield text_content
|
|
420
|
+
|
|
421
|
+
# Extract and accumulate function calls
|
|
422
|
+
function_calls = self._extract_function_calls_from_google_chunk(chunk)
|
|
423
|
+
if function_calls:
|
|
424
|
+
for func_call in function_calls:
|
|
425
|
+
call_id = func_call["id"]
|
|
426
|
+
|
|
427
|
+
# Initialize accumulator if needed
|
|
428
|
+
if call_id not in tool_calls_accumulator:
|
|
429
|
+
tool_calls_accumulator[call_id] = func_call.copy()
|
|
430
|
+
else:
|
|
431
|
+
# Update accumulator (merge arguments if needed)
|
|
432
|
+
existing_call = tool_calls_accumulator[call_id]
|
|
433
|
+
if func_call["function"]["name"]:
|
|
434
|
+
existing_call["function"]["name"] = func_call["function"]["name"]
|
|
435
|
+
if func_call["function"]["arguments"]:
|
|
436
|
+
# Merge arguments (Google may send partial arguments)
|
|
437
|
+
existing_args = existing_call["function"].get("arguments", "{}")
|
|
438
|
+
new_args = func_call["function"]["arguments"]
|
|
439
|
+
# Simple merge: append new args (may need JSON parsing for proper merge)
|
|
440
|
+
if new_args and new_args != "{}":
|
|
441
|
+
existing_call["function"]["arguments"] = new_args
|
|
442
|
+
|
|
443
|
+
# Yield tool call update if return_chunks=True
|
|
444
|
+
if return_chunks:
|
|
445
|
+
yield StreamChunk(
|
|
446
|
+
type="tool_call",
|
|
447
|
+
tool_call=tool_calls_accumulator[call_id].copy(),
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
# At the end of stream, yield complete tool_calls if any
|
|
451
|
+
if tool_calls_accumulator and return_chunks:
|
|
452
|
+
complete_tool_calls = list(tool_calls_accumulator.values())
|
|
453
|
+
yield StreamChunk(
|
|
454
|
+
type="tool_calls",
|
|
455
|
+
tool_calls=complete_tool_calls,
|
|
456
|
+
)
|
|
457
|
+
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Optional, List, AsyncGenerator
|
|
4
|
+
|
|
5
|
+
from google import genai
|
|
6
|
+
from google.genai import types
|
|
7
|
+
|
|
8
|
+
from aiecs.llm.clients.base_client import (
|
|
9
|
+
BaseLLMClient,
|
|
10
|
+
LLMMessage,
|
|
11
|
+
LLMResponse,
|
|
12
|
+
ProviderNotAvailableError,
|
|
13
|
+
RateLimitError,
|
|
14
|
+
)
|
|
15
|
+
from aiecs.config.config import get_settings
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GoogleAIClient(BaseLLMClient):
|
|
21
|
+
"""Google AI (Gemini) provider client using the new google.genai SDK"""
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
super().__init__("GoogleAI")
|
|
25
|
+
self.settings = get_settings()
|
|
26
|
+
self._initialized = False
|
|
27
|
+
self._client: Optional[genai.Client] = None
|
|
28
|
+
|
|
29
|
+
def _init_google_ai(self) -> genai.Client:
|
|
30
|
+
"""Lazy initialization of Google AI SDK and return the client"""
|
|
31
|
+
if not self._initialized or self._client is None:
|
|
32
|
+
api_key = self.settings.googleai_api_key or os.environ.get("GOOGLEAI_API_KEY")
|
|
33
|
+
if not api_key:
|
|
34
|
+
raise ProviderNotAvailableError("Google AI API key not configured. Set GOOGLEAI_API_KEY.")
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
self._client = genai.Client(api_key=api_key)
|
|
38
|
+
self._initialized = True
|
|
39
|
+
self.logger.info("Google AI SDK (google.genai) initialized successfully.")
|
|
40
|
+
except Exception as e:
|
|
41
|
+
raise ProviderNotAvailableError(f"Failed to initialize Google AI SDK: {str(e)}")
|
|
42
|
+
|
|
43
|
+
return self._client
|
|
44
|
+
|
|
45
|
+
def _convert_messages_to_contents(
|
|
46
|
+
self, messages: List[LLMMessage]
|
|
47
|
+
) -> List[types.Content]:
|
|
48
|
+
"""Convert LLMMessage list to Google GenAI Content objects."""
|
|
49
|
+
contents = []
|
|
50
|
+
for msg in messages:
|
|
51
|
+
# Map 'assistant' role to 'model' for Google AI
|
|
52
|
+
role = "model" if msg.role == "assistant" else msg.role
|
|
53
|
+
contents.append(
|
|
54
|
+
types.Content(role=role, parts=[types.Part(text=msg.content)])
|
|
55
|
+
)
|
|
56
|
+
return contents
|
|
57
|
+
|
|
58
|
+
async def generate_text(
|
|
59
|
+
self,
|
|
60
|
+
messages: List[LLMMessage],
|
|
61
|
+
model: Optional[str] = None,
|
|
62
|
+
temperature: float = 0.7,
|
|
63
|
+
max_tokens: Optional[int] = None,
|
|
64
|
+
system_instruction: Optional[str] = None,
|
|
65
|
+
**kwargs,
|
|
66
|
+
) -> LLMResponse:
|
|
67
|
+
"""Generate text using Google AI (google.genai SDK)"""
|
|
68
|
+
client = self._init_google_ai()
|
|
69
|
+
|
|
70
|
+
# Get model name from config if not provided
|
|
71
|
+
model_name = model or self._get_default_model() or "gemini-2.5-pro"
|
|
72
|
+
|
|
73
|
+
# Get model config for default parameters
|
|
74
|
+
model_config = self._get_model_config(model_name)
|
|
75
|
+
if model_config and max_tokens is None:
|
|
76
|
+
max_tokens = model_config.default_params.max_tokens
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
# Extract system message from messages if not provided
|
|
80
|
+
system_msg = None
|
|
81
|
+
user_messages = []
|
|
82
|
+
for msg in messages:
|
|
83
|
+
if msg.role == "system":
|
|
84
|
+
system_msg = msg.content
|
|
85
|
+
else:
|
|
86
|
+
user_messages.append(msg)
|
|
87
|
+
|
|
88
|
+
# Use provided system_instruction or extracted system message
|
|
89
|
+
final_system_instruction = system_instruction or system_msg
|
|
90
|
+
|
|
91
|
+
# Convert messages to Content objects
|
|
92
|
+
contents = self._convert_messages_to_contents(user_messages)
|
|
93
|
+
|
|
94
|
+
# Create GenerateContentConfig with all settings
|
|
95
|
+
config = types.GenerateContentConfig(
|
|
96
|
+
system_instruction=final_system_instruction,
|
|
97
|
+
temperature=temperature,
|
|
98
|
+
max_output_tokens=max_tokens or 8192,
|
|
99
|
+
top_p=kwargs.get("top_p", 0.95),
|
|
100
|
+
top_k=kwargs.get("top_k", 40),
|
|
101
|
+
safety_settings=[
|
|
102
|
+
types.SafetySetting(
|
|
103
|
+
category="HARM_CATEGORY_HARASSMENT",
|
|
104
|
+
threshold="OFF",
|
|
105
|
+
),
|
|
106
|
+
types.SafetySetting(
|
|
107
|
+
category="HARM_CATEGORY_HATE_SPEECH",
|
|
108
|
+
threshold="OFF",
|
|
109
|
+
),
|
|
110
|
+
types.SafetySetting(
|
|
111
|
+
category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
|
112
|
+
threshold="OFF",
|
|
113
|
+
),
|
|
114
|
+
types.SafetySetting(
|
|
115
|
+
category="HARM_CATEGORY_DANGEROUS_CONTENT",
|
|
116
|
+
threshold="OFF",
|
|
117
|
+
),
|
|
118
|
+
],
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Use async client for async operations
|
|
122
|
+
response = await client.aio.models.generate_content(
|
|
123
|
+
model=model_name,
|
|
124
|
+
contents=contents,
|
|
125
|
+
config=config,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
content = response.text
|
|
129
|
+
prompt_tokens = response.usage_metadata.prompt_token_count
|
|
130
|
+
completion_tokens = response.usage_metadata.candidates_token_count
|
|
131
|
+
total_tokens = response.usage_metadata.total_token_count
|
|
132
|
+
|
|
133
|
+
# Extract cache metadata from Google AI response
|
|
134
|
+
cache_read_tokens = None
|
|
135
|
+
cache_hit = None
|
|
136
|
+
if hasattr(response.usage_metadata, "cached_content_token_count"):
|
|
137
|
+
cache_read_tokens = response.usage_metadata.cached_content_token_count
|
|
138
|
+
cache_hit = cache_read_tokens is not None and cache_read_tokens > 0
|
|
139
|
+
|
|
140
|
+
# Use config-based cost estimation
|
|
141
|
+
cost = self._estimate_cost_from_config(model_name, prompt_tokens, completion_tokens)
|
|
142
|
+
|
|
143
|
+
return LLMResponse(
|
|
144
|
+
content=content,
|
|
145
|
+
provider=self.provider_name,
|
|
146
|
+
model=model_name,
|
|
147
|
+
tokens_used=total_tokens,
|
|
148
|
+
prompt_tokens=prompt_tokens,
|
|
149
|
+
completion_tokens=completion_tokens,
|
|
150
|
+
cost_estimate=cost,
|
|
151
|
+
cache_read_tokens=cache_read_tokens,
|
|
152
|
+
cache_hit=cache_hit,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
except Exception as e:
|
|
156
|
+
if "quota" in str(e).lower():
|
|
157
|
+
raise RateLimitError(f"Google AI quota exceeded: {str(e)}")
|
|
158
|
+
self.logger.error(f"Error generating text with Google AI: {e}")
|
|
159
|
+
raise
|
|
160
|
+
|
|
161
|
+
async def stream_text( # type: ignore[override]
|
|
162
|
+
self,
|
|
163
|
+
messages: List[LLMMessage],
|
|
164
|
+
model: Optional[str] = None,
|
|
165
|
+
temperature: float = 0.7,
|
|
166
|
+
max_tokens: Optional[int] = None,
|
|
167
|
+
system_instruction: Optional[str] = None,
|
|
168
|
+
**kwargs,
|
|
169
|
+
) -> AsyncGenerator[str, None]:
|
|
170
|
+
"""Stream text generation using Google AI (google.genai SDK)"""
|
|
171
|
+
client = self._init_google_ai()
|
|
172
|
+
|
|
173
|
+
# Get model name from config if not provided
|
|
174
|
+
model_name = model or self._get_default_model() or "gemini-2.5-pro"
|
|
175
|
+
|
|
176
|
+
# Get model config for default parameters
|
|
177
|
+
model_config = self._get_model_config(model_name)
|
|
178
|
+
if model_config and max_tokens is None:
|
|
179
|
+
max_tokens = model_config.default_params.max_tokens
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
# Extract system message from messages if not provided
|
|
183
|
+
system_msg = None
|
|
184
|
+
user_messages = []
|
|
185
|
+
for msg in messages:
|
|
186
|
+
if msg.role == "system":
|
|
187
|
+
system_msg = msg.content
|
|
188
|
+
else:
|
|
189
|
+
user_messages.append(msg)
|
|
190
|
+
|
|
191
|
+
# Use provided system_instruction or extracted system message
|
|
192
|
+
final_system_instruction = system_instruction or system_msg
|
|
193
|
+
|
|
194
|
+
# Convert messages to Content objects
|
|
195
|
+
contents = self._convert_messages_to_contents(user_messages)
|
|
196
|
+
|
|
197
|
+
# Create GenerateContentConfig with all settings
|
|
198
|
+
config = types.GenerateContentConfig(
|
|
199
|
+
system_instruction=final_system_instruction,
|
|
200
|
+
temperature=temperature,
|
|
201
|
+
max_output_tokens=max_tokens or 8192,
|
|
202
|
+
top_p=kwargs.get("top_p", 0.95),
|
|
203
|
+
top_k=kwargs.get("top_k", 40),
|
|
204
|
+
safety_settings=[
|
|
205
|
+
types.SafetySetting(
|
|
206
|
+
category="HARM_CATEGORY_HARASSMENT",
|
|
207
|
+
threshold="OFF",
|
|
208
|
+
),
|
|
209
|
+
types.SafetySetting(
|
|
210
|
+
category="HARM_CATEGORY_HATE_SPEECH",
|
|
211
|
+
threshold="OFF",
|
|
212
|
+
),
|
|
213
|
+
types.SafetySetting(
|
|
214
|
+
category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
|
215
|
+
threshold="OFF",
|
|
216
|
+
),
|
|
217
|
+
types.SafetySetting(
|
|
218
|
+
category="HARM_CATEGORY_DANGEROUS_CONTENT",
|
|
219
|
+
threshold="OFF",
|
|
220
|
+
),
|
|
221
|
+
],
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Use async streaming with the new SDK
|
|
225
|
+
async for chunk in client.aio.models.generate_content_stream(
|
|
226
|
+
model=model_name,
|
|
227
|
+
contents=contents,
|
|
228
|
+
config=config,
|
|
229
|
+
):
|
|
230
|
+
if chunk.text:
|
|
231
|
+
yield chunk.text
|
|
232
|
+
|
|
233
|
+
except Exception as e:
|
|
234
|
+
self.logger.error(f"Error streaming text with Google AI: {e}")
|
|
235
|
+
raise
|
|
236
|
+
|
|
237
|
+
async def close(self):
|
|
238
|
+
"""Clean up resources"""
|
|
239
|
+
# Google GenAI SDK does not require explicit closing of a client
|
|
240
|
+
self._initialized = False
|
|
241
|
+
self._client = None
|