isa-model 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_registry.py +273 -46
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/eval/__init__.py +56 -0
- isa_model/eval/benchmarks.py +469 -0
- isa_model/eval/factory.py +582 -0
- isa_model/eval/metrics.py +628 -0
- isa_model/inference/ai_factory.py +98 -93
- isa_model/inference/providers/openai_provider.py +21 -7
- isa_model/inference/providers/replicate_provider.py +18 -5
- isa_model/inference/providers/triton_provider.py +1 -1
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
- isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
- isa_model/inference/services/llm/__init__.py +0 -4
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +1 -10
- isa_model/inference/services/llm/openai_llm_service.py +70 -61
- isa_model/inference/services/vision/__init__.py +1 -1
- isa_model/inference/services/vision/ollama_vision_service.py +4 -4
- isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/training/__init__.py +44 -0
- isa_model/training/factory.py +393 -0
- isa_model-0.1.1.dist-info/METADATA +327 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/RECORD +35 -60
- isa_model/deployment/mlflow_gateway/__init__.py +0 -8
- isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
- isa_model/deployment/unified_multimodal_client.py +0 -341
- isa_model/inference/adapter/triton_adapter.py +0 -453
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
- isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
- isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
- isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
- isa_model/inference/backends/__init__.py +0 -53
- isa_model/inference/backends/base_backend_client.py +0 -26
- isa_model/inference/backends/container_services.py +0 -104
- isa_model/inference/backends/local_services.py +0 -72
- isa_model/inference/backends/openai_client.py +0 -130
- isa_model/inference/backends/replicate_client.py +0 -197
- isa_model/inference/backends/third_party_services.py +0 -239
- isa_model/inference/backends/triton_client.py +0 -97
- isa_model/inference/client_sdk/client.py +0 -134
- isa_model/inference/client_sdk/client_data_std.py +0 -34
- isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +0 -83
- isa_model/inference/services/audio/fish_speech/handler.py +0 -215
- isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
- isa_model/inference/services/audio/triton_speech_service.py +0 -138
- isa_model/inference/services/audio/whisper_service.py +0 -186
- isa_model/inference/services/base_tts_service.py +0 -66
- isa_model/inference/services/embedding/bge_service.py +0 -183
- isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
- isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
- isa_model/inference/services/llm/gemma_service.py +0 -143
- isa_model/inference/services/llm/llama_service.py +0 -143
- isa_model/inference/services/llm/replicate_llm_service.py +0 -179
- isa_model/inference/services/llm/triton_llm_service.py +0 -230
- isa_model/inference/services/vision/replicate_vision_service.py +0 -241
- isa_model/inference/services/vision/triton_vision_service.py +0 -199
- isa_model-0.1.0.dist-info/METADATA +0 -116
- /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/WHEEL +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,75 +1,87 @@
|
|
1
1
|
import logging
|
2
|
+
import httpx
|
3
|
+
import asyncio
|
2
4
|
from typing import List, Dict, Any, Optional
|
5
|
+
|
6
|
+
# 保留您指定的导入和框架结构
|
3
7
|
from isa_model.inference.services.base_service import BaseEmbeddingService
|
4
8
|
from isa_model.inference.providers.base_provider import BaseProvider
|
5
|
-
from isa_model.inference.backends.local_services import OllamaBackendClient
|
6
9
|
|
7
10
|
logger = logging.getLogger(__name__)
|
8
11
|
|
9
12
|
class OllamaEmbedService(BaseEmbeddingService):
|
10
|
-
"""
|
13
|
+
"""
|
14
|
+
Ollama embedding service.
|
15
|
+
此类遵循基础服务架构,但使用其自己的 HTTP 客户端与 Ollama API 通信,
|
16
|
+
而不依赖于注入的 backend 对象。
|
17
|
+
"""
|
11
18
|
|
12
|
-
def __init__(self, provider: 'BaseProvider', model_name: str = "bge-m3"
|
19
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "bge-m3"):
|
20
|
+
# 保持对基类和 provider 的兼容
|
13
21
|
super().__init__(provider, model_name)
|
14
22
|
|
15
|
-
#
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
23
|
+
# 从基类继承的 self.config 中获取配置
|
24
|
+
host = self.config.get("host", "localhost")
|
25
|
+
port = self.config.get("port", 11434)
|
26
|
+
|
27
|
+
# 创建并持有自己的 httpx 客户端实例
|
28
|
+
base_url = f"http://{host}:{port}"
|
29
|
+
self.client = httpx.AsyncClient(base_url=base_url, timeout=30.0)
|
22
30
|
|
23
|
-
logger.info(f"Initialized OllamaEmbedService with model {model_name}")
|
31
|
+
logger.info(f"Initialized OllamaEmbedService with model '{self.model_name}' at {base_url}")
|
24
32
|
|
25
33
|
async def create_text_embedding(self, text: str) -> List[float]:
|
26
|
-
"""
|
34
|
+
"""为单个文本创建 embedding"""
|
27
35
|
try:
|
28
36
|
payload = {
|
29
37
|
"model": self.model_name,
|
30
38
|
"prompt": text
|
31
39
|
}
|
32
|
-
|
33
|
-
|
40
|
+
# 使用自己的 client 实例,而不是 self.backend
|
41
|
+
response = await self.client.post("/api/embeddings", json=payload)
|
42
|
+
response.raise_for_status() # 检查请求是否成功
|
43
|
+
return response.json()["embedding"]
|
34
44
|
|
45
|
+
except httpx.RequestError as e:
|
46
|
+
logger.error(f"An error occurred while requesting {e.request.url!r}: {e}")
|
47
|
+
raise
|
35
48
|
except Exception as e:
|
36
49
|
logger.error(f"Error creating text embedding: {e}")
|
37
50
|
raise
|
38
51
|
|
39
52
|
async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
40
|
-
"""
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
53
|
+
"""为多个文本并发地创建 embeddings"""
|
54
|
+
if not texts:
|
55
|
+
return []
|
56
|
+
|
57
|
+
tasks = [self.create_text_embedding(text) for text in texts]
|
58
|
+
embeddings = await asyncio.gather(*tasks)
|
45
59
|
return embeddings
|
46
60
|
|
47
61
|
async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
|
48
|
-
"""
|
49
|
-
# 简单实现:将文本分成固定大小的块
|
62
|
+
"""将文本分块并为每个块创建 embedding"""
|
50
63
|
chunk_size = 200 # 单词数量
|
51
|
-
chunks = []
|
52
|
-
|
53
|
-
# 按单词分割
|
54
64
|
words = text.split()
|
65
|
+
chunk_texts = [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
|
55
66
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
67
|
+
if not chunk_texts:
|
68
|
+
return []
|
69
|
+
|
70
|
+
embeddings = await self.create_text_embeddings(chunk_texts)
|
71
|
+
|
72
|
+
chunks = [
|
73
|
+
{
|
62
74
|
"text": chunk_text,
|
63
|
-
"embedding":
|
75
|
+
"embedding": emb,
|
64
76
|
"metadata": metadata or {}
|
65
77
|
}
|
66
|
-
|
78
|
+
for chunk_text, emb in zip(chunk_texts, embeddings)
|
79
|
+
]
|
67
80
|
|
68
81
|
return chunks
|
69
82
|
|
70
83
|
async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
71
84
|
"""计算两个嵌入向量之间的余弦相似度"""
|
72
|
-
# 余弦相似度简单实现
|
73
85
|
dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
|
74
86
|
norm1 = sum(a * a for a in embedding1) ** 0.5
|
75
87
|
norm2 = sum(b * b for b in embedding2) ** 0.5
|
@@ -80,6 +92,6 @@ class OllamaEmbedService(BaseEmbeddingService):
|
|
80
92
|
return dot_product / (norm1 * norm2)
|
81
93
|
|
82
94
|
async def close(self):
|
83
|
-
"""
|
84
|
-
await self.
|
85
|
-
|
95
|
+
"""关闭内置的 HTTP 客户端"""
|
96
|
+
await self.client.aclose()
|
97
|
+
logger.info("OllamaEmbedService's internal client has been closed.")
|
@@ -4,13 +4,9 @@ LLM Services - Business logic services for Language Models
|
|
4
4
|
|
5
5
|
# Import LLM services here when created
|
6
6
|
from .ollama_llm_service import OllamaLLMService
|
7
|
-
from .triton_llm_service import TritonLLMService
|
8
7
|
from .openai_llm_service import OpenAILLMService
|
9
|
-
from .replicate_llm_service import ReplicateLLMService
|
10
8
|
|
11
9
|
__all__ = [
|
12
10
|
"OllamaLLMService",
|
13
|
-
"TritonLLMService",
|
14
11
|
"OpenAILLMService",
|
15
|
-
"ReplicateLLMService",
|
16
12
|
]
|
@@ -0,0 +1,134 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Dict, Any, List, Union, Optional, AsyncGenerator, TypeVar
|
3
|
+
from isa_model.inference.services.base_service import BaseService
|
4
|
+
|
5
|
+
T = TypeVar('T') # Generic type for responses
|
6
|
+
|
7
|
+
class BaseLLMService(BaseService):
|
8
|
+
"""Base class for Large Language Model services"""
|
9
|
+
|
10
|
+
@abstractmethod
|
11
|
+
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> T:
|
12
|
+
"""
|
13
|
+
Universal invocation method that handles different input types
|
14
|
+
|
15
|
+
Args:
|
16
|
+
prompt: Can be a string, list of messages, or other format
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
Model response in the appropriate format
|
20
|
+
"""
|
21
|
+
pass
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
async def achat(self, messages: List[Dict[str, str]]) -> T:
|
25
|
+
"""
|
26
|
+
Chat completion method using message format
|
27
|
+
|
28
|
+
Args:
|
29
|
+
messages: List of message dictionaries with 'role' and 'content' keys
|
30
|
+
Example: [{"role": "user", "content": "Hello"}]
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
Chat completion response
|
34
|
+
"""
|
35
|
+
pass
|
36
|
+
|
37
|
+
@abstractmethod
|
38
|
+
async def acompletion(self, prompt: str) -> T:
|
39
|
+
"""
|
40
|
+
Text completion method for simple prompt completion
|
41
|
+
|
42
|
+
Args:
|
43
|
+
prompt: Input text prompt
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
Text completion response
|
47
|
+
"""
|
48
|
+
pass
|
49
|
+
|
50
|
+
@abstractmethod
|
51
|
+
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[T]:
|
52
|
+
"""
|
53
|
+
Generate multiple completions for the same input
|
54
|
+
|
55
|
+
Args:
|
56
|
+
messages: List of message dictionaries
|
57
|
+
n: Number of completions to generate
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
List of completion responses
|
61
|
+
"""
|
62
|
+
pass
|
63
|
+
|
64
|
+
@abstractmethod
|
65
|
+
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
66
|
+
"""
|
67
|
+
Stream chat responses token by token
|
68
|
+
|
69
|
+
Args:
|
70
|
+
messages: List of message dictionaries
|
71
|
+
|
72
|
+
Yields:
|
73
|
+
Individual tokens or chunks of the response
|
74
|
+
"""
|
75
|
+
pass
|
76
|
+
|
77
|
+
@abstractmethod
|
78
|
+
async def astream_completion(self, prompt: str) -> AsyncGenerator[str, None]:
|
79
|
+
"""
|
80
|
+
Stream completion responses token by token
|
81
|
+
|
82
|
+
Args:
|
83
|
+
prompt: Input text prompt
|
84
|
+
|
85
|
+
Yields:
|
86
|
+
Individual tokens or chunks of the response
|
87
|
+
"""
|
88
|
+
pass
|
89
|
+
|
90
|
+
@abstractmethod
|
91
|
+
def get_token_usage(self) -> Dict[str, Any]:
|
92
|
+
"""
|
93
|
+
Get cumulative token usage statistics for this service instance
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
Dict containing token usage information:
|
97
|
+
- total_tokens: Total tokens used
|
98
|
+
- prompt_tokens: Tokens used for prompts
|
99
|
+
- completion_tokens: Tokens used for completions
|
100
|
+
- requests_count: Number of requests made
|
101
|
+
"""
|
102
|
+
pass
|
103
|
+
|
104
|
+
@abstractmethod
|
105
|
+
def get_last_token_usage(self) -> Dict[str, int]:
|
106
|
+
"""
|
107
|
+
Get token usage from the last request
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
Dict containing last request token usage:
|
111
|
+
- prompt_tokens: Tokens in last prompt
|
112
|
+
- completion_tokens: Tokens in last completion
|
113
|
+
- total_tokens: Total tokens in last request
|
114
|
+
"""
|
115
|
+
pass
|
116
|
+
|
117
|
+
@abstractmethod
|
118
|
+
def get_model_info(self) -> Dict[str, Any]:
|
119
|
+
"""
|
120
|
+
Get information about the current model
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
Dict containing model information:
|
124
|
+
- name: Model name
|
125
|
+
- max_tokens: Maximum context length
|
126
|
+
- supports_streaming: Whether streaming is supported
|
127
|
+
- supports_functions: Whether function calling is supported
|
128
|
+
"""
|
129
|
+
pass
|
130
|
+
|
131
|
+
@abstractmethod
|
132
|
+
async def close(self):
|
133
|
+
"""Cleanup resources and close connections"""
|
134
|
+
pass
|
@@ -2,23 +2,14 @@ import logging
|
|
2
2
|
from typing import Dict, Any, List, Union, AsyncGenerator, Optional
|
3
3
|
from isa_model.inference.services.base_service import BaseLLMService
|
4
4
|
from isa_model.inference.providers.base_provider import BaseProvider
|
5
|
-
from isa_model.inference.backends.local_services import OllamaBackendClient
|
6
5
|
|
7
6
|
logger = logging.getLogger(__name__)
|
8
7
|
|
9
8
|
class OllamaLLMService(BaseLLMService):
|
10
9
|
"""Ollama LLM service using backend client"""
|
11
10
|
|
12
|
-
def __init__(self, provider: 'BaseProvider', model_name: str = "llama3.1"
|
11
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "llama3.1"):
|
13
12
|
super().__init__(provider, model_name)
|
14
|
-
|
15
|
-
# Use provided backend or create new one
|
16
|
-
if backend:
|
17
|
-
self.backend = backend
|
18
|
-
else:
|
19
|
-
host = self.config.get("host", "localhost")
|
20
|
-
port = self.config.get("port", 11434)
|
21
|
-
self.backend = OllamaBackendClient(host, port)
|
22
13
|
|
23
14
|
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
24
15
|
logger.info(f"Initialized OllamaLLMService with model {model_name}")
|
@@ -1,72 +1,80 @@
|
|
1
1
|
import logging
|
2
|
+
import os
|
2
3
|
from typing import Dict, Any, List, Union, AsyncGenerator, Optional
|
4
|
+
|
5
|
+
# 使用官方 OpenAI 库和 dotenv
|
6
|
+
from openai import AsyncOpenAI
|
7
|
+
from dotenv import load_dotenv
|
8
|
+
|
3
9
|
from isa_model.inference.services.base_service import BaseLLMService
|
4
10
|
from isa_model.inference.providers.base_provider import BaseProvider
|
5
|
-
|
11
|
+
|
12
|
+
# 加载 .env.local 文件中的环境变量
|
13
|
+
load_dotenv(dotenv_path='.env.local')
|
6
14
|
|
7
15
|
logger = logging.getLogger(__name__)
|
8
16
|
|
9
17
|
class OpenAILLMService(BaseLLMService):
|
10
18
|
"""OpenAI LLM service implementation"""
|
11
19
|
|
12
|
-
def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-3.5-turbo"
|
20
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-3.5-turbo"):
|
13
21
|
super().__init__(provider, model_name)
|
14
22
|
|
15
|
-
#
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
23
|
+
# 从provider配置初始化 AsyncOpenAI 客户端
|
24
|
+
try:
|
25
|
+
api_key = provider.config.get("api_key") or os.getenv("OPENAI_API_KEY")
|
26
|
+
base_url = provider.config.get("api_base") or os.getenv("OPENAI_API_BASE")
|
27
|
+
|
28
|
+
self.client = AsyncOpenAI(
|
29
|
+
api_key=api_key,
|
30
|
+
base_url=base_url
|
31
|
+
)
|
32
|
+
except TypeError as e:
|
33
|
+
logger.error("初始化 OpenAI 客户端失败。请检查您的 .env.local 文件中是否正确设置了 OPENAI_API_KEY。")
|
34
|
+
raise ValueError("OPENAI_API_KEY 未设置。") from e
|
22
35
|
|
23
36
|
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
24
|
-
logger.info(f"Initialized OpenAILLMService with model {model_name}")
|
37
|
+
logger.info(f"Initialized OpenAILLMService with model {self.model_name} and endpoint {self.client.base_url}")
|
25
38
|
|
26
|
-
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]):
|
39
|
+
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> str:
|
27
40
|
"""Universal invocation method"""
|
28
41
|
if isinstance(prompt, str):
|
29
42
|
return await self.acompletion(prompt)
|
30
43
|
elif isinstance(prompt, list):
|
31
44
|
return await self.achat(prompt)
|
32
45
|
else:
|
33
|
-
raise ValueError("Prompt must be string or list of messages")
|
46
|
+
raise ValueError("Prompt must be a string or a list of messages")
|
34
47
|
|
35
|
-
async def achat(self, messages: List[Dict[str, str]]):
|
48
|
+
async def achat(self, messages: List[Dict[str, str]]) -> str:
|
36
49
|
"""Chat completion method"""
|
37
50
|
try:
|
38
51
|
temperature = self.config.get("temperature", 0.7)
|
39
52
|
max_tokens = self.config.get("max_tokens", 1024)
|
40
53
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
response = await self.backend.post("/chat/completions", payload)
|
54
|
+
response = await self.client.chat.completions.create(
|
55
|
+
model=self.model_name,
|
56
|
+
messages=messages,
|
57
|
+
temperature=temperature,
|
58
|
+
max_tokens=max_tokens
|
59
|
+
)
|
48
60
|
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
61
|
+
if response.usage:
|
62
|
+
self.last_token_usage = {
|
63
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
64
|
+
"completion_tokens": response.usage.completion_tokens,
|
65
|
+
"total_tokens": response.usage.total_tokens
|
66
|
+
}
|
55
67
|
|
56
|
-
return response
|
68
|
+
return response.choices[0].message.content or ""
|
57
69
|
|
58
70
|
except Exception as e:
|
59
71
|
logger.error(f"Error in chat completion: {e}")
|
60
72
|
raise
|
61
73
|
|
62
|
-
async def acompletion(self, prompt: str):
|
63
|
-
"""Text completion method (using chat API
|
64
|
-
|
65
|
-
|
66
|
-
return await self.achat(messages)
|
67
|
-
except Exception as e:
|
68
|
-
logger.error(f"Error in text completion: {e}")
|
69
|
-
raise
|
74
|
+
async def acompletion(self, prompt: str) -> str:
|
75
|
+
"""Text completion method (using chat API)"""
|
76
|
+
messages = [{"role": "user", "content": prompt}]
|
77
|
+
return await self.achat(messages)
|
70
78
|
|
71
79
|
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
|
72
80
|
"""Generate multiple completions"""
|
@@ -74,23 +82,22 @@ class OpenAILLMService(BaseLLMService):
|
|
74
82
|
temperature = self.config.get("temperature", 0.7)
|
75
83
|
max_tokens = self.config.get("max_tokens", 1024)
|
76
84
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
response = await self.backend.post("/chat/completions", payload)
|
85
|
+
response = await self.client.chat.completions.create(
|
86
|
+
model=self.model_name,
|
87
|
+
messages=messages,
|
88
|
+
temperature=temperature,
|
89
|
+
max_tokens=max_tokens,
|
90
|
+
n=n
|
91
|
+
)
|
85
92
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
93
|
+
if response.usage:
|
94
|
+
self.last_token_usage = {
|
95
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
96
|
+
"completion_tokens": response.usage.completion_tokens,
|
97
|
+
"total_tokens": response.usage.total_tokens
|
98
|
+
}
|
92
99
|
|
93
|
-
return [choice
|
100
|
+
return [choice.message.content or "" for choice in response.choices]
|
94
101
|
except Exception as e:
|
95
102
|
logger.error(f"Error in generate: {e}")
|
96
103
|
raise
|
@@ -101,22 +108,24 @@ class OpenAILLMService(BaseLLMService):
|
|
101
108
|
temperature = self.config.get("temperature", 0.7)
|
102
109
|
max_tokens = self.config.get("max_tokens", 1024)
|
103
110
|
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
+
stream = await self.client.chat.completions.create(
|
112
|
+
model=self.model_name,
|
113
|
+
messages=messages,
|
114
|
+
temperature=temperature,
|
115
|
+
max_tokens=max_tokens,
|
116
|
+
stream=True
|
117
|
+
)
|
111
118
|
|
112
|
-
async for chunk in
|
113
|
-
|
119
|
+
async for chunk in stream:
|
120
|
+
content = chunk.choices[0].delta.content
|
121
|
+
if content:
|
122
|
+
yield content
|
114
123
|
|
115
124
|
except Exception as e:
|
116
125
|
logger.error(f"Error in stream chat: {e}")
|
117
126
|
raise
|
118
127
|
|
119
|
-
def get_token_usage(self):
|
128
|
+
def get_token_usage(self) -> Dict[str, int]:
|
120
129
|
"""Get total token usage statistics"""
|
121
130
|
return self.last_token_usage
|
122
131
|
|
@@ -126,4 +135,4 @@ class OpenAILLMService(BaseLLMService):
|
|
126
135
|
|
127
136
|
async def close(self):
|
128
137
|
"""Close the backend client"""
|
129
|
-
await self.
|
138
|
+
await self.client.aclose()
|
@@ -7,6 +7,6 @@ Vision服务包
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
# 导出ReplicateVisionService
|
10
|
-
from isa_model.inference.services.vision.
|
10
|
+
from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateVisionService
|
11
11
|
|
12
12
|
__all__ = ["ReplicateVisionService"]
|
@@ -4,11 +4,11 @@ import base64
|
|
4
4
|
import ollama
|
5
5
|
from typing import Dict, Any, Union
|
6
6
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
7
|
-
from
|
8
|
-
from
|
9
|
-
|
7
|
+
from isa_model.inference.services.base_service import BaseService
|
8
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
9
|
+
import logging
|
10
10
|
|
11
|
-
logger =
|
11
|
+
logger = logging.getLogger(__name__)
|
12
12
|
|
13
13
|
class OllamaVisionService(BaseService):
|
14
14
|
"""Vision model service wrapper for Ollama using base64 encoded images"""
|
@@ -1,14 +1,14 @@
|
|
1
1
|
from typing import Dict, Any, Union
|
2
2
|
from openai import AsyncOpenAI
|
3
3
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
4
|
-
from
|
5
|
-
from
|
4
|
+
from isa_model.inference.services.base_service import BaseService
|
5
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
6
6
|
from .helpers.image_utils import compress_image, encode_image_to_base64
|
7
|
-
|
7
|
+
import logging
|
8
8
|
|
9
|
-
logger =
|
9
|
+
logger = logging.getLogger(__name__)
|
10
10
|
|
11
|
-
class
|
11
|
+
class OpenAIVisionService(BaseService):
|
12
12
|
"""Vision model service wrapper for YYDS"""
|
13
13
|
|
14
14
|
def __init__(self, provider: 'BaseProvider', model_name: str):
|