aiecs 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiecs/__init__.py +72 -0
- aiecs/__main__.py +41 -0
- aiecs/aiecs_client.py +469 -0
- aiecs/application/__init__.py +10 -0
- aiecs/application/executors/__init__.py +10 -0
- aiecs/application/executors/operation_executor.py +363 -0
- aiecs/application/knowledge_graph/__init__.py +7 -0
- aiecs/application/knowledge_graph/builder/__init__.py +37 -0
- aiecs/application/knowledge_graph/builder/document_builder.py +375 -0
- aiecs/application/knowledge_graph/builder/graph_builder.py +356 -0
- aiecs/application/knowledge_graph/builder/schema_mapping.py +531 -0
- aiecs/application/knowledge_graph/builder/structured_pipeline.py +443 -0
- aiecs/application/knowledge_graph/builder/text_chunker.py +319 -0
- aiecs/application/knowledge_graph/extractors/__init__.py +27 -0
- aiecs/application/knowledge_graph/extractors/base.py +100 -0
- aiecs/application/knowledge_graph/extractors/llm_entity_extractor.py +327 -0
- aiecs/application/knowledge_graph/extractors/llm_relation_extractor.py +349 -0
- aiecs/application/knowledge_graph/extractors/ner_entity_extractor.py +244 -0
- aiecs/application/knowledge_graph/fusion/__init__.py +23 -0
- aiecs/application/knowledge_graph/fusion/entity_deduplicator.py +387 -0
- aiecs/application/knowledge_graph/fusion/entity_linker.py +343 -0
- aiecs/application/knowledge_graph/fusion/knowledge_fusion.py +580 -0
- aiecs/application/knowledge_graph/fusion/relation_deduplicator.py +189 -0
- aiecs/application/knowledge_graph/pattern_matching/__init__.py +21 -0
- aiecs/application/knowledge_graph/pattern_matching/pattern_matcher.py +344 -0
- aiecs/application/knowledge_graph/pattern_matching/query_executor.py +378 -0
- aiecs/application/knowledge_graph/profiling/__init__.py +12 -0
- aiecs/application/knowledge_graph/profiling/query_plan_visualizer.py +199 -0
- aiecs/application/knowledge_graph/profiling/query_profiler.py +223 -0
- aiecs/application/knowledge_graph/reasoning/__init__.py +27 -0
- aiecs/application/knowledge_graph/reasoning/evidence_synthesis.py +347 -0
- aiecs/application/knowledge_graph/reasoning/inference_engine.py +504 -0
- aiecs/application/knowledge_graph/reasoning/logic_form_parser.py +167 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/__init__.py +79 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_builder.py +513 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_nodes.py +630 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_validator.py +654 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/error_handler.py +477 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/parser.py +390 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/query_context.py +217 -0
- aiecs/application/knowledge_graph/reasoning/logic_query_integration.py +169 -0
- aiecs/application/knowledge_graph/reasoning/query_planner.py +872 -0
- aiecs/application/knowledge_graph/reasoning/reasoning_engine.py +554 -0
- aiecs/application/knowledge_graph/retrieval/__init__.py +19 -0
- aiecs/application/knowledge_graph/retrieval/retrieval_strategies.py +596 -0
- aiecs/application/knowledge_graph/search/__init__.py +59 -0
- aiecs/application/knowledge_graph/search/hybrid_search.py +423 -0
- aiecs/application/knowledge_graph/search/reranker.py +295 -0
- aiecs/application/knowledge_graph/search/reranker_strategies.py +553 -0
- aiecs/application/knowledge_graph/search/text_similarity.py +398 -0
- aiecs/application/knowledge_graph/traversal/__init__.py +15 -0
- aiecs/application/knowledge_graph/traversal/enhanced_traversal.py +329 -0
- aiecs/application/knowledge_graph/traversal/path_scorer.py +269 -0
- aiecs/application/knowledge_graph/validators/__init__.py +13 -0
- aiecs/application/knowledge_graph/validators/relation_validator.py +189 -0
- aiecs/application/knowledge_graph/visualization/__init__.py +11 -0
- aiecs/application/knowledge_graph/visualization/graph_visualizer.py +321 -0
- aiecs/common/__init__.py +9 -0
- aiecs/common/knowledge_graph/__init__.py +17 -0
- aiecs/common/knowledge_graph/runnable.py +484 -0
- aiecs/config/__init__.py +16 -0
- aiecs/config/config.py +498 -0
- aiecs/config/graph_config.py +137 -0
- aiecs/config/registry.py +23 -0
- aiecs/core/__init__.py +46 -0
- aiecs/core/interface/__init__.py +34 -0
- aiecs/core/interface/execution_interface.py +152 -0
- aiecs/core/interface/storage_interface.py +171 -0
- aiecs/domain/__init__.py +289 -0
- aiecs/domain/agent/__init__.py +189 -0
- aiecs/domain/agent/base_agent.py +697 -0
- aiecs/domain/agent/exceptions.py +103 -0
- aiecs/domain/agent/graph_aware_mixin.py +559 -0
- aiecs/domain/agent/hybrid_agent.py +490 -0
- aiecs/domain/agent/integration/__init__.py +26 -0
- aiecs/domain/agent/integration/context_compressor.py +222 -0
- aiecs/domain/agent/integration/context_engine_adapter.py +252 -0
- aiecs/domain/agent/integration/retry_policy.py +219 -0
- aiecs/domain/agent/integration/role_config.py +213 -0
- aiecs/domain/agent/knowledge_aware_agent.py +646 -0
- aiecs/domain/agent/lifecycle.py +296 -0
- aiecs/domain/agent/llm_agent.py +300 -0
- aiecs/domain/agent/memory/__init__.py +12 -0
- aiecs/domain/agent/memory/conversation.py +197 -0
- aiecs/domain/agent/migration/__init__.py +14 -0
- aiecs/domain/agent/migration/conversion.py +160 -0
- aiecs/domain/agent/migration/legacy_wrapper.py +90 -0
- aiecs/domain/agent/models.py +317 -0
- aiecs/domain/agent/observability.py +407 -0
- aiecs/domain/agent/persistence.py +289 -0
- aiecs/domain/agent/prompts/__init__.py +29 -0
- aiecs/domain/agent/prompts/builder.py +161 -0
- aiecs/domain/agent/prompts/formatters.py +189 -0
- aiecs/domain/agent/prompts/template.py +255 -0
- aiecs/domain/agent/registry.py +260 -0
- aiecs/domain/agent/tool_agent.py +257 -0
- aiecs/domain/agent/tools/__init__.py +12 -0
- aiecs/domain/agent/tools/schema_generator.py +221 -0
- aiecs/domain/community/__init__.py +155 -0
- aiecs/domain/community/agent_adapter.py +477 -0
- aiecs/domain/community/analytics.py +481 -0
- aiecs/domain/community/collaborative_workflow.py +642 -0
- aiecs/domain/community/communication_hub.py +645 -0
- aiecs/domain/community/community_builder.py +320 -0
- aiecs/domain/community/community_integration.py +800 -0
- aiecs/domain/community/community_manager.py +813 -0
- aiecs/domain/community/decision_engine.py +879 -0
- aiecs/domain/community/exceptions.py +225 -0
- aiecs/domain/community/models/__init__.py +33 -0
- aiecs/domain/community/models/community_models.py +268 -0
- aiecs/domain/community/resource_manager.py +457 -0
- aiecs/domain/community/shared_context_manager.py +603 -0
- aiecs/domain/context/__init__.py +58 -0
- aiecs/domain/context/context_engine.py +989 -0
- aiecs/domain/context/conversation_models.py +354 -0
- aiecs/domain/context/graph_memory.py +467 -0
- aiecs/domain/execution/__init__.py +12 -0
- aiecs/domain/execution/model.py +57 -0
- aiecs/domain/knowledge_graph/__init__.py +19 -0
- aiecs/domain/knowledge_graph/models/__init__.py +52 -0
- aiecs/domain/knowledge_graph/models/entity.py +130 -0
- aiecs/domain/knowledge_graph/models/evidence.py +194 -0
- aiecs/domain/knowledge_graph/models/inference_rule.py +186 -0
- aiecs/domain/knowledge_graph/models/path.py +179 -0
- aiecs/domain/knowledge_graph/models/path_pattern.py +173 -0
- aiecs/domain/knowledge_graph/models/query.py +272 -0
- aiecs/domain/knowledge_graph/models/query_plan.py +187 -0
- aiecs/domain/knowledge_graph/models/relation.py +136 -0
- aiecs/domain/knowledge_graph/schema/__init__.py +23 -0
- aiecs/domain/knowledge_graph/schema/entity_type.py +135 -0
- aiecs/domain/knowledge_graph/schema/graph_schema.py +271 -0
- aiecs/domain/knowledge_graph/schema/property_schema.py +155 -0
- aiecs/domain/knowledge_graph/schema/relation_type.py +171 -0
- aiecs/domain/knowledge_graph/schema/schema_manager.py +496 -0
- aiecs/domain/knowledge_graph/schema/type_enums.py +205 -0
- aiecs/domain/task/__init__.py +13 -0
- aiecs/domain/task/dsl_processor.py +613 -0
- aiecs/domain/task/model.py +62 -0
- aiecs/domain/task/task_context.py +268 -0
- aiecs/infrastructure/__init__.py +24 -0
- aiecs/infrastructure/graph_storage/__init__.py +11 -0
- aiecs/infrastructure/graph_storage/base.py +601 -0
- aiecs/infrastructure/graph_storage/batch_operations.py +449 -0
- aiecs/infrastructure/graph_storage/cache.py +429 -0
- aiecs/infrastructure/graph_storage/distributed.py +226 -0
- aiecs/infrastructure/graph_storage/error_handling.py +390 -0
- aiecs/infrastructure/graph_storage/graceful_degradation.py +306 -0
- aiecs/infrastructure/graph_storage/health_checks.py +378 -0
- aiecs/infrastructure/graph_storage/in_memory.py +514 -0
- aiecs/infrastructure/graph_storage/index_optimization.py +483 -0
- aiecs/infrastructure/graph_storage/lazy_loading.py +410 -0
- aiecs/infrastructure/graph_storage/metrics.py +357 -0
- aiecs/infrastructure/graph_storage/migration.py +413 -0
- aiecs/infrastructure/graph_storage/pagination.py +471 -0
- aiecs/infrastructure/graph_storage/performance_monitoring.py +466 -0
- aiecs/infrastructure/graph_storage/postgres.py +871 -0
- aiecs/infrastructure/graph_storage/query_optimizer.py +635 -0
- aiecs/infrastructure/graph_storage/schema_cache.py +290 -0
- aiecs/infrastructure/graph_storage/sqlite.py +623 -0
- aiecs/infrastructure/graph_storage/streaming.py +495 -0
- aiecs/infrastructure/messaging/__init__.py +13 -0
- aiecs/infrastructure/messaging/celery_task_manager.py +383 -0
- aiecs/infrastructure/messaging/websocket_manager.py +298 -0
- aiecs/infrastructure/monitoring/__init__.py +34 -0
- aiecs/infrastructure/monitoring/executor_metrics.py +174 -0
- aiecs/infrastructure/monitoring/global_metrics_manager.py +213 -0
- aiecs/infrastructure/monitoring/structured_logger.py +48 -0
- aiecs/infrastructure/monitoring/tracing_manager.py +410 -0
- aiecs/infrastructure/persistence/__init__.py +24 -0
- aiecs/infrastructure/persistence/context_engine_client.py +187 -0
- aiecs/infrastructure/persistence/database_manager.py +333 -0
- aiecs/infrastructure/persistence/file_storage.py +754 -0
- aiecs/infrastructure/persistence/redis_client.py +220 -0
- aiecs/llm/__init__.py +86 -0
- aiecs/llm/callbacks/__init__.py +11 -0
- aiecs/llm/callbacks/custom_callbacks.py +264 -0
- aiecs/llm/client_factory.py +420 -0
- aiecs/llm/clients/__init__.py +33 -0
- aiecs/llm/clients/base_client.py +193 -0
- aiecs/llm/clients/googleai_client.py +181 -0
- aiecs/llm/clients/openai_client.py +131 -0
- aiecs/llm/clients/vertex_client.py +437 -0
- aiecs/llm/clients/xai_client.py +184 -0
- aiecs/llm/config/__init__.py +51 -0
- aiecs/llm/config/config_loader.py +275 -0
- aiecs/llm/config/config_validator.py +236 -0
- aiecs/llm/config/model_config.py +151 -0
- aiecs/llm/utils/__init__.py +10 -0
- aiecs/llm/utils/validate_config.py +91 -0
- aiecs/main.py +363 -0
- aiecs/scripts/__init__.py +3 -0
- aiecs/scripts/aid/VERSION_MANAGEMENT.md +97 -0
- aiecs/scripts/aid/__init__.py +19 -0
- aiecs/scripts/aid/version_manager.py +215 -0
- aiecs/scripts/dependance_check/DEPENDENCY_SYSTEM_SUMMARY.md +242 -0
- aiecs/scripts/dependance_check/README_DEPENDENCY_CHECKER.md +310 -0
- aiecs/scripts/dependance_check/__init__.py +17 -0
- aiecs/scripts/dependance_check/dependency_checker.py +938 -0
- aiecs/scripts/dependance_check/dependency_fixer.py +391 -0
- aiecs/scripts/dependance_check/download_nlp_data.py +396 -0
- aiecs/scripts/dependance_check/quick_dependency_check.py +270 -0
- aiecs/scripts/dependance_check/setup_nlp_data.sh +217 -0
- aiecs/scripts/dependance_patch/__init__.py +7 -0
- aiecs/scripts/dependance_patch/fix_weasel/README_WEASEL_PATCH.md +126 -0
- aiecs/scripts/dependance_patch/fix_weasel/__init__.py +11 -0
- aiecs/scripts/dependance_patch/fix_weasel/fix_weasel_validator.py +128 -0
- aiecs/scripts/dependance_patch/fix_weasel/fix_weasel_validator.sh +82 -0
- aiecs/scripts/dependance_patch/fix_weasel/patch_weasel_library.sh +188 -0
- aiecs/scripts/dependance_patch/fix_weasel/run_weasel_patch.sh +41 -0
- aiecs/scripts/tools_develop/README.md +449 -0
- aiecs/scripts/tools_develop/TOOL_AUTO_DISCOVERY.md +234 -0
- aiecs/scripts/tools_develop/__init__.py +21 -0
- aiecs/scripts/tools_develop/check_type_annotations.py +259 -0
- aiecs/scripts/tools_develop/validate_tool_schemas.py +422 -0
- aiecs/scripts/tools_develop/verify_tools.py +356 -0
- aiecs/tasks/__init__.py +1 -0
- aiecs/tasks/worker.py +172 -0
- aiecs/tools/__init__.py +299 -0
- aiecs/tools/apisource/__init__.py +99 -0
- aiecs/tools/apisource/intelligence/__init__.py +19 -0
- aiecs/tools/apisource/intelligence/data_fusion.py +381 -0
- aiecs/tools/apisource/intelligence/query_analyzer.py +413 -0
- aiecs/tools/apisource/intelligence/search_enhancer.py +388 -0
- aiecs/tools/apisource/monitoring/__init__.py +9 -0
- aiecs/tools/apisource/monitoring/metrics.py +303 -0
- aiecs/tools/apisource/providers/__init__.py +115 -0
- aiecs/tools/apisource/providers/base.py +664 -0
- aiecs/tools/apisource/providers/census.py +401 -0
- aiecs/tools/apisource/providers/fred.py +564 -0
- aiecs/tools/apisource/providers/newsapi.py +412 -0
- aiecs/tools/apisource/providers/worldbank.py +357 -0
- aiecs/tools/apisource/reliability/__init__.py +12 -0
- aiecs/tools/apisource/reliability/error_handler.py +375 -0
- aiecs/tools/apisource/reliability/fallback_strategy.py +391 -0
- aiecs/tools/apisource/tool.py +850 -0
- aiecs/tools/apisource/utils/__init__.py +9 -0
- aiecs/tools/apisource/utils/validators.py +338 -0
- aiecs/tools/base_tool.py +201 -0
- aiecs/tools/docs/__init__.py +121 -0
- aiecs/tools/docs/ai_document_orchestrator.py +599 -0
- aiecs/tools/docs/ai_document_writer_orchestrator.py +2403 -0
- aiecs/tools/docs/content_insertion_tool.py +1333 -0
- aiecs/tools/docs/document_creator_tool.py +1317 -0
- aiecs/tools/docs/document_layout_tool.py +1166 -0
- aiecs/tools/docs/document_parser_tool.py +994 -0
- aiecs/tools/docs/document_writer_tool.py +1818 -0
- aiecs/tools/knowledge_graph/__init__.py +17 -0
- aiecs/tools/knowledge_graph/graph_reasoning_tool.py +734 -0
- aiecs/tools/knowledge_graph/graph_search_tool.py +923 -0
- aiecs/tools/knowledge_graph/kg_builder_tool.py +476 -0
- aiecs/tools/langchain_adapter.py +542 -0
- aiecs/tools/schema_generator.py +275 -0
- aiecs/tools/search_tool/__init__.py +100 -0
- aiecs/tools/search_tool/analyzers.py +589 -0
- aiecs/tools/search_tool/cache.py +260 -0
- aiecs/tools/search_tool/constants.py +128 -0
- aiecs/tools/search_tool/context.py +216 -0
- aiecs/tools/search_tool/core.py +749 -0
- aiecs/tools/search_tool/deduplicator.py +123 -0
- aiecs/tools/search_tool/error_handler.py +271 -0
- aiecs/tools/search_tool/metrics.py +371 -0
- aiecs/tools/search_tool/rate_limiter.py +178 -0
- aiecs/tools/search_tool/schemas.py +277 -0
- aiecs/tools/statistics/__init__.py +80 -0
- aiecs/tools/statistics/ai_data_analysis_orchestrator.py +643 -0
- aiecs/tools/statistics/ai_insight_generator_tool.py +505 -0
- aiecs/tools/statistics/ai_report_orchestrator_tool.py +694 -0
- aiecs/tools/statistics/data_loader_tool.py +564 -0
- aiecs/tools/statistics/data_profiler_tool.py +658 -0
- aiecs/tools/statistics/data_transformer_tool.py +573 -0
- aiecs/tools/statistics/data_visualizer_tool.py +495 -0
- aiecs/tools/statistics/model_trainer_tool.py +487 -0
- aiecs/tools/statistics/statistical_analyzer_tool.py +459 -0
- aiecs/tools/task_tools/__init__.py +86 -0
- aiecs/tools/task_tools/chart_tool.py +732 -0
- aiecs/tools/task_tools/classfire_tool.py +922 -0
- aiecs/tools/task_tools/image_tool.py +447 -0
- aiecs/tools/task_tools/office_tool.py +684 -0
- aiecs/tools/task_tools/pandas_tool.py +635 -0
- aiecs/tools/task_tools/report_tool.py +635 -0
- aiecs/tools/task_tools/research_tool.py +392 -0
- aiecs/tools/task_tools/scraper_tool.py +715 -0
- aiecs/tools/task_tools/stats_tool.py +688 -0
- aiecs/tools/temp_file_manager.py +130 -0
- aiecs/tools/tool_executor/__init__.py +37 -0
- aiecs/tools/tool_executor/tool_executor.py +881 -0
- aiecs/utils/LLM_output_structor.py +445 -0
- aiecs/utils/__init__.py +34 -0
- aiecs/utils/base_callback.py +47 -0
- aiecs/utils/cache_provider.py +695 -0
- aiecs/utils/execution_utils.py +184 -0
- aiecs/utils/logging.py +1 -0
- aiecs/utils/prompt_loader.py +14 -0
- aiecs/utils/token_usage_repository.py +323 -0
- aiecs/ws/__init__.py +0 -0
- aiecs/ws/socket_server.py +52 -0
- aiecs-1.5.1.dist-info/METADATA +608 -0
- aiecs-1.5.1.dist-info/RECORD +302 -0
- aiecs-1.5.1.dist-info/WHEEL +5 -0
- aiecs-1.5.1.dist-info/entry_points.txt +10 -0
- aiecs-1.5.1.dist-info/licenses/LICENSE +225 -0
- aiecs-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import Dict, Any, Optional, List, AsyncGenerator
|
|
6
|
+
import vertexai
|
|
7
|
+
from vertexai.generative_models import (
|
|
8
|
+
GenerativeModel,
|
|
9
|
+
HarmCategory,
|
|
10
|
+
HarmBlockThreshold,
|
|
11
|
+
GenerationConfig,
|
|
12
|
+
SafetySetting,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from aiecs.llm.clients.base_client import (
|
|
16
|
+
BaseLLMClient,
|
|
17
|
+
LLMMessage,
|
|
18
|
+
LLMResponse,
|
|
19
|
+
ProviderNotAvailableError,
|
|
20
|
+
RateLimitError,
|
|
21
|
+
)
|
|
22
|
+
from aiecs.config.config import get_settings
|
|
23
|
+
|
|
24
|
+
# Suppress Vertex AI SDK deprecation warnings (deprecated June 2025, removal June 2026)
|
|
25
|
+
# TODO: Migrate to Google Gen AI SDK when official migration guide is available
|
|
26
|
+
warnings.filterwarnings(
|
|
27
|
+
"ignore",
|
|
28
|
+
category=UserWarning,
|
|
29
|
+
module="vertexai.generative_models._generative_models",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class VertexAIClient(BaseLLMClient):
|
|
36
|
+
"""Vertex AI provider client"""
|
|
37
|
+
|
|
38
|
+
def __init__(self):
|
|
39
|
+
super().__init__("Vertex")
|
|
40
|
+
self.settings = get_settings()
|
|
41
|
+
self._initialized = False
|
|
42
|
+
# Track part count statistics for monitoring
|
|
43
|
+
self._part_count_stats = {
|
|
44
|
+
"total_responses": 0,
|
|
45
|
+
"part_counts": {}, # {part_count: frequency}
|
|
46
|
+
"last_part_count": None,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
def _init_vertex_ai(self):
|
|
50
|
+
"""Lazy initialization of Vertex AI with proper authentication"""
|
|
51
|
+
if not self._initialized:
|
|
52
|
+
if not self.settings.vertex_project_id:
|
|
53
|
+
raise ProviderNotAvailableError("Vertex AI project ID not configured")
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
# Set up Google Cloud authentication
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
# Check if GOOGLE_APPLICATION_CREDENTIALS is configured
|
|
60
|
+
if self.settings.google_application_credentials:
|
|
61
|
+
credentials_path = self.settings.google_application_credentials
|
|
62
|
+
if os.path.exists(credentials_path):
|
|
63
|
+
# Set the environment variable for Google Cloud SDK
|
|
64
|
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_path
|
|
65
|
+
self.logger.info(f"Using Google Cloud credentials from: {credentials_path}")
|
|
66
|
+
else:
|
|
67
|
+
self.logger.warning(
|
|
68
|
+
f"Google Cloud credentials file not found: {credentials_path}"
|
|
69
|
+
)
|
|
70
|
+
raise ProviderNotAvailableError(
|
|
71
|
+
f"Google Cloud credentials file not found: {credentials_path}"
|
|
72
|
+
)
|
|
73
|
+
elif "GOOGLE_APPLICATION_CREDENTIALS" in os.environ:
|
|
74
|
+
self.logger.info("Using Google Cloud credentials from environment variable")
|
|
75
|
+
else:
|
|
76
|
+
self.logger.warning(
|
|
77
|
+
"No Google Cloud credentials configured. Using default authentication."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Initialize Vertex AI
|
|
81
|
+
vertexai.init(
|
|
82
|
+
project=self.settings.vertex_project_id,
|
|
83
|
+
location=getattr(self.settings, "vertex_location", "us-central1"),
|
|
84
|
+
)
|
|
85
|
+
self._initialized = True
|
|
86
|
+
self.logger.info(
|
|
87
|
+
f"Vertex AI initialized for project {self.settings.vertex_project_id}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
except Exception as e:
|
|
91
|
+
raise ProviderNotAvailableError(f"Failed to initialize Vertex AI: {str(e)}")
|
|
92
|
+
|
|
93
|
+
async def generate_text(
|
|
94
|
+
self,
|
|
95
|
+
messages: List[LLMMessage],
|
|
96
|
+
model: Optional[str] = None,
|
|
97
|
+
temperature: float = 0.7,
|
|
98
|
+
max_tokens: Optional[int] = None,
|
|
99
|
+
**kwargs,
|
|
100
|
+
) -> LLMResponse:
|
|
101
|
+
"""Generate text using Vertex AI"""
|
|
102
|
+
self._init_vertex_ai()
|
|
103
|
+
|
|
104
|
+
# Get model name from config if not provided
|
|
105
|
+
model_name = model or self._get_default_model() or "gemini-2.5-pro"
|
|
106
|
+
|
|
107
|
+
# Get model config for default parameters
|
|
108
|
+
model_config = self._get_model_config(model_name)
|
|
109
|
+
if model_config and max_tokens is None:
|
|
110
|
+
max_tokens = model_config.default_params.max_tokens
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
# Use the stable Vertex AI API
|
|
114
|
+
model_instance = GenerativeModel(model_name)
|
|
115
|
+
self.logger.debug(f"Initialized Vertex AI model: {model_name}")
|
|
116
|
+
|
|
117
|
+
# Convert messages to Vertex AI format
|
|
118
|
+
if len(messages) == 1 and messages[0].role == "user":
|
|
119
|
+
prompt = messages[0].content
|
|
120
|
+
else:
|
|
121
|
+
# For multi-turn conversations, combine messages
|
|
122
|
+
prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in messages])
|
|
123
|
+
|
|
124
|
+
# Use modern GenerationConfig object
|
|
125
|
+
generation_config = GenerationConfig(
|
|
126
|
+
temperature=temperature,
|
|
127
|
+
# Increased to account for thinking tokens
|
|
128
|
+
max_output_tokens=max_tokens or 8192,
|
|
129
|
+
top_p=0.95,
|
|
130
|
+
top_k=40,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Modern safety settings configuration using SafetySetting objects
|
|
134
|
+
safety_settings = [
|
|
135
|
+
SafetySetting(
|
|
136
|
+
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
|
|
137
|
+
threshold=HarmBlockThreshold.BLOCK_NONE,
|
|
138
|
+
),
|
|
139
|
+
SafetySetting(
|
|
140
|
+
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
|
141
|
+
threshold=HarmBlockThreshold.BLOCK_NONE,
|
|
142
|
+
),
|
|
143
|
+
SafetySetting(
|
|
144
|
+
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
|
145
|
+
threshold=HarmBlockThreshold.BLOCK_NONE,
|
|
146
|
+
),
|
|
147
|
+
SafetySetting(
|
|
148
|
+
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
|
149
|
+
threshold=HarmBlockThreshold.BLOCK_NONE,
|
|
150
|
+
),
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
response = await asyncio.get_event_loop().run_in_executor(
|
|
154
|
+
None,
|
|
155
|
+
lambda: model_instance.generate_content(
|
|
156
|
+
prompt,
|
|
157
|
+
generation_config=generation_config,
|
|
158
|
+
safety_settings=safety_settings,
|
|
159
|
+
),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Handle response content safely - improved multi-part response
|
|
163
|
+
# handling
|
|
164
|
+
content = None
|
|
165
|
+
try:
|
|
166
|
+
# First try to get text directly
|
|
167
|
+
content = response.text
|
|
168
|
+
self.logger.debug(f"Vertex AI response received: {content[:100]}...")
|
|
169
|
+
except (ValueError, AttributeError) as ve:
|
|
170
|
+
# Handle multi-part responses and other issues
|
|
171
|
+
self.logger.warning(f"Cannot get response text directly: {str(ve)}")
|
|
172
|
+
|
|
173
|
+
# Try to extract content from candidates with multi-part
|
|
174
|
+
# support
|
|
175
|
+
if hasattr(response, "candidates") and response.candidates:
|
|
176
|
+
candidate = response.candidates[0]
|
|
177
|
+
self.logger.debug(
|
|
178
|
+
f"Candidate finish_reason: {getattr(candidate, 'finish_reason', 'unknown')}"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Handle multi-part content
|
|
182
|
+
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
|
|
183
|
+
try:
|
|
184
|
+
# Extract text from all parts
|
|
185
|
+
text_parts = []
|
|
186
|
+
for part in candidate.content.parts:
|
|
187
|
+
if hasattr(part, "text") and part.text:
|
|
188
|
+
text_parts.append(part.text)
|
|
189
|
+
|
|
190
|
+
if text_parts:
|
|
191
|
+
# Log part count for monitoring
|
|
192
|
+
part_count = len(text_parts)
|
|
193
|
+
self.logger.info(
|
|
194
|
+
f"📊 Vertex AI response: {part_count} parts detected"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Update statistics
|
|
198
|
+
self._part_count_stats["total_responses"] += 1
|
|
199
|
+
self._part_count_stats["part_counts"][part_count] = (
|
|
200
|
+
self._part_count_stats["part_counts"].get(part_count, 0) + 1
|
|
201
|
+
)
|
|
202
|
+
self._part_count_stats["last_part_count"] = part_count
|
|
203
|
+
|
|
204
|
+
# Log statistics if significant variation
|
|
205
|
+
# detected
|
|
206
|
+
if part_count != self._part_count_stats.get(
|
|
207
|
+
"last_part_count", part_count
|
|
208
|
+
):
|
|
209
|
+
self.logger.warning(
|
|
210
|
+
f"⚠️ Part count variation detected: {part_count} parts (previous: {self._part_count_stats.get('last_part_count', 'unknown')})"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Handle multi-part response format
|
|
214
|
+
if len(text_parts) > 1:
|
|
215
|
+
# Multi-part response
|
|
216
|
+
# Minimal fix: only fix incomplete <thinking> tags, preserve original order
|
|
217
|
+
# Do NOT reorganize content - let
|
|
218
|
+
# downstream code handle semantics
|
|
219
|
+
|
|
220
|
+
processed_parts = []
|
|
221
|
+
fixed_count = 0
|
|
222
|
+
|
|
223
|
+
for i, part in enumerate(text_parts):
|
|
224
|
+
# Check for thinking content that needs
|
|
225
|
+
# formatting
|
|
226
|
+
needs_thinking_format = False
|
|
227
|
+
|
|
228
|
+
if "<thinking>" in part and "</thinking>" not in part:
|
|
229
|
+
# Incomplete <thinking> tag: add
|
|
230
|
+
# closing tag
|
|
231
|
+
part = part + "\n</thinking>"
|
|
232
|
+
needs_thinking_format = True
|
|
233
|
+
self.logger.debug(
|
|
234
|
+
f" Part {i+1}: Incomplete <thinking> tag fixed"
|
|
235
|
+
)
|
|
236
|
+
elif (
|
|
237
|
+
part.startswith("thinking")
|
|
238
|
+
and "</thinking>" not in part
|
|
239
|
+
):
|
|
240
|
+
# thinking\n format: convert to
|
|
241
|
+
# <thinking>...</thinking>
|
|
242
|
+
if part.startswith("thinking\n"):
|
|
243
|
+
# thinking\n格式:提取内容并包装
|
|
244
|
+
# 跳过 "thinking\n"
|
|
245
|
+
content = part[8:]
|
|
246
|
+
else:
|
|
247
|
+
# thinking开头但无换行:提取内容并包装
|
|
248
|
+
# 跳过 "thinking"
|
|
249
|
+
content = part[7:]
|
|
250
|
+
|
|
251
|
+
part = f"<thinking>\n{content}\n</thinking>"
|
|
252
|
+
needs_thinking_format = True
|
|
253
|
+
self.logger.debug(
|
|
254
|
+
f" Part {i+1}: thinking\\n format converted to <thinking> tags"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
if needs_thinking_format:
|
|
258
|
+
fixed_count += 1
|
|
259
|
+
|
|
260
|
+
processed_parts.append(part)
|
|
261
|
+
|
|
262
|
+
# Merge in original order
|
|
263
|
+
content = "\n".join(processed_parts)
|
|
264
|
+
|
|
265
|
+
if fixed_count > 0:
|
|
266
|
+
self.logger.info(
|
|
267
|
+
f"✅ Multi-part response merged: {len(text_parts)} parts, {fixed_count} incomplete tags fixed, order preserved"
|
|
268
|
+
)
|
|
269
|
+
else:
|
|
270
|
+
self.logger.info(
|
|
271
|
+
f"✅ Multi-part response merged: {len(text_parts)} parts, order preserved"
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
# Single part response - use as is
|
|
275
|
+
content = text_parts[0]
|
|
276
|
+
self.logger.info("Successfully extracted single-part response")
|
|
277
|
+
else:
|
|
278
|
+
self.logger.warning("No text content found in multi-part response")
|
|
279
|
+
except Exception as part_error:
|
|
280
|
+
self.logger.error(
|
|
281
|
+
f"Failed to extract content from multi-part response: {str(part_error)}"
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# If still no content, check finish reason
|
|
285
|
+
if not content:
|
|
286
|
+
if hasattr(candidate, "finish_reason"):
|
|
287
|
+
if candidate.finish_reason == "MAX_TOKENS":
|
|
288
|
+
content = "[Response truncated due to token limit - consider increasing max_tokens for Gemini 2.5 models]"
|
|
289
|
+
self.logger.warning(
|
|
290
|
+
"Response truncated due to MAX_TOKENS - Gemini 2.5 uses thinking tokens"
|
|
291
|
+
)
|
|
292
|
+
elif candidate.finish_reason in [
|
|
293
|
+
"SAFETY",
|
|
294
|
+
"RECITATION",
|
|
295
|
+
]:
|
|
296
|
+
content = "[Response blocked by safety filters or content policy]"
|
|
297
|
+
self.logger.warning(
|
|
298
|
+
f"Response blocked by safety filters: {candidate.finish_reason}"
|
|
299
|
+
)
|
|
300
|
+
else:
|
|
301
|
+
content = f"[Response error: Cannot get response text - {candidate.finish_reason}]"
|
|
302
|
+
else:
|
|
303
|
+
content = "[Response error: Cannot get the response text]"
|
|
304
|
+
else:
|
|
305
|
+
content = f"[Response error: No candidates found - {str(ve)}]"
|
|
306
|
+
|
|
307
|
+
# Final fallback
|
|
308
|
+
if not content:
|
|
309
|
+
content = "[Response error: Cannot get the response text. Multiple content parts are not supported.]"
|
|
310
|
+
|
|
311
|
+
# Vertex AI doesn't provide detailed token usage in the response
|
|
312
|
+
input_tokens = self._count_tokens_estimate(prompt)
|
|
313
|
+
output_tokens = self._count_tokens_estimate(content)
|
|
314
|
+
tokens_used = input_tokens + output_tokens
|
|
315
|
+
|
|
316
|
+
# Use config-based cost estimation
|
|
317
|
+
cost = self._estimate_cost_from_config(model_name, input_tokens, output_tokens)
|
|
318
|
+
|
|
319
|
+
return LLMResponse(
|
|
320
|
+
content=content,
|
|
321
|
+
provider=self.provider_name,
|
|
322
|
+
model=model_name,
|
|
323
|
+
tokens_used=tokens_used,
|
|
324
|
+
cost_estimate=cost,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
except Exception as e:
|
|
328
|
+
if "quota" in str(e).lower() or "limit" in str(e).lower():
|
|
329
|
+
raise RateLimitError(f"Vertex AI quota exceeded: {str(e)}")
|
|
330
|
+
# Handle specific Vertex AI response errors
|
|
331
|
+
if any(
|
|
332
|
+
keyword in str(e).lower()
|
|
333
|
+
for keyword in [
|
|
334
|
+
"cannot get the response text",
|
|
335
|
+
"safety filters",
|
|
336
|
+
"multiple content parts are not supported",
|
|
337
|
+
"cannot get the candidate text",
|
|
338
|
+
]
|
|
339
|
+
):
|
|
340
|
+
self.logger.warning(f"Vertex AI response issue: {str(e)}")
|
|
341
|
+
# Return a response indicating the issue
|
|
342
|
+
return LLMResponse(
|
|
343
|
+
content="[Response unavailable due to content processing issues or safety filters]",
|
|
344
|
+
provider=self.provider_name,
|
|
345
|
+
model=model_name,
|
|
346
|
+
tokens_used=self._count_tokens_estimate(prompt),
|
|
347
|
+
cost_estimate=0.0,
|
|
348
|
+
)
|
|
349
|
+
raise
|
|
350
|
+
|
|
351
|
+
async def stream_text(
|
|
352
|
+
self,
|
|
353
|
+
messages: List[LLMMessage],
|
|
354
|
+
model: Optional[str] = None,
|
|
355
|
+
temperature: float = 0.7,
|
|
356
|
+
max_tokens: Optional[int] = None,
|
|
357
|
+
**kwargs,
|
|
358
|
+
) -> AsyncGenerator[str, None]:
|
|
359
|
+
"""Stream text using Vertex AI (simulated streaming)"""
|
|
360
|
+
# Vertex AI streaming is more complex, for now fall back to
|
|
361
|
+
# non-streaming
|
|
362
|
+
response = await self.generate_text(messages, model, temperature, max_tokens, **kwargs)
|
|
363
|
+
|
|
364
|
+
# Simulate streaming by yielding words
|
|
365
|
+
words = response.content.split()
|
|
366
|
+
for word in words:
|
|
367
|
+
yield word + " "
|
|
368
|
+
await asyncio.sleep(0.05) # Small delay to simulate streaming
|
|
369
|
+
|
|
370
|
+
def get_part_count_stats(self) -> Dict[str, Any]:
|
|
371
|
+
"""
|
|
372
|
+
Get statistics about part count variations in Vertex AI responses.
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
Dictionary containing part count statistics and analysis
|
|
376
|
+
"""
|
|
377
|
+
stats = self._part_count_stats.copy()
|
|
378
|
+
|
|
379
|
+
if stats["total_responses"] > 0:
|
|
380
|
+
# Calculate variation metrics
|
|
381
|
+
part_counts = list(stats["part_counts"].keys())
|
|
382
|
+
stats["variation_analysis"] = {
|
|
383
|
+
"unique_part_counts": len(part_counts),
|
|
384
|
+
"most_common_count": (
|
|
385
|
+
max(stats["part_counts"].items(), key=lambda x: x[1])[0]
|
|
386
|
+
if stats["part_counts"]
|
|
387
|
+
else None
|
|
388
|
+
),
|
|
389
|
+
"part_count_range": (
|
|
390
|
+
f"{min(part_counts)}-{max(part_counts)}" if part_counts else "N/A"
|
|
391
|
+
),
|
|
392
|
+
# 0-1, higher is more stable
|
|
393
|
+
"stability_score": 1.0 - (len(part_counts) - 1) / max(stats["total_responses"], 1),
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
# Generate recommendations
|
|
397
|
+
if stats["variation_analysis"]["stability_score"] < 0.7:
|
|
398
|
+
stats["recommendations"] = [
|
|
399
|
+
"High part count variation detected",
|
|
400
|
+
"Consider optimizing prompt structure",
|
|
401
|
+
"Monitor input complexity patterns",
|
|
402
|
+
"Review tool calling configuration",
|
|
403
|
+
]
|
|
404
|
+
else:
|
|
405
|
+
stats["recommendations"] = [
|
|
406
|
+
"Part count variation is within acceptable range",
|
|
407
|
+
"Continue monitoring for patterns",
|
|
408
|
+
]
|
|
409
|
+
|
|
410
|
+
return stats
|
|
411
|
+
|
|
412
|
+
def log_part_count_summary(self):
|
|
413
|
+
"""Log a summary of part count statistics"""
|
|
414
|
+
stats = self.get_part_count_stats()
|
|
415
|
+
|
|
416
|
+
if stats["total_responses"] > 0:
|
|
417
|
+
self.logger.info("📈 Vertex AI Part Count Summary:")
|
|
418
|
+
self.logger.info(f" Total responses: {stats['total_responses']}")
|
|
419
|
+
self.logger.info(f" Part count distribution: {stats['part_counts']}")
|
|
420
|
+
|
|
421
|
+
if "variation_analysis" in stats:
|
|
422
|
+
analysis = stats["variation_analysis"]
|
|
423
|
+
self.logger.info(f" Stability score: {analysis['stability_score']:.2f}")
|
|
424
|
+
self.logger.info(f" Most common count: {analysis['most_common_count']}")
|
|
425
|
+
self.logger.info(f" Count range: {analysis['part_count_range']}")
|
|
426
|
+
|
|
427
|
+
if "recommendations" in stats:
|
|
428
|
+
self.logger.info(" Recommendations:")
|
|
429
|
+
for rec in stats["recommendations"]:
|
|
430
|
+
self.logger.info(f" • {rec}")
|
|
431
|
+
|
|
432
|
+
async def close(self):
|
|
433
|
+
"""Clean up resources"""
|
|
434
|
+
# Log final statistics before cleanup
|
|
435
|
+
self.log_part_count_summary()
|
|
436
|
+
# Vertex AI doesn't require explicit cleanup
|
|
437
|
+
self._initialized = False
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from openai import AsyncOpenAI
|
|
2
|
+
from aiecs.config.config import get_settings
|
|
3
|
+
from aiecs.llm.clients.base_client import (
|
|
4
|
+
BaseLLMClient,
|
|
5
|
+
LLMMessage,
|
|
6
|
+
LLMResponse,
|
|
7
|
+
ProviderNotAvailableError,
|
|
8
|
+
RateLimitError,
|
|
9
|
+
)
|
|
10
|
+
from tenacity import (
|
|
11
|
+
retry,
|
|
12
|
+
stop_after_attempt,
|
|
13
|
+
wait_exponential,
|
|
14
|
+
retry_if_exception_type,
|
|
15
|
+
)
|
|
16
|
+
import logging
|
|
17
|
+
from typing import Dict, Optional, List, AsyncGenerator
|
|
18
|
+
|
|
19
|
+
# Lazy import to avoid circular dependency
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _get_config_loader():
|
|
23
|
+
"""Lazy import of config loader to avoid circular dependency"""
|
|
24
|
+
from aiecs.llm.config import get_llm_config_loader
|
|
25
|
+
|
|
26
|
+
return get_llm_config_loader()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class XAIClient(BaseLLMClient):
|
|
33
|
+
"""xAI (Grok) provider client"""
|
|
34
|
+
|
|
35
|
+
def __init__(self):
|
|
36
|
+
super().__init__("xAI")
|
|
37
|
+
self.settings = get_settings()
|
|
38
|
+
self._openai_client: Optional[AsyncOpenAI] = None
|
|
39
|
+
self._model_map: Optional[Dict[str, str]] = None
|
|
40
|
+
|
|
41
|
+
def _get_openai_client(self) -> AsyncOpenAI:
|
|
42
|
+
"""Lazy initialization of OpenAI client for XAI"""
|
|
43
|
+
if not self._openai_client:
|
|
44
|
+
api_key = self._get_api_key()
|
|
45
|
+
self._openai_client = AsyncOpenAI(
|
|
46
|
+
api_key=api_key,
|
|
47
|
+
base_url="https://api.x.ai/v1",
|
|
48
|
+
timeout=360.0, # Override default timeout with longer timeout for reasoning models
|
|
49
|
+
)
|
|
50
|
+
return self._openai_client
|
|
51
|
+
|
|
52
|
+
def _get_api_key(self) -> str:
|
|
53
|
+
"""Get API key with backward compatibility"""
|
|
54
|
+
# Support both xai_api_key and grok_api_key for backward compatibility
|
|
55
|
+
api_key = getattr(self.settings, "xai_api_key", None) or getattr(
|
|
56
|
+
self.settings, "grok_api_key", None
|
|
57
|
+
)
|
|
58
|
+
if not api_key:
|
|
59
|
+
raise ProviderNotAvailableError("xAI API key not configured")
|
|
60
|
+
return api_key
|
|
61
|
+
|
|
62
|
+
def _get_model_map(self) -> Dict[str, str]:
|
|
63
|
+
"""Get model mappings from configuration"""
|
|
64
|
+
if self._model_map is None:
|
|
65
|
+
try:
|
|
66
|
+
loader = _get_config_loader()
|
|
67
|
+
provider_config = loader.get_provider_config("xAI")
|
|
68
|
+
if provider_config and provider_config.model_mappings:
|
|
69
|
+
self._model_map = provider_config.model_mappings
|
|
70
|
+
else:
|
|
71
|
+
self._model_map = {}
|
|
72
|
+
except Exception as e:
|
|
73
|
+
self.logger.warning(f"Failed to load model mappings from config: {e}")
|
|
74
|
+
self._model_map = {}
|
|
75
|
+
return self._model_map
|
|
76
|
+
|
|
77
|
+
@retry(
|
|
78
|
+
stop=stop_after_attempt(3),
|
|
79
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
80
|
+
retry=retry_if_exception_type((Exception, RateLimitError)),
|
|
81
|
+
)
|
|
82
|
+
async def generate_text(
|
|
83
|
+
self,
|
|
84
|
+
messages: List[LLMMessage],
|
|
85
|
+
model: Optional[str] = None,
|
|
86
|
+
temperature: float = 0.7,
|
|
87
|
+
max_tokens: Optional[int] = None,
|
|
88
|
+
**kwargs,
|
|
89
|
+
) -> LLMResponse:
|
|
90
|
+
"""Generate text using xAI API via OpenAI library (supports all Grok models)"""
|
|
91
|
+
# Check API key availability
|
|
92
|
+
api_key = self._get_api_key()
|
|
93
|
+
if not api_key:
|
|
94
|
+
raise ProviderNotAvailableError("xAI API key is not configured.")
|
|
95
|
+
|
|
96
|
+
client = self._get_openai_client()
|
|
97
|
+
|
|
98
|
+
# Get model name from config if not provided
|
|
99
|
+
selected_model = model or self._get_default_model() or "grok-4"
|
|
100
|
+
|
|
101
|
+
# Get model mappings from config
|
|
102
|
+
model_map = self._get_model_map()
|
|
103
|
+
api_model = model_map.get(selected_model, selected_model)
|
|
104
|
+
|
|
105
|
+
# Convert to OpenAI format
|
|
106
|
+
openai_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
completion = await client.chat.completions.create(
|
|
110
|
+
model=api_model,
|
|
111
|
+
messages=openai_messages,
|
|
112
|
+
temperature=temperature,
|
|
113
|
+
max_tokens=max_tokens,
|
|
114
|
+
**kwargs,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
content = completion.choices[0].message.content
|
|
118
|
+
tokens_used = completion.usage.total_tokens if completion.usage else None
|
|
119
|
+
|
|
120
|
+
return LLMResponse(
|
|
121
|
+
content=content,
|
|
122
|
+
provider=self.provider_name,
|
|
123
|
+
model=selected_model,
|
|
124
|
+
tokens_used=tokens_used,
|
|
125
|
+
cost_estimate=0.0, # xAI pricing not available yet
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
if "rate limit" in str(e).lower() or "429" in str(e):
|
|
130
|
+
raise RateLimitError(f"xAI rate limit exceeded: {str(e)}")
|
|
131
|
+
logger.error(f"xAI API error: {str(e)}")
|
|
132
|
+
raise
|
|
133
|
+
|
|
134
|
+
async def stream_text(
|
|
135
|
+
self,
|
|
136
|
+
messages: List[LLMMessage],
|
|
137
|
+
model: Optional[str] = None,
|
|
138
|
+
temperature: float = 0.7,
|
|
139
|
+
max_tokens: Optional[int] = None,
|
|
140
|
+
**kwargs,
|
|
141
|
+
) -> AsyncGenerator[str, None]:
|
|
142
|
+
"""Stream text using xAI API via OpenAI library (supports all Grok models)"""
|
|
143
|
+
# Check API key availability
|
|
144
|
+
api_key = self._get_api_key()
|
|
145
|
+
if not api_key:
|
|
146
|
+
raise ProviderNotAvailableError("xAI API key is not configured.")
|
|
147
|
+
|
|
148
|
+
client = self._get_openai_client()
|
|
149
|
+
|
|
150
|
+
# Get model name from config if not provided
|
|
151
|
+
selected_model = model or self._get_default_model() or "grok-4"
|
|
152
|
+
|
|
153
|
+
# Get model mappings from config
|
|
154
|
+
model_map = self._get_model_map()
|
|
155
|
+
api_model = model_map.get(selected_model, selected_model)
|
|
156
|
+
|
|
157
|
+
# Convert to OpenAI format
|
|
158
|
+
openai_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
stream = await client.chat.completions.create(
|
|
162
|
+
model=api_model,
|
|
163
|
+
messages=openai_messages,
|
|
164
|
+
temperature=temperature,
|
|
165
|
+
max_tokens=max_tokens,
|
|
166
|
+
stream=True,
|
|
167
|
+
**kwargs,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
async for chunk in stream:
|
|
171
|
+
if chunk.choices[0].delta.content is not None:
|
|
172
|
+
yield chunk.choices[0].delta.content
|
|
173
|
+
|
|
174
|
+
except Exception as e:
|
|
175
|
+
if "rate limit" in str(e).lower() or "429" in str(e):
|
|
176
|
+
raise RateLimitError(f"xAI rate limit exceeded: {str(e)}")
|
|
177
|
+
logger.error(f"xAI API streaming error: {str(e)}")
|
|
178
|
+
raise
|
|
179
|
+
|
|
180
|
+
async def close(self):
|
|
181
|
+
"""Clean up resources"""
|
|
182
|
+
if self._openai_client:
|
|
183
|
+
await self._openai_client.close()
|
|
184
|
+
self._openai_client = None
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM Configuration management.
|
|
3
|
+
|
|
4
|
+
This package handles configuration loading, validation, and management
|
|
5
|
+
for all LLM providers and models.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .model_config import (
|
|
9
|
+
ModelCostConfig,
|
|
10
|
+
ModelCapabilities,
|
|
11
|
+
ModelDefaultParams,
|
|
12
|
+
ModelConfig,
|
|
13
|
+
ProviderConfig,
|
|
14
|
+
LLMModelsConfig,
|
|
15
|
+
)
|
|
16
|
+
from .config_loader import (
|
|
17
|
+
LLMConfigLoader,
|
|
18
|
+
get_llm_config_loader,
|
|
19
|
+
get_llm_config,
|
|
20
|
+
reload_llm_config,
|
|
21
|
+
)
|
|
22
|
+
from .config_validator import (
|
|
23
|
+
ConfigValidationError,
|
|
24
|
+
validate_cost_config,
|
|
25
|
+
validate_model_config,
|
|
26
|
+
validate_provider_config,
|
|
27
|
+
validate_llm_config,
|
|
28
|
+
validate_config_file,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
# Configuration models
|
|
33
|
+
"ModelCostConfig",
|
|
34
|
+
"ModelCapabilities",
|
|
35
|
+
"ModelDefaultParams",
|
|
36
|
+
"ModelConfig",
|
|
37
|
+
"ProviderConfig",
|
|
38
|
+
"LLMModelsConfig",
|
|
39
|
+
# Config loader
|
|
40
|
+
"LLMConfigLoader",
|
|
41
|
+
"get_llm_config_loader",
|
|
42
|
+
"get_llm_config",
|
|
43
|
+
"reload_llm_config",
|
|
44
|
+
# Validation
|
|
45
|
+
"ConfigValidationError",
|
|
46
|
+
"validate_cost_config",
|
|
47
|
+
"validate_model_config",
|
|
48
|
+
"validate_provider_config",
|
|
49
|
+
"validate_llm_config",
|
|
50
|
+
"validate_config_file",
|
|
51
|
+
]
|