isa-model 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +359 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +101 -0
- isa_model/inference/providers/replicate_provider.py +107 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/openai_tts_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +97 -0
- isa_model/inference/services/embedding/openai_embed_service.py +0 -0
- isa_model/inference/services/llm/__init__.py +12 -0
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +99 -0
- isa_model/inference/services/llm/openai_llm_service.py +138 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/openai_vision_service.py +80 -0
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.0.1.dist-info/METADATA +327 -0
- isa_model-0.0.1.dist-info/RECORD +86 -0
- isa_model-0.0.1.dist-info/WHEEL +5 -0
- isa_model-0.0.1.dist-info/licenses/LICENSE +21 -0
- isa_model-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,71 @@
|
|
1
|
+
from typing import Dict, Any
|
2
|
+
import tempfile
|
3
|
+
import os
|
4
|
+
from openai import AsyncOpenAI
|
5
|
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
6
|
+
from isa_model.inference.services.base_service import BaseService
|
7
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
8
|
+
import logging
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
class YYDSAudioService(BaseService):
|
13
|
+
"""Audio model service wrapper for YYDS"""
|
14
|
+
|
15
|
+
def __init__(self, provider: 'BaseProvider', model_name: str):
|
16
|
+
super().__init__(provider, model_name)
|
17
|
+
# 初始化 AsyncOpenAI 客户端
|
18
|
+
self._client = AsyncOpenAI(
|
19
|
+
api_key=self.config.get('api_key'),
|
20
|
+
base_url=self.config.get('base_url')
|
21
|
+
)
|
22
|
+
self.language = self.config.get('language', None)
|
23
|
+
|
24
|
+
@property
|
25
|
+
def client(self) -> AsyncOpenAI:
|
26
|
+
"""获取底层的 OpenAI 客户端"""
|
27
|
+
return self._client
|
28
|
+
|
29
|
+
@retry(
|
30
|
+
stop=stop_after_attempt(3),
|
31
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
32
|
+
reraise=True
|
33
|
+
)
|
34
|
+
async def transcribe(self, audio_data: bytes) -> Dict[str, Any]:
|
35
|
+
"""转写音频数据
|
36
|
+
|
37
|
+
Args:
|
38
|
+
audio_data: 音频二进制数据
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
Dict[str, Any]: 包含转写文本的字典
|
42
|
+
"""
|
43
|
+
try:
|
44
|
+
# 创建临时文件存储音频数据
|
45
|
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
46
|
+
temp_file.write(audio_data)
|
47
|
+
temp_file.flush()
|
48
|
+
|
49
|
+
# 以二进制模式打开文件用于 API 请求
|
50
|
+
with open(temp_file.name, 'rb') as audio_file:
|
51
|
+
# 只在有效的 ISO-639-1 语言代码时包含 language 参数
|
52
|
+
params = {
|
53
|
+
'model': self.model_name,
|
54
|
+
'file': audio_file,
|
55
|
+
}
|
56
|
+
if self.language and isinstance(self.language, str):
|
57
|
+
params['language'] = self.language
|
58
|
+
|
59
|
+
response = await self._client.audio.transcriptions.create(**params)
|
60
|
+
|
61
|
+
# 清理临时文件
|
62
|
+
os.unlink(temp_file.name)
|
63
|
+
|
64
|
+
# 返回包含转写文本的字典
|
65
|
+
return {
|
66
|
+
"text": response.text
|
67
|
+
}
|
68
|
+
|
69
|
+
except Exception as e:
|
70
|
+
logger.error(f"Error in audio transcription: {e}")
|
71
|
+
raise
|
@@ -0,0 +1,106 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Dict, Any, List, Union, AsyncGenerator, TypeVar, Optional
|
3
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
4
|
+
|
5
|
+
T = TypeVar('T') # Generic type for responses
|
6
|
+
|
7
|
+
class BaseService(ABC):
|
8
|
+
"""Base class for all AI services"""
|
9
|
+
|
10
|
+
def __init__(self, provider: 'BaseProvider', model_name: str):
|
11
|
+
self.provider = provider
|
12
|
+
self.model_name = model_name
|
13
|
+
self.config = provider.get_config()
|
14
|
+
|
15
|
+
def __await__(self):
|
16
|
+
"""Make the service awaitable"""
|
17
|
+
yield
|
18
|
+
return self
|
19
|
+
|
20
|
+
class BaseLLMService(BaseService):
|
21
|
+
"""Base class for LLM services"""
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> T:
|
25
|
+
"""Universal invocation method"""
|
26
|
+
pass
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
async def achat(self, messages: List[Dict[str, str]]) -> T:
|
30
|
+
"""Chat completion method"""
|
31
|
+
pass
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
async def acompletion(self, prompt: str) -> T:
|
35
|
+
"""Text completion method"""
|
36
|
+
pass
|
37
|
+
|
38
|
+
@abstractmethod
|
39
|
+
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[T]:
|
40
|
+
"""Generate multiple completions"""
|
41
|
+
pass
|
42
|
+
|
43
|
+
@abstractmethod
|
44
|
+
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
45
|
+
"""Stream chat responses"""
|
46
|
+
pass
|
47
|
+
|
48
|
+
@abstractmethod
|
49
|
+
def get_token_usage(self) -> Any:
|
50
|
+
"""Get total token usage statistics"""
|
51
|
+
pass
|
52
|
+
|
53
|
+
@abstractmethod
|
54
|
+
def get_last_token_usage(self) -> Dict[str, int]:
|
55
|
+
"""Get token usage from last request"""
|
56
|
+
pass
|
57
|
+
|
58
|
+
class BaseEmbeddingService(BaseService):
|
59
|
+
"""Base class for embedding services"""
|
60
|
+
|
61
|
+
@abstractmethod
|
62
|
+
async def create_text_embedding(self, text: str) -> List[float]:
|
63
|
+
"""Create embedding for single text"""
|
64
|
+
pass
|
65
|
+
|
66
|
+
@abstractmethod
|
67
|
+
async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
68
|
+
"""Create embeddings for multiple texts"""
|
69
|
+
pass
|
70
|
+
|
71
|
+
@abstractmethod
|
72
|
+
async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
|
73
|
+
"""Create text chunks with embeddings"""
|
74
|
+
pass
|
75
|
+
|
76
|
+
@abstractmethod
|
77
|
+
async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
78
|
+
"""Compute similarity between two embeddings"""
|
79
|
+
pass
|
80
|
+
|
81
|
+
@abstractmethod
|
82
|
+
async def close(self):
|
83
|
+
"""Cleanup resources"""
|
84
|
+
pass
|
85
|
+
|
86
|
+
class BaseRerankService(BaseService):
|
87
|
+
"""Base class for reranking services"""
|
88
|
+
|
89
|
+
@abstractmethod
|
90
|
+
async def rerank(
|
91
|
+
self,
|
92
|
+
query: str,
|
93
|
+
documents: List[Dict],
|
94
|
+
top_k: int = 5
|
95
|
+
) -> List[Dict]:
|
96
|
+
"""Rerank documents based on query relevance"""
|
97
|
+
pass
|
98
|
+
|
99
|
+
@abstractmethod
|
100
|
+
async def rerank_texts(
|
101
|
+
self,
|
102
|
+
query: str,
|
103
|
+
texts: List[str]
|
104
|
+
) -> List[Dict]:
|
105
|
+
"""Rerank raw texts based on query relevance"""
|
106
|
+
pass
|
@@ -0,0 +1,97 @@
|
|
1
|
+
import logging
|
2
|
+
import httpx
|
3
|
+
import asyncio
|
4
|
+
from typing import List, Dict, Any, Optional
|
5
|
+
|
6
|
+
# 保留您指定的导入和框架结构
|
7
|
+
from isa_model.inference.services.base_service import BaseEmbeddingService
|
8
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
class OllamaEmbedService(BaseEmbeddingService):
|
13
|
+
"""
|
14
|
+
Ollama embedding service.
|
15
|
+
此类遵循基础服务架构,但使用其自己的 HTTP 客户端与 Ollama API 通信,
|
16
|
+
而不依赖于注入的 backend 对象。
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "bge-m3"):
|
20
|
+
# 保持对基类和 provider 的兼容
|
21
|
+
super().__init__(provider, model_name)
|
22
|
+
|
23
|
+
# 从基类继承的 self.config 中获取配置
|
24
|
+
host = self.config.get("host", "localhost")
|
25
|
+
port = self.config.get("port", 11434)
|
26
|
+
|
27
|
+
# 创建并持有自己的 httpx 客户端实例
|
28
|
+
base_url = f"http://{host}:{port}"
|
29
|
+
self.client = httpx.AsyncClient(base_url=base_url, timeout=30.0)
|
30
|
+
|
31
|
+
logger.info(f"Initialized OllamaEmbedService with model '{self.model_name}' at {base_url}")
|
32
|
+
|
33
|
+
async def create_text_embedding(self, text: str) -> List[float]:
|
34
|
+
"""为单个文本创建 embedding"""
|
35
|
+
try:
|
36
|
+
payload = {
|
37
|
+
"model": self.model_name,
|
38
|
+
"prompt": text
|
39
|
+
}
|
40
|
+
# 使用自己的 client 实例,而不是 self.backend
|
41
|
+
response = await self.client.post("/api/embeddings", json=payload)
|
42
|
+
response.raise_for_status() # 检查请求是否成功
|
43
|
+
return response.json()["embedding"]
|
44
|
+
|
45
|
+
except httpx.RequestError as e:
|
46
|
+
logger.error(f"An error occurred while requesting {e.request.url!r}: {e}")
|
47
|
+
raise
|
48
|
+
except Exception as e:
|
49
|
+
logger.error(f"Error creating text embedding: {e}")
|
50
|
+
raise
|
51
|
+
|
52
|
+
async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
53
|
+
"""为多个文本并发地创建 embeddings"""
|
54
|
+
if not texts:
|
55
|
+
return []
|
56
|
+
|
57
|
+
tasks = [self.create_text_embedding(text) for text in texts]
|
58
|
+
embeddings = await asyncio.gather(*tasks)
|
59
|
+
return embeddings
|
60
|
+
|
61
|
+
async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
|
62
|
+
"""将文本分块并为每个块创建 embedding"""
|
63
|
+
chunk_size = 200 # 单词数量
|
64
|
+
words = text.split()
|
65
|
+
chunk_texts = [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
|
66
|
+
|
67
|
+
if not chunk_texts:
|
68
|
+
return []
|
69
|
+
|
70
|
+
embeddings = await self.create_text_embeddings(chunk_texts)
|
71
|
+
|
72
|
+
chunks = [
|
73
|
+
{
|
74
|
+
"text": chunk_text,
|
75
|
+
"embedding": emb,
|
76
|
+
"metadata": metadata or {}
|
77
|
+
}
|
78
|
+
for chunk_text, emb in zip(chunk_texts, embeddings)
|
79
|
+
]
|
80
|
+
|
81
|
+
return chunks
|
82
|
+
|
83
|
+
async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
84
|
+
"""计算两个嵌入向量之间的余弦相似度"""
|
85
|
+
dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
|
86
|
+
norm1 = sum(a * a for a in embedding1) ** 0.5
|
87
|
+
norm2 = sum(b * b for b in embedding2) ** 0.5
|
88
|
+
|
89
|
+
if norm1 * norm2 == 0:
|
90
|
+
return 0.0
|
91
|
+
|
92
|
+
return dot_product / (norm1 * norm2)
|
93
|
+
|
94
|
+
async def close(self):
|
95
|
+
"""关闭内置的 HTTP 客户端"""
|
96
|
+
await self.client.aclose()
|
97
|
+
logger.info("OllamaEmbedService's internal client has been closed.")
|
File without changes
|
@@ -0,0 +1,12 @@
|
|
1
|
+
"""
|
2
|
+
LLM Services - Business logic services for Language Models
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Import LLM services here when created
|
6
|
+
from .ollama_llm_service import OllamaLLMService
|
7
|
+
from .openai_llm_service import OpenAILLMService
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"OllamaLLMService",
|
11
|
+
"OpenAILLMService",
|
12
|
+
]
|
@@ -0,0 +1,134 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Dict, Any, List, Union, Optional, AsyncGenerator, TypeVar
|
3
|
+
from isa_model.inference.services.base_service import BaseService
|
4
|
+
|
5
|
+
T = TypeVar('T') # Generic type for responses
|
6
|
+
|
7
|
+
class BaseLLMService(BaseService):
|
8
|
+
"""Base class for Large Language Model services"""
|
9
|
+
|
10
|
+
@abstractmethod
|
11
|
+
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> T:
|
12
|
+
"""
|
13
|
+
Universal invocation method that handles different input types
|
14
|
+
|
15
|
+
Args:
|
16
|
+
prompt: Can be a string, list of messages, or other format
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
Model response in the appropriate format
|
20
|
+
"""
|
21
|
+
pass
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
async def achat(self, messages: List[Dict[str, str]]) -> T:
|
25
|
+
"""
|
26
|
+
Chat completion method using message format
|
27
|
+
|
28
|
+
Args:
|
29
|
+
messages: List of message dictionaries with 'role' and 'content' keys
|
30
|
+
Example: [{"role": "user", "content": "Hello"}]
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
Chat completion response
|
34
|
+
"""
|
35
|
+
pass
|
36
|
+
|
37
|
+
@abstractmethod
|
38
|
+
async def acompletion(self, prompt: str) -> T:
|
39
|
+
"""
|
40
|
+
Text completion method for simple prompt completion
|
41
|
+
|
42
|
+
Args:
|
43
|
+
prompt: Input text prompt
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
Text completion response
|
47
|
+
"""
|
48
|
+
pass
|
49
|
+
|
50
|
+
@abstractmethod
|
51
|
+
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[T]:
|
52
|
+
"""
|
53
|
+
Generate multiple completions for the same input
|
54
|
+
|
55
|
+
Args:
|
56
|
+
messages: List of message dictionaries
|
57
|
+
n: Number of completions to generate
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
List of completion responses
|
61
|
+
"""
|
62
|
+
pass
|
63
|
+
|
64
|
+
@abstractmethod
|
65
|
+
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
66
|
+
"""
|
67
|
+
Stream chat responses token by token
|
68
|
+
|
69
|
+
Args:
|
70
|
+
messages: List of message dictionaries
|
71
|
+
|
72
|
+
Yields:
|
73
|
+
Individual tokens or chunks of the response
|
74
|
+
"""
|
75
|
+
pass
|
76
|
+
|
77
|
+
@abstractmethod
|
78
|
+
async def astream_completion(self, prompt: str) -> AsyncGenerator[str, None]:
|
79
|
+
"""
|
80
|
+
Stream completion responses token by token
|
81
|
+
|
82
|
+
Args:
|
83
|
+
prompt: Input text prompt
|
84
|
+
|
85
|
+
Yields:
|
86
|
+
Individual tokens or chunks of the response
|
87
|
+
"""
|
88
|
+
pass
|
89
|
+
|
90
|
+
@abstractmethod
|
91
|
+
def get_token_usage(self) -> Dict[str, Any]:
|
92
|
+
"""
|
93
|
+
Get cumulative token usage statistics for this service instance
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
Dict containing token usage information:
|
97
|
+
- total_tokens: Total tokens used
|
98
|
+
- prompt_tokens: Tokens used for prompts
|
99
|
+
- completion_tokens: Tokens used for completions
|
100
|
+
- requests_count: Number of requests made
|
101
|
+
"""
|
102
|
+
pass
|
103
|
+
|
104
|
+
@abstractmethod
|
105
|
+
def get_last_token_usage(self) -> Dict[str, int]:
|
106
|
+
"""
|
107
|
+
Get token usage from the last request
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
Dict containing last request token usage:
|
111
|
+
- prompt_tokens: Tokens in last prompt
|
112
|
+
- completion_tokens: Tokens in last completion
|
113
|
+
- total_tokens: Total tokens in last request
|
114
|
+
"""
|
115
|
+
pass
|
116
|
+
|
117
|
+
@abstractmethod
|
118
|
+
def get_model_info(self) -> Dict[str, Any]:
|
119
|
+
"""
|
120
|
+
Get information about the current model
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
Dict containing model information:
|
124
|
+
- name: Model name
|
125
|
+
- max_tokens: Maximum context length
|
126
|
+
- supports_streaming: Whether streaming is supported
|
127
|
+
- supports_functions: Whether function calling is supported
|
128
|
+
"""
|
129
|
+
pass
|
130
|
+
|
131
|
+
@abstractmethod
|
132
|
+
async def close(self):
|
133
|
+
"""Cleanup resources and close connections"""
|
134
|
+
pass
|
@@ -0,0 +1,99 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Dict, Any, List, Union, AsyncGenerator, Optional
|
3
|
+
from isa_model.inference.services.base_service import BaseLLMService
|
4
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
class OllamaLLMService(BaseLLMService):
|
9
|
+
"""Ollama LLM service using backend client"""
|
10
|
+
|
11
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "llama3.1"):
|
12
|
+
super().__init__(provider, model_name)
|
13
|
+
|
14
|
+
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
15
|
+
logger.info(f"Initialized OllamaLLMService with model {model_name}")
|
16
|
+
|
17
|
+
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]):
|
18
|
+
"""Universal invocation method"""
|
19
|
+
if isinstance(prompt, str):
|
20
|
+
return await self.acompletion(prompt)
|
21
|
+
elif isinstance(prompt, list):
|
22
|
+
return await self.achat(prompt)
|
23
|
+
else:
|
24
|
+
raise ValueError("Prompt must be string or list of messages")
|
25
|
+
|
26
|
+
async def achat(self, messages: List[Dict[str, str]]):
|
27
|
+
"""Chat completion method"""
|
28
|
+
try:
|
29
|
+
payload = {
|
30
|
+
"model": self.model_name,
|
31
|
+
"messages": messages,
|
32
|
+
"stream": False
|
33
|
+
}
|
34
|
+
response = await self.backend.post("/api/chat", payload)
|
35
|
+
|
36
|
+
# Update token usage if available
|
37
|
+
if "eval_count" in response:
|
38
|
+
self.last_token_usage = {
|
39
|
+
"prompt_tokens": response.get("prompt_eval_count", 0),
|
40
|
+
"completion_tokens": response.get("eval_count", 0),
|
41
|
+
"total_tokens": response.get("prompt_eval_count", 0) + response.get("eval_count", 0)
|
42
|
+
}
|
43
|
+
|
44
|
+
return response["message"]["content"]
|
45
|
+
|
46
|
+
except Exception as e:
|
47
|
+
logger.error(f"Error in chat completion: {e}")
|
48
|
+
raise
|
49
|
+
|
50
|
+
async def acompletion(self, prompt: str):
|
51
|
+
"""Text completion method"""
|
52
|
+
try:
|
53
|
+
payload = {
|
54
|
+
"model": self.model_name,
|
55
|
+
"prompt": prompt,
|
56
|
+
"stream": False
|
57
|
+
}
|
58
|
+
response = await self.backend.post("/api/generate", payload)
|
59
|
+
|
60
|
+
# Update token usage if available
|
61
|
+
if "eval_count" in response:
|
62
|
+
self.last_token_usage = {
|
63
|
+
"prompt_tokens": response.get("prompt_eval_count", 0),
|
64
|
+
"completion_tokens": response.get("eval_count", 0),
|
65
|
+
"total_tokens": response.get("prompt_eval_count", 0) + response.get("eval_count", 0)
|
66
|
+
}
|
67
|
+
|
68
|
+
return response["response"]
|
69
|
+
|
70
|
+
except Exception as e:
|
71
|
+
logger.error(f"Error in text completion: {e}")
|
72
|
+
raise
|
73
|
+
|
74
|
+
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
|
75
|
+
"""Generate multiple completions"""
|
76
|
+
results = []
|
77
|
+
for _ in range(n):
|
78
|
+
result = await self.achat(messages)
|
79
|
+
results.append(result)
|
80
|
+
return results
|
81
|
+
|
82
|
+
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
83
|
+
"""Stream chat responses"""
|
84
|
+
# Note: This would require modifying the backend to support streaming
|
85
|
+
# For now, return the full response
|
86
|
+
response = await self.achat(messages)
|
87
|
+
yield response
|
88
|
+
|
89
|
+
def get_token_usage(self):
|
90
|
+
"""Get total token usage statistics"""
|
91
|
+
return self.last_token_usage
|
92
|
+
|
93
|
+
def get_last_token_usage(self) -> Dict[str, int]:
|
94
|
+
"""Get token usage from last request"""
|
95
|
+
return self.last_token_usage
|
96
|
+
|
97
|
+
async def close(self):
|
98
|
+
"""Close the backend client"""
|
99
|
+
await self.backend.close()
|
@@ -0,0 +1,138 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from typing import Dict, Any, List, Union, AsyncGenerator, Optional
|
4
|
+
|
5
|
+
# 使用官方 OpenAI 库和 dotenv
|
6
|
+
from openai import AsyncOpenAI
|
7
|
+
from dotenv import load_dotenv
|
8
|
+
|
9
|
+
from isa_model.inference.services.base_service import BaseLLMService
|
10
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
11
|
+
|
12
|
+
# 加载 .env.local 文件中的环境变量
|
13
|
+
load_dotenv(dotenv_path='.env.local')
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
class OpenAILLMService(BaseLLMService):
|
18
|
+
"""OpenAI LLM service implementation"""
|
19
|
+
|
20
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-3.5-turbo"):
|
21
|
+
super().__init__(provider, model_name)
|
22
|
+
|
23
|
+
# 从provider配置初始化 AsyncOpenAI 客户端
|
24
|
+
try:
|
25
|
+
api_key = provider.config.get("api_key") or os.getenv("OPENAI_API_KEY")
|
26
|
+
base_url = provider.config.get("api_base") or os.getenv("OPENAI_API_BASE")
|
27
|
+
|
28
|
+
self.client = AsyncOpenAI(
|
29
|
+
api_key=api_key,
|
30
|
+
base_url=base_url
|
31
|
+
)
|
32
|
+
except TypeError as e:
|
33
|
+
logger.error("初始化 OpenAI 客户端失败。请检查您的 .env.local 文件中是否正确设置了 OPENAI_API_KEY。")
|
34
|
+
raise ValueError("OPENAI_API_KEY 未设置。") from e
|
35
|
+
|
36
|
+
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
37
|
+
logger.info(f"Initialized OpenAILLMService with model {self.model_name} and endpoint {self.client.base_url}")
|
38
|
+
|
39
|
+
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> str:
|
40
|
+
"""Universal invocation method"""
|
41
|
+
if isinstance(prompt, str):
|
42
|
+
return await self.acompletion(prompt)
|
43
|
+
elif isinstance(prompt, list):
|
44
|
+
return await self.achat(prompt)
|
45
|
+
else:
|
46
|
+
raise ValueError("Prompt must be a string or a list of messages")
|
47
|
+
|
48
|
+
async def achat(self, messages: List[Dict[str, str]]) -> str:
|
49
|
+
"""Chat completion method"""
|
50
|
+
try:
|
51
|
+
temperature = self.config.get("temperature", 0.7)
|
52
|
+
max_tokens = self.config.get("max_tokens", 1024)
|
53
|
+
|
54
|
+
response = await self.client.chat.completions.create(
|
55
|
+
model=self.model_name,
|
56
|
+
messages=messages,
|
57
|
+
temperature=temperature,
|
58
|
+
max_tokens=max_tokens
|
59
|
+
)
|
60
|
+
|
61
|
+
if response.usage:
|
62
|
+
self.last_token_usage = {
|
63
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
64
|
+
"completion_tokens": response.usage.completion_tokens,
|
65
|
+
"total_tokens": response.usage.total_tokens
|
66
|
+
}
|
67
|
+
|
68
|
+
return response.choices[0].message.content or ""
|
69
|
+
|
70
|
+
except Exception as e:
|
71
|
+
logger.error(f"Error in chat completion: {e}")
|
72
|
+
raise
|
73
|
+
|
74
|
+
async def acompletion(self, prompt: str) -> str:
|
75
|
+
"""Text completion method (using chat API)"""
|
76
|
+
messages = [{"role": "user", "content": prompt}]
|
77
|
+
return await self.achat(messages)
|
78
|
+
|
79
|
+
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
|
80
|
+
"""Generate multiple completions"""
|
81
|
+
try:
|
82
|
+
temperature = self.config.get("temperature", 0.7)
|
83
|
+
max_tokens = self.config.get("max_tokens", 1024)
|
84
|
+
|
85
|
+
response = await self.client.chat.completions.create(
|
86
|
+
model=self.model_name,
|
87
|
+
messages=messages,
|
88
|
+
temperature=temperature,
|
89
|
+
max_tokens=max_tokens,
|
90
|
+
n=n
|
91
|
+
)
|
92
|
+
|
93
|
+
if response.usage:
|
94
|
+
self.last_token_usage = {
|
95
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
96
|
+
"completion_tokens": response.usage.completion_tokens,
|
97
|
+
"total_tokens": response.usage.total_tokens
|
98
|
+
}
|
99
|
+
|
100
|
+
return [choice.message.content or "" for choice in response.choices]
|
101
|
+
except Exception as e:
|
102
|
+
logger.error(f"Error in generate: {e}")
|
103
|
+
raise
|
104
|
+
|
105
|
+
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
106
|
+
"""Stream chat responses"""
|
107
|
+
try:
|
108
|
+
temperature = self.config.get("temperature", 0.7)
|
109
|
+
max_tokens = self.config.get("max_tokens", 1024)
|
110
|
+
|
111
|
+
stream = await self.client.chat.completions.create(
|
112
|
+
model=self.model_name,
|
113
|
+
messages=messages,
|
114
|
+
temperature=temperature,
|
115
|
+
max_tokens=max_tokens,
|
116
|
+
stream=True
|
117
|
+
)
|
118
|
+
|
119
|
+
async for chunk in stream:
|
120
|
+
content = chunk.choices[0].delta.content
|
121
|
+
if content:
|
122
|
+
yield content
|
123
|
+
|
124
|
+
except Exception as e:
|
125
|
+
logger.error(f"Error in stream chat: {e}")
|
126
|
+
raise
|
127
|
+
|
128
|
+
def get_token_usage(self) -> Dict[str, int]:
|
129
|
+
"""Get total token usage statistics"""
|
130
|
+
return self.last_token_usage
|
131
|
+
|
132
|
+
def get_last_token_usage(self) -> Dict[str, int]:
|
133
|
+
"""Get token usage from last request"""
|
134
|
+
return self.last_token_usage
|
135
|
+
|
136
|
+
async def close(self):
|
137
|
+
"""Close the backend client"""
|
138
|
+
await self.client.aclose()
|