aiecs 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiecs might be problematic. Click here for more details.
- aiecs/__init__.py +75 -0
- aiecs/__main__.py +41 -0
- aiecs/aiecs_client.py +295 -0
- aiecs/application/__init__.py +10 -0
- aiecs/application/executors/__init__.py +10 -0
- aiecs/application/executors/operation_executor.py +341 -0
- aiecs/config/__init__.py +15 -0
- aiecs/config/config.py +117 -0
- aiecs/config/registry.py +19 -0
- aiecs/core/__init__.py +46 -0
- aiecs/core/interface/__init__.py +34 -0
- aiecs/core/interface/execution_interface.py +150 -0
- aiecs/core/interface/storage_interface.py +214 -0
- aiecs/domain/__init__.py +20 -0
- aiecs/domain/context/__init__.py +28 -0
- aiecs/domain/context/content_engine.py +982 -0
- aiecs/domain/context/conversation_models.py +306 -0
- aiecs/domain/execution/__init__.py +12 -0
- aiecs/domain/execution/model.py +49 -0
- aiecs/domain/task/__init__.py +13 -0
- aiecs/domain/task/dsl_processor.py +460 -0
- aiecs/domain/task/model.py +50 -0
- aiecs/domain/task/task_context.py +257 -0
- aiecs/infrastructure/__init__.py +26 -0
- aiecs/infrastructure/messaging/__init__.py +13 -0
- aiecs/infrastructure/messaging/celery_task_manager.py +341 -0
- aiecs/infrastructure/messaging/websocket_manager.py +289 -0
- aiecs/infrastructure/monitoring/__init__.py +12 -0
- aiecs/infrastructure/monitoring/executor_metrics.py +138 -0
- aiecs/infrastructure/monitoring/structured_logger.py +50 -0
- aiecs/infrastructure/monitoring/tracing_manager.py +376 -0
- aiecs/infrastructure/persistence/__init__.py +12 -0
- aiecs/infrastructure/persistence/database_manager.py +286 -0
- aiecs/infrastructure/persistence/file_storage.py +671 -0
- aiecs/infrastructure/persistence/redis_client.py +162 -0
- aiecs/llm/__init__.py +54 -0
- aiecs/llm/base_client.py +99 -0
- aiecs/llm/client_factory.py +339 -0
- aiecs/llm/custom_callbacks.py +228 -0
- aiecs/llm/openai_client.py +125 -0
- aiecs/llm/vertex_client.py +186 -0
- aiecs/llm/xai_client.py +184 -0
- aiecs/main.py +351 -0
- aiecs/scripts/DEPENDENCY_SYSTEM_SUMMARY.md +241 -0
- aiecs/scripts/README_DEPENDENCY_CHECKER.md +309 -0
- aiecs/scripts/README_WEASEL_PATCH.md +126 -0
- aiecs/scripts/__init__.py +3 -0
- aiecs/scripts/dependency_checker.py +825 -0
- aiecs/scripts/dependency_fixer.py +348 -0
- aiecs/scripts/download_nlp_data.py +348 -0
- aiecs/scripts/fix_weasel_validator.py +121 -0
- aiecs/scripts/fix_weasel_validator.sh +82 -0
- aiecs/scripts/patch_weasel_library.sh +188 -0
- aiecs/scripts/quick_dependency_check.py +269 -0
- aiecs/scripts/run_weasel_patch.sh +41 -0
- aiecs/scripts/setup_nlp_data.sh +217 -0
- aiecs/tasks/__init__.py +2 -0
- aiecs/tasks/worker.py +111 -0
- aiecs/tools/__init__.py +196 -0
- aiecs/tools/base_tool.py +202 -0
- aiecs/tools/langchain_adapter.py +361 -0
- aiecs/tools/task_tools/__init__.py +82 -0
- aiecs/tools/task_tools/chart_tool.py +704 -0
- aiecs/tools/task_tools/classfire_tool.py +901 -0
- aiecs/tools/task_tools/image_tool.py +397 -0
- aiecs/tools/task_tools/office_tool.py +600 -0
- aiecs/tools/task_tools/pandas_tool.py +565 -0
- aiecs/tools/task_tools/report_tool.py +499 -0
- aiecs/tools/task_tools/research_tool.py +363 -0
- aiecs/tools/task_tools/scraper_tool.py +548 -0
- aiecs/tools/task_tools/search_api.py +7 -0
- aiecs/tools/task_tools/stats_tool.py +513 -0
- aiecs/tools/temp_file_manager.py +126 -0
- aiecs/tools/tool_executor/__init__.py +35 -0
- aiecs/tools/tool_executor/tool_executor.py +518 -0
- aiecs/utils/LLM_output_structor.py +409 -0
- aiecs/utils/__init__.py +23 -0
- aiecs/utils/base_callback.py +50 -0
- aiecs/utils/execution_utils.py +158 -0
- aiecs/utils/logging.py +1 -0
- aiecs/utils/prompt_loader.py +13 -0
- aiecs/utils/token_usage_repository.py +279 -0
- aiecs/ws/__init__.py +0 -0
- aiecs/ws/socket_server.py +41 -0
- aiecs-1.0.0.dist-info/METADATA +610 -0
- aiecs-1.0.0.dist-info/RECORD +90 -0
- aiecs-1.0.0.dist-info/WHEEL +5 -0
- aiecs-1.0.0.dist-info/entry_points.txt +7 -0
- aiecs-1.0.0.dist-info/licenses/LICENSE +225 -0
- aiecs-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import redis.asyncio as redis
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
class RedisClient:
|
|
9
|
+
"""Redis client singleton for sharing across different caching strategies"""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
self._client: Optional[redis.Redis] = None
|
|
13
|
+
self._connection_pool: Optional[redis.ConnectionPool] = None
|
|
14
|
+
|
|
15
|
+
async def initialize(self):
|
|
16
|
+
"""Initialize Redis client"""
|
|
17
|
+
try:
|
|
18
|
+
# Get Redis configuration from environment variables
|
|
19
|
+
redis_host = os.getenv('REDIS_HOST', 'localhost')
|
|
20
|
+
redis_port = int(os.getenv('REDIS_PORT', 6379))
|
|
21
|
+
redis_db = int(os.getenv('REDIS_DB', 0))
|
|
22
|
+
redis_password = os.getenv('REDIS_PASSWORD')
|
|
23
|
+
|
|
24
|
+
# Create connection pool
|
|
25
|
+
self._connection_pool = redis.ConnectionPool(
|
|
26
|
+
host=redis_host,
|
|
27
|
+
port=redis_port,
|
|
28
|
+
db=redis_db,
|
|
29
|
+
password=redis_password,
|
|
30
|
+
decode_responses=True,
|
|
31
|
+
max_connections=20,
|
|
32
|
+
retry_on_timeout=True
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Create Redis client
|
|
36
|
+
self._client = redis.Redis(connection_pool=self._connection_pool)
|
|
37
|
+
|
|
38
|
+
# Test connection
|
|
39
|
+
await self._client.ping()
|
|
40
|
+
logger.info(f"Redis client initialized successfully: {redis_host}:{redis_port}/{redis_db}")
|
|
41
|
+
|
|
42
|
+
except Exception as e:
|
|
43
|
+
logger.error(f"Failed to initialize Redis client: {e}")
|
|
44
|
+
raise
|
|
45
|
+
|
|
46
|
+
async def get_client(self) -> redis.Redis:
|
|
47
|
+
"""Get Redis client instance"""
|
|
48
|
+
if self._client is None:
|
|
49
|
+
raise RuntimeError("Redis client not initialized. Call initialize() first.")
|
|
50
|
+
return self._client
|
|
51
|
+
|
|
52
|
+
async def close(self):
|
|
53
|
+
"""Close Redis connection"""
|
|
54
|
+
if self._client:
|
|
55
|
+
await self._client.close()
|
|
56
|
+
self._client = None
|
|
57
|
+
if self._connection_pool:
|
|
58
|
+
await self._connection_pool.disconnect()
|
|
59
|
+
self._connection_pool = None
|
|
60
|
+
logger.info("Redis client closed")
|
|
61
|
+
|
|
62
|
+
async def hincrby(self, name: str, key: str, amount: int = 1) -> int:
|
|
63
|
+
"""Atomically increment hash field"""
|
|
64
|
+
client = await self.get_client()
|
|
65
|
+
return await client.hincrby(name, key, amount)
|
|
66
|
+
|
|
67
|
+
async def hget(self, name: str, key: str) -> Optional[str]:
|
|
68
|
+
"""Get hash field value"""
|
|
69
|
+
client = await self.get_client()
|
|
70
|
+
return await client.hget(name, key)
|
|
71
|
+
|
|
72
|
+
async def hgetall(self, name: str) -> dict:
|
|
73
|
+
"""Get all hash fields"""
|
|
74
|
+
client = await self.get_client()
|
|
75
|
+
return await client.hgetall(name)
|
|
76
|
+
|
|
77
|
+
async def hset(self, name: str, mapping: dict) -> int:
|
|
78
|
+
"""Set hash fields"""
|
|
79
|
+
client = await self.get_client()
|
|
80
|
+
return await client.hset(name, mapping=mapping)
|
|
81
|
+
|
|
82
|
+
async def expire(self, name: str, time: int) -> bool:
|
|
83
|
+
"""Set expiration time"""
|
|
84
|
+
client = await self.get_client()
|
|
85
|
+
return await client.expire(name, time)
|
|
86
|
+
|
|
87
|
+
async def exists(self, name: str) -> bool:
|
|
88
|
+
"""Check if key exists"""
|
|
89
|
+
client = await self.get_client()
|
|
90
|
+
return bool(await client.exists(name))
|
|
91
|
+
|
|
92
|
+
async def ping(self) -> bool:
|
|
93
|
+
"""Test Redis connection"""
|
|
94
|
+
try:
|
|
95
|
+
client = await self.get_client()
|
|
96
|
+
result = await client.ping()
|
|
97
|
+
return result
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.error(f"Redis ping failed: {e}")
|
|
100
|
+
return False
|
|
101
|
+
|
|
102
|
+
async def info(self, section: str = None) -> dict:
|
|
103
|
+
"""Get Redis server information"""
|
|
104
|
+
try:
|
|
105
|
+
client = await self.get_client()
|
|
106
|
+
return await client.info(section)
|
|
107
|
+
except Exception as e:
|
|
108
|
+
logger.error(f"Redis info failed: {e}")
|
|
109
|
+
return {}
|
|
110
|
+
|
|
111
|
+
async def delete(self, *keys) -> int:
|
|
112
|
+
"""Delete one or more keys"""
|
|
113
|
+
try:
|
|
114
|
+
client = await self.get_client()
|
|
115
|
+
return await client.delete(*keys)
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logger.error(f"Redis delete failed: {e}")
|
|
118
|
+
return 0
|
|
119
|
+
|
|
120
|
+
async def set(self, key: str, value: str, ex: int = None) -> bool:
|
|
121
|
+
"""Set a key-value pair with optional expiration"""
|
|
122
|
+
try:
|
|
123
|
+
client = await self.get_client()
|
|
124
|
+
return await client.set(key, value, ex=ex)
|
|
125
|
+
except Exception as e:
|
|
126
|
+
logger.error(f"Redis set failed for key {key}: {e}")
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
async def get(self, key: str) -> Optional[str]:
|
|
130
|
+
"""Get value by key"""
|
|
131
|
+
try:
|
|
132
|
+
client = await self.get_client()
|
|
133
|
+
return await client.get(key)
|
|
134
|
+
except Exception as e:
|
|
135
|
+
logger.error(f"Redis get failed for key {key}: {e}")
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
# ✅ Key changes:
|
|
139
|
+
# 1. No longer create instance immediately.
|
|
140
|
+
# 2. Define a global variable with initial value None. This variable will be populated by lifespan.
|
|
141
|
+
redis_client: Optional[RedisClient] = None
|
|
142
|
+
|
|
143
|
+
# 3. Provide an initialization function for lifespan to call
|
|
144
|
+
async def initialize_redis_client():
|
|
145
|
+
"""Create and initialize global Redis client instance at application startup."""
|
|
146
|
+
global redis_client
|
|
147
|
+
if redis_client is None:
|
|
148
|
+
redis_client = RedisClient()
|
|
149
|
+
await redis_client.initialize()
|
|
150
|
+
|
|
151
|
+
# 4. Provide a close function for lifespan to call
|
|
152
|
+
async def close_redis_client():
|
|
153
|
+
"""Close global Redis client instance at application shutdown."""
|
|
154
|
+
if redis_client:
|
|
155
|
+
await redis_client.close()
|
|
156
|
+
|
|
157
|
+
# For backward compatibility, keep get_redis_client function
|
|
158
|
+
async def get_redis_client() -> RedisClient:
|
|
159
|
+
"""Get global Redis client instance"""
|
|
160
|
+
if redis_client is None:
|
|
161
|
+
raise RuntimeError("Redis client not initialized. Call initialize_redis_client() first.")
|
|
162
|
+
return redis_client
|
aiecs/llm/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM Package - Modular AI Provider Architecture
|
|
3
|
+
|
|
4
|
+
This package provides a unified interface to multiple AI providers through
|
|
5
|
+
individual client implementations and a factory pattern.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# Import all main components
|
|
9
|
+
from .base_client import (
|
|
10
|
+
BaseLLMClient,
|
|
11
|
+
LLMMessage,
|
|
12
|
+
LLMResponse,
|
|
13
|
+
LLMClientError,
|
|
14
|
+
ProviderNotAvailableError,
|
|
15
|
+
RateLimitError
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from .client_factory import (
|
|
19
|
+
AIProvider,
|
|
20
|
+
LLMClientFactory,
|
|
21
|
+
LLMClientManager,
|
|
22
|
+
get_llm_manager,
|
|
23
|
+
generate_text,
|
|
24
|
+
stream_text
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
from .openai_client import OpenAIClient
|
|
28
|
+
from .vertex_client import VertexAIClient
|
|
29
|
+
from .xai_client import XAIClient
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
# Base classes and types
|
|
33
|
+
'BaseLLMClient',
|
|
34
|
+
'LLMMessage',
|
|
35
|
+
'LLMResponse',
|
|
36
|
+
'LLMClientError',
|
|
37
|
+
'ProviderNotAvailableError',
|
|
38
|
+
'RateLimitError',
|
|
39
|
+
'AIProvider',
|
|
40
|
+
|
|
41
|
+
# Factory and manager
|
|
42
|
+
'LLMClientFactory',
|
|
43
|
+
'LLMClientManager',
|
|
44
|
+
'get_llm_manager',
|
|
45
|
+
|
|
46
|
+
# Individual clients
|
|
47
|
+
'OpenAIClient',
|
|
48
|
+
'VertexAIClient',
|
|
49
|
+
'XAIClient',
|
|
50
|
+
|
|
51
|
+
# Convenience functions
|
|
52
|
+
'generate_text',
|
|
53
|
+
'stream_text',
|
|
54
|
+
]
|
aiecs/llm/base_client.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Dict, Any, Optional, List, AsyncGenerator
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
import time
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class LLMMessage:
|
|
11
|
+
role: str # "system", "user", "assistant"
|
|
12
|
+
content: str
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class LLMResponse:
|
|
16
|
+
content: str
|
|
17
|
+
provider: str
|
|
18
|
+
model: str
|
|
19
|
+
tokens_used: Optional[int] = None
|
|
20
|
+
prompt_tokens: Optional[int] = None
|
|
21
|
+
completion_tokens: Optional[int] = None
|
|
22
|
+
cost_estimate: Optional[float] = None
|
|
23
|
+
response_time: Optional[float] = None
|
|
24
|
+
|
|
25
|
+
def __post_init__(self):
|
|
26
|
+
"""Ensure consistency of token data"""
|
|
27
|
+
# If there are detailed token information but no total, calculate the total
|
|
28
|
+
if self.prompt_tokens is not None and self.completion_tokens is not None and self.tokens_used is None:
|
|
29
|
+
self.tokens_used = self.prompt_tokens + self.completion_tokens
|
|
30
|
+
|
|
31
|
+
# If only total is available but no detailed information, try to estimate (cannot accurately allocate in this case)
|
|
32
|
+
elif self.tokens_used is not None and self.prompt_tokens is None and self.completion_tokens is None:
|
|
33
|
+
# In this case we cannot accurately allocate, keep as is
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
class LLMClientError(Exception):
|
|
37
|
+
"""Base exception for LLM client errors"""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
class ProviderNotAvailableError(LLMClientError):
|
|
41
|
+
"""Raised when a provider is not available or misconfigured"""
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
class RateLimitError(LLMClientError):
|
|
45
|
+
"""Raised when rate limit is exceeded"""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
class BaseLLMClient(ABC):
|
|
49
|
+
"""Abstract base class for all LLM provider clients"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, provider_name: str):
|
|
52
|
+
self.provider_name = provider_name
|
|
53
|
+
self.logger = logging.getLogger(f"{__name__}.{provider_name}")
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
async def generate_text(
|
|
57
|
+
self,
|
|
58
|
+
messages: List[LLMMessage],
|
|
59
|
+
model: Optional[str] = None,
|
|
60
|
+
temperature: float = 0.7,
|
|
61
|
+
max_tokens: Optional[int] = None,
|
|
62
|
+
**kwargs
|
|
63
|
+
) -> LLMResponse:
|
|
64
|
+
"""Generate text using the provider's API"""
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
async def stream_text(
|
|
69
|
+
self,
|
|
70
|
+
messages: List[LLMMessage],
|
|
71
|
+
model: Optional[str] = None,
|
|
72
|
+
temperature: float = 0.7,
|
|
73
|
+
max_tokens: Optional[int] = None,
|
|
74
|
+
**kwargs
|
|
75
|
+
) -> AsyncGenerator[str, None]:
|
|
76
|
+
"""Stream text generation using the provider's API"""
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
@abstractmethod
|
|
80
|
+
async def close(self):
|
|
81
|
+
"""Clean up resources"""
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
async def __aenter__(self):
|
|
85
|
+
return self
|
|
86
|
+
|
|
87
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
88
|
+
await self.close()
|
|
89
|
+
|
|
90
|
+
def _count_tokens_estimate(self, text: str) -> int:
|
|
91
|
+
"""Rough token count estimation (4 chars ≈ 1 token for English)"""
|
|
92
|
+
return len(text) // 4
|
|
93
|
+
|
|
94
|
+
def _estimate_cost(self, model: str, input_tokens: int, output_tokens: int, token_costs: Dict) -> float:
|
|
95
|
+
"""Estimate the cost of the API call"""
|
|
96
|
+
if model in token_costs:
|
|
97
|
+
costs = token_costs[model]
|
|
98
|
+
return (input_tokens * costs["input"] + output_tokens * costs["output"]) / 1000
|
|
99
|
+
return 0.0
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, Any, Optional, Union, List
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
from .base_client import BaseLLMClient, LLMMessage, LLMResponse
|
|
6
|
+
from .openai_client import OpenAIClient
|
|
7
|
+
from .vertex_client import VertexAIClient
|
|
8
|
+
from .xai_client import XAIClient
|
|
9
|
+
from ..utils.base_callback import CustomAsyncCallbackHandler
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
class AIProvider(str, Enum):
|
|
14
|
+
OPENAI = "OpenAI"
|
|
15
|
+
VERTEX = "Vertex"
|
|
16
|
+
XAI = "xAI"
|
|
17
|
+
|
|
18
|
+
class LLMClientFactory:
|
|
19
|
+
"""Factory for creating and managing LLM provider clients"""
|
|
20
|
+
|
|
21
|
+
_clients: Dict[AIProvider, BaseLLMClient] = {}
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def get_client(cls, provider: Union[str, AIProvider]) -> BaseLLMClient:
|
|
25
|
+
"""Get or create a client for the specified provider"""
|
|
26
|
+
if isinstance(provider, str):
|
|
27
|
+
try:
|
|
28
|
+
provider = AIProvider(provider)
|
|
29
|
+
except ValueError:
|
|
30
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
31
|
+
|
|
32
|
+
if provider not in cls._clients:
|
|
33
|
+
cls._clients[provider] = cls._create_client(provider)
|
|
34
|
+
|
|
35
|
+
return cls._clients[provider]
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def _create_client(cls, provider: AIProvider) -> BaseLLMClient:
|
|
39
|
+
"""Create a new client instance for the provider"""
|
|
40
|
+
if provider == AIProvider.OPENAI:
|
|
41
|
+
return OpenAIClient()
|
|
42
|
+
elif provider == AIProvider.VERTEX:
|
|
43
|
+
return VertexAIClient()
|
|
44
|
+
elif provider == AIProvider.XAI:
|
|
45
|
+
return XAIClient()
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
async def close_all(cls):
|
|
51
|
+
"""Close all active clients"""
|
|
52
|
+
for client in cls._clients.values():
|
|
53
|
+
try:
|
|
54
|
+
await client.close()
|
|
55
|
+
except Exception as e:
|
|
56
|
+
logger.error(f"Error closing client {client.provider_name}: {e}")
|
|
57
|
+
cls._clients.clear()
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
async def close_client(cls, provider: Union[str, AIProvider]):
|
|
61
|
+
"""Close a specific client"""
|
|
62
|
+
if isinstance(provider, str):
|
|
63
|
+
provider = AIProvider(provider)
|
|
64
|
+
|
|
65
|
+
if provider in cls._clients:
|
|
66
|
+
try:
|
|
67
|
+
await cls._clients[provider].close()
|
|
68
|
+
del cls._clients[provider]
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.error(f"Error closing client {provider}: {e}")
|
|
71
|
+
|
|
72
|
+
class LLMClientManager:
|
|
73
|
+
"""High-level manager for LLM operations with context-aware provider selection"""
|
|
74
|
+
|
|
75
|
+
def __init__(self):
|
|
76
|
+
self.factory = LLMClientFactory()
|
|
77
|
+
|
|
78
|
+
def _extract_ai_preference(self, context: Optional[Dict[str, Any]]) -> tuple[Optional[str], Optional[str]]:
|
|
79
|
+
"""Extract AI provider and model from context"""
|
|
80
|
+
if not context:
|
|
81
|
+
return None, None
|
|
82
|
+
|
|
83
|
+
metadata = context.get('metadata', {})
|
|
84
|
+
|
|
85
|
+
# First, check for aiPreference in metadata
|
|
86
|
+
ai_preference = metadata.get('aiPreference', {})
|
|
87
|
+
if isinstance(ai_preference, dict):
|
|
88
|
+
provider = ai_preference.get('provider')
|
|
89
|
+
model = ai_preference.get('model')
|
|
90
|
+
if provider is not None:
|
|
91
|
+
return provider, model
|
|
92
|
+
|
|
93
|
+
# Fallback to direct provider/model in metadata
|
|
94
|
+
provider = metadata.get('provider')
|
|
95
|
+
model = metadata.get('model')
|
|
96
|
+
return provider, model
|
|
97
|
+
|
|
98
|
+
async def generate_text(
|
|
99
|
+
self,
|
|
100
|
+
messages: Union[str, list[LLMMessage]],
|
|
101
|
+
provider: Optional[Union[str, AIProvider]] = None,
|
|
102
|
+
model: Optional[str] = None,
|
|
103
|
+
context: Optional[Dict[str, Any]] = None,
|
|
104
|
+
temperature: float = 0.7,
|
|
105
|
+
max_tokens: Optional[int] = None,
|
|
106
|
+
callbacks: Optional[List[CustomAsyncCallbackHandler]] = None,
|
|
107
|
+
**kwargs
|
|
108
|
+
) -> LLMResponse:
|
|
109
|
+
"""
|
|
110
|
+
Generate text using context-aware provider selection
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
messages: Either a string prompt or list of LLMMessage objects
|
|
114
|
+
provider: AI provider to use (can be overridden by context)
|
|
115
|
+
model: Specific model to use (can be overridden by context)
|
|
116
|
+
context: TaskContext or dict containing aiPreference
|
|
117
|
+
temperature: Sampling temperature (0.0 to 2.0)
|
|
118
|
+
max_tokens: Maximum tokens to generate
|
|
119
|
+
callbacks: List of callback handlers to execute during LLM calls
|
|
120
|
+
**kwargs: Additional provider-specific parameters
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
LLMResponse object with generated text and metadata
|
|
124
|
+
"""
|
|
125
|
+
# Extract provider/model from context if available
|
|
126
|
+
context_provider, context_model = self._extract_ai_preference(context)
|
|
127
|
+
|
|
128
|
+
# Use context preferences if available, otherwise use provided values
|
|
129
|
+
final_provider = context_provider or provider or AIProvider.OPENAI
|
|
130
|
+
final_model = context_model or model
|
|
131
|
+
|
|
132
|
+
# Convert string prompt to messages format
|
|
133
|
+
if isinstance(messages, str):
|
|
134
|
+
messages = [LLMMessage(role="user", content=messages)]
|
|
135
|
+
|
|
136
|
+
# Execute on_llm_start callbacks
|
|
137
|
+
if callbacks:
|
|
138
|
+
# Convert LLMMessage objects to dictionaries for callbacks
|
|
139
|
+
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
140
|
+
for callback in callbacks:
|
|
141
|
+
try:
|
|
142
|
+
await callback.on_llm_start(messages_dict, provider=final_provider, model=final_model, **kwargs)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error(f"Error in callback on_llm_start: {e}")
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
# Get the appropriate client
|
|
148
|
+
client = self.factory.get_client(final_provider)
|
|
149
|
+
|
|
150
|
+
# Generate text
|
|
151
|
+
response = await client.generate_text(
|
|
152
|
+
messages=messages,
|
|
153
|
+
model=final_model,
|
|
154
|
+
temperature=temperature,
|
|
155
|
+
max_tokens=max_tokens,
|
|
156
|
+
**kwargs
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Execute on_llm_end callbacks
|
|
160
|
+
if callbacks:
|
|
161
|
+
# Convert LLMResponse object to dictionary for callbacks
|
|
162
|
+
response_dict = {
|
|
163
|
+
"content": response.content,
|
|
164
|
+
"provider": response.provider,
|
|
165
|
+
"model": response.model,
|
|
166
|
+
"tokens_used": response.tokens_used,
|
|
167
|
+
"prompt_tokens": response.prompt_tokens,
|
|
168
|
+
"completion_tokens": response.completion_tokens,
|
|
169
|
+
"cost_estimate": response.cost_estimate,
|
|
170
|
+
"response_time": response.response_time
|
|
171
|
+
}
|
|
172
|
+
for callback in callbacks:
|
|
173
|
+
try:
|
|
174
|
+
await callback.on_llm_end(response_dict, provider=final_provider, model=final_model, **kwargs)
|
|
175
|
+
except Exception as e:
|
|
176
|
+
logger.error(f"Error in callback on_llm_end: {e}")
|
|
177
|
+
|
|
178
|
+
logger.info(f"Generated text using {final_provider}/{response.model}")
|
|
179
|
+
return response
|
|
180
|
+
|
|
181
|
+
except Exception as e:
|
|
182
|
+
# Execute on_llm_error callbacks
|
|
183
|
+
if callbacks:
|
|
184
|
+
for callback in callbacks:
|
|
185
|
+
try:
|
|
186
|
+
await callback.on_llm_error(e, provider=final_provider, model=final_model, **kwargs)
|
|
187
|
+
except Exception as callback_error:
|
|
188
|
+
logger.error(f"Error in callback on_llm_error: {callback_error}")
|
|
189
|
+
|
|
190
|
+
# Re-raise the original exception
|
|
191
|
+
raise
|
|
192
|
+
|
|
193
|
+
async def stream_text(
|
|
194
|
+
self,
|
|
195
|
+
messages: Union[str, list[LLMMessage]],
|
|
196
|
+
provider: Optional[Union[str, AIProvider]] = None,
|
|
197
|
+
model: Optional[str] = None,
|
|
198
|
+
context: Optional[Dict[str, Any]] = None,
|
|
199
|
+
temperature: float = 0.7,
|
|
200
|
+
max_tokens: Optional[int] = None,
|
|
201
|
+
callbacks: Optional[List[CustomAsyncCallbackHandler]] = None,
|
|
202
|
+
**kwargs
|
|
203
|
+
):
|
|
204
|
+
"""
|
|
205
|
+
Stream text generation using context-aware provider selection
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
messages: Either a string prompt or list of LLMMessage objects
|
|
209
|
+
provider: AI provider to use (can be overridden by context)
|
|
210
|
+
model: Specific model to use (can be overridden by context)
|
|
211
|
+
context: TaskContext or dict containing aiPreference
|
|
212
|
+
temperature: Sampling temperature (0.0 to 2.0)
|
|
213
|
+
max_tokens: Maximum tokens to generate
|
|
214
|
+
callbacks: List of callback handlers to execute during LLM calls
|
|
215
|
+
**kwargs: Additional provider-specific parameters
|
|
216
|
+
|
|
217
|
+
Yields:
|
|
218
|
+
str: Incremental text chunks
|
|
219
|
+
"""
|
|
220
|
+
# Extract provider/model from context if available
|
|
221
|
+
context_provider, context_model = self._extract_ai_preference(context)
|
|
222
|
+
|
|
223
|
+
# Use context preferences if available, otherwise use provided values
|
|
224
|
+
final_provider = context_provider or provider or AIProvider.OPENAI
|
|
225
|
+
final_model = context_model or model
|
|
226
|
+
|
|
227
|
+
# Convert string prompt to messages format
|
|
228
|
+
if isinstance(messages, str):
|
|
229
|
+
messages = [LLMMessage(role="user", content=messages)]
|
|
230
|
+
|
|
231
|
+
# Execute on_llm_start callbacks
|
|
232
|
+
if callbacks:
|
|
233
|
+
# Convert LLMMessage objects to dictionaries for callbacks
|
|
234
|
+
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
235
|
+
for callback in callbacks:
|
|
236
|
+
try:
|
|
237
|
+
await callback.on_llm_start(messages_dict, provider=final_provider, model=final_model, **kwargs)
|
|
238
|
+
except Exception as e:
|
|
239
|
+
logger.error(f"Error in callback on_llm_start: {e}")
|
|
240
|
+
|
|
241
|
+
try:
|
|
242
|
+
# Get the appropriate client
|
|
243
|
+
client = self.factory.get_client(final_provider)
|
|
244
|
+
|
|
245
|
+
# Collect streamed content for token counting
|
|
246
|
+
collected_content = ""
|
|
247
|
+
|
|
248
|
+
# Stream text
|
|
249
|
+
async for chunk in await client.stream_text(
|
|
250
|
+
messages=messages,
|
|
251
|
+
model=final_model,
|
|
252
|
+
temperature=temperature,
|
|
253
|
+
max_tokens=max_tokens,
|
|
254
|
+
**kwargs
|
|
255
|
+
):
|
|
256
|
+
collected_content += chunk
|
|
257
|
+
yield chunk
|
|
258
|
+
|
|
259
|
+
# Create a response object for callbacks (streaming doesn't return LLMResponse directly)
|
|
260
|
+
# We need to estimate token usage for streaming responses
|
|
261
|
+
estimated_tokens = len(collected_content) // 4 # Rough estimation
|
|
262
|
+
stream_response = LLMResponse(
|
|
263
|
+
content=collected_content,
|
|
264
|
+
provider=str(final_provider),
|
|
265
|
+
model=final_model or "unknown",
|
|
266
|
+
tokens_used=estimated_tokens
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Execute on_llm_end callbacks
|
|
270
|
+
if callbacks:
|
|
271
|
+
# Convert LLMResponse object to dictionary for callbacks
|
|
272
|
+
response_dict = {
|
|
273
|
+
"content": stream_response.content,
|
|
274
|
+
"provider": stream_response.provider,
|
|
275
|
+
"model": stream_response.model,
|
|
276
|
+
"tokens_used": stream_response.tokens_used,
|
|
277
|
+
"prompt_tokens": stream_response.prompt_tokens,
|
|
278
|
+
"completion_tokens": stream_response.completion_tokens,
|
|
279
|
+
"cost_estimate": stream_response.cost_estimate,
|
|
280
|
+
"response_time": stream_response.response_time
|
|
281
|
+
}
|
|
282
|
+
for callback in callbacks:
|
|
283
|
+
try:
|
|
284
|
+
await callback.on_llm_end(response_dict, provider=final_provider, model=final_model, **kwargs)
|
|
285
|
+
except Exception as e:
|
|
286
|
+
logger.error(f"Error in callback on_llm_end: {e}")
|
|
287
|
+
|
|
288
|
+
except Exception as e:
|
|
289
|
+
# Execute on_llm_error callbacks
|
|
290
|
+
if callbacks:
|
|
291
|
+
for callback in callbacks:
|
|
292
|
+
try:
|
|
293
|
+
await callback.on_llm_error(e, provider=final_provider, model=final_model, **kwargs)
|
|
294
|
+
except Exception as callback_error:
|
|
295
|
+
logger.error(f"Error in callback on_llm_error: {callback_error}")
|
|
296
|
+
|
|
297
|
+
# Re-raise the original exception
|
|
298
|
+
raise
|
|
299
|
+
|
|
300
|
+
async def close(self):
|
|
301
|
+
"""Close all clients"""
|
|
302
|
+
await self.factory.close_all()
|
|
303
|
+
|
|
304
|
+
# Global instance for easy access
|
|
305
|
+
_llm_manager = LLMClientManager()
|
|
306
|
+
|
|
307
|
+
async def get_llm_manager() -> LLMClientManager:
|
|
308
|
+
"""Get the global LLM manager instance"""
|
|
309
|
+
return _llm_manager
|
|
310
|
+
|
|
311
|
+
# Convenience functions for backward compatibility
|
|
312
|
+
async def generate_text(
|
|
313
|
+
messages: Union[str, list[LLMMessage]],
|
|
314
|
+
provider: Optional[Union[str, AIProvider]] = None,
|
|
315
|
+
model: Optional[str] = None,
|
|
316
|
+
context: Optional[Dict[str, Any]] = None,
|
|
317
|
+
temperature: float = 0.7,
|
|
318
|
+
max_tokens: Optional[int] = None,
|
|
319
|
+
callbacks: Optional[List[CustomAsyncCallbackHandler]] = None,
|
|
320
|
+
**kwargs
|
|
321
|
+
) -> LLMResponse:
|
|
322
|
+
"""Generate text using the global LLM manager"""
|
|
323
|
+
manager = await get_llm_manager()
|
|
324
|
+
return await manager.generate_text(messages, provider, model, context, temperature, max_tokens, callbacks, **kwargs)
|
|
325
|
+
|
|
326
|
+
async def stream_text(
|
|
327
|
+
messages: Union[str, list[LLMMessage]],
|
|
328
|
+
provider: Optional[Union[str, AIProvider]] = None,
|
|
329
|
+
model: Optional[str] = None,
|
|
330
|
+
context: Optional[Dict[str, Any]] = None,
|
|
331
|
+
temperature: float = 0.7,
|
|
332
|
+
max_tokens: Optional[int] = None,
|
|
333
|
+
callbacks: Optional[List[CustomAsyncCallbackHandler]] = None,
|
|
334
|
+
**kwargs
|
|
335
|
+
):
|
|
336
|
+
"""Stream text using the global LLM manager"""
|
|
337
|
+
manager = await get_llm_manager()
|
|
338
|
+
async for chunk in manager.stream_text(messages, provider, model, context, temperature, max_tokens, callbacks, **kwargs):
|
|
339
|
+
yield chunk
|