isa-model 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
  12. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
  13. isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
  14. isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
  15. isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
  16. isa_model/inference/__init__.py +11 -0
  17. isa_model/inference/adapter/unified_api.py +248 -0
  18. isa_model/inference/ai_factory.py +359 -0
  19. isa_model/inference/base.py +46 -0
  20. isa_model/inference/providers/__init__.py +19 -0
  21. isa_model/inference/providers/base_provider.py +30 -0
  22. isa_model/inference/providers/model_cache_manager.py +341 -0
  23. isa_model/inference/providers/ollama_provider.py +73 -0
  24. isa_model/inference/providers/openai_provider.py +101 -0
  25. isa_model/inference/providers/replicate_provider.py +107 -0
  26. isa_model/inference/providers/triton_provider.py +439 -0
  27. isa_model/inference/services/__init__.py +14 -0
  28. isa_model/inference/services/audio/base_stt_service.py +91 -0
  29. isa_model/inference/services/audio/base_tts_service.py +136 -0
  30. isa_model/inference/services/audio/openai_tts_service.py +71 -0
  31. isa_model/inference/services/base_service.py +106 -0
  32. isa_model/inference/services/embedding/ollama_embed_service.py +97 -0
  33. isa_model/inference/services/embedding/openai_embed_service.py +0 -0
  34. isa_model/inference/services/llm/__init__.py +12 -0
  35. isa_model/inference/services/llm/base_llm_service.py +134 -0
  36. isa_model/inference/services/llm/ollama_llm_service.py +99 -0
  37. isa_model/inference/services/llm/openai_llm_service.py +138 -0
  38. isa_model/inference/services/others/table_transformer_service.py +61 -0
  39. isa_model/inference/services/vision/__init__.py +12 -0
  40. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  41. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  42. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  43. isa_model/inference/services/vision/openai_vision_service.py +80 -0
  44. isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
  45. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  46. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  47. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  48. isa_model/scripts/inference_tracker.py +283 -0
  49. isa_model/scripts/mlflow_manager.py +379 -0
  50. isa_model/scripts/model_registry.py +465 -0
  51. isa_model/scripts/start_mlflow.py +95 -0
  52. isa_model/scripts/training_tracker.py +257 -0
  53. isa_model/training/engine/llama_factory/__init__.py +39 -0
  54. isa_model/training/engine/llama_factory/config.py +115 -0
  55. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  56. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  57. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  58. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  59. isa_model/training/engine/llama_factory/factory.py +331 -0
  60. isa_model/training/engine/llama_factory/rl.py +254 -0
  61. isa_model/training/engine/llama_factory/trainer.py +171 -0
  62. isa_model/training/image_model/configs/create_config.py +37 -0
  63. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  64. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  65. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  66. isa_model/training/image_model/prepare_upload.py +17 -0
  67. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  68. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  69. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  70. isa_model/training/image_model/train/train.py +42 -0
  71. isa_model/training/image_model/train/train_flux.py +41 -0
  72. isa_model/training/image_model/train/train_lora.py +57 -0
  73. isa_model/training/image_model/train_main.py +25 -0
  74. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  75. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  76. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  77. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  78. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  79. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  80. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  81. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  82. isa_model-0.0.1.dist-info/METADATA +327 -0
  83. isa_model-0.0.1.dist-info/RECORD +86 -0
  84. isa_model-0.0.1.dist-info/WHEEL +5 -0
  85. isa_model-0.0.1.dist-info/licenses/LICENSE +21 -0
  86. isa_model-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,71 @@
1
+ from typing import Dict, Any
2
+ import tempfile
3
+ import os
4
+ from openai import AsyncOpenAI
5
+ from tenacity import retry, stop_after_attempt, wait_exponential
6
+ from isa_model.inference.services.base_service import BaseService
7
+ from isa_model.inference.providers.base_provider import BaseProvider
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class YYDSAudioService(BaseService):
13
+ """Audio model service wrapper for YYDS"""
14
+
15
+ def __init__(self, provider: 'BaseProvider', model_name: str):
16
+ super().__init__(provider, model_name)
17
+ # 初始化 AsyncOpenAI 客户端
18
+ self._client = AsyncOpenAI(
19
+ api_key=self.config.get('api_key'),
20
+ base_url=self.config.get('base_url')
21
+ )
22
+ self.language = self.config.get('language', None)
23
+
24
+ @property
25
+ def client(self) -> AsyncOpenAI:
26
+ """获取底层的 OpenAI 客户端"""
27
+ return self._client
28
+
29
+ @retry(
30
+ stop=stop_after_attempt(3),
31
+ wait=wait_exponential(multiplier=1, min=4, max=10),
32
+ reraise=True
33
+ )
34
+ async def transcribe(self, audio_data: bytes) -> Dict[str, Any]:
35
+ """转写音频数据
36
+
37
+ Args:
38
+ audio_data: 音频二进制数据
39
+
40
+ Returns:
41
+ Dict[str, Any]: 包含转写文本的字典
42
+ """
43
+ try:
44
+ # 创建临时文件存储音频数据
45
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
46
+ temp_file.write(audio_data)
47
+ temp_file.flush()
48
+
49
+ # 以二进制模式打开文件用于 API 请求
50
+ with open(temp_file.name, 'rb') as audio_file:
51
+ # 只在有效的 ISO-639-1 语言代码时包含 language 参数
52
+ params = {
53
+ 'model': self.model_name,
54
+ 'file': audio_file,
55
+ }
56
+ if self.language and isinstance(self.language, str):
57
+ params['language'] = self.language
58
+
59
+ response = await self._client.audio.transcriptions.create(**params)
60
+
61
+ # 清理临时文件
62
+ os.unlink(temp_file.name)
63
+
64
+ # 返回包含转写文本的字典
65
+ return {
66
+ "text": response.text
67
+ }
68
+
69
+ except Exception as e:
70
+ logger.error(f"Error in audio transcription: {e}")
71
+ raise
@@ -0,0 +1,106 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, List, Union, AsyncGenerator, TypeVar, Optional
3
+ from isa_model.inference.providers.base_provider import BaseProvider
4
+
5
+ T = TypeVar('T') # Generic type for responses
6
+
7
+ class BaseService(ABC):
8
+ """Base class for all AI services"""
9
+
10
+ def __init__(self, provider: 'BaseProvider', model_name: str):
11
+ self.provider = provider
12
+ self.model_name = model_name
13
+ self.config = provider.get_config()
14
+
15
+ def __await__(self):
16
+ """Make the service awaitable"""
17
+ yield
18
+ return self
19
+
20
+ class BaseLLMService(BaseService):
21
+ """Base class for LLM services"""
22
+
23
+ @abstractmethod
24
+ async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> T:
25
+ """Universal invocation method"""
26
+ pass
27
+
28
+ @abstractmethod
29
+ async def achat(self, messages: List[Dict[str, str]]) -> T:
30
+ """Chat completion method"""
31
+ pass
32
+
33
+ @abstractmethod
34
+ async def acompletion(self, prompt: str) -> T:
35
+ """Text completion method"""
36
+ pass
37
+
38
+ @abstractmethod
39
+ async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[T]:
40
+ """Generate multiple completions"""
41
+ pass
42
+
43
+ @abstractmethod
44
+ async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
45
+ """Stream chat responses"""
46
+ pass
47
+
48
+ @abstractmethod
49
+ def get_token_usage(self) -> Any:
50
+ """Get total token usage statistics"""
51
+ pass
52
+
53
+ @abstractmethod
54
+ def get_last_token_usage(self) -> Dict[str, int]:
55
+ """Get token usage from last request"""
56
+ pass
57
+
58
+ class BaseEmbeddingService(BaseService):
59
+ """Base class for embedding services"""
60
+
61
+ @abstractmethod
62
+ async def create_text_embedding(self, text: str) -> List[float]:
63
+ """Create embedding for single text"""
64
+ pass
65
+
66
+ @abstractmethod
67
+ async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
68
+ """Create embeddings for multiple texts"""
69
+ pass
70
+
71
+ @abstractmethod
72
+ async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
73
+ """Create text chunks with embeddings"""
74
+ pass
75
+
76
+ @abstractmethod
77
+ async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
78
+ """Compute similarity between two embeddings"""
79
+ pass
80
+
81
+ @abstractmethod
82
+ async def close(self):
83
+ """Cleanup resources"""
84
+ pass
85
+
86
+ class BaseRerankService(BaseService):
87
+ """Base class for reranking services"""
88
+
89
+ @abstractmethod
90
+ async def rerank(
91
+ self,
92
+ query: str,
93
+ documents: List[Dict],
94
+ top_k: int = 5
95
+ ) -> List[Dict]:
96
+ """Rerank documents based on query relevance"""
97
+ pass
98
+
99
+ @abstractmethod
100
+ async def rerank_texts(
101
+ self,
102
+ query: str,
103
+ texts: List[str]
104
+ ) -> List[Dict]:
105
+ """Rerank raw texts based on query relevance"""
106
+ pass
@@ -0,0 +1,97 @@
1
+ import logging
2
+ import httpx
3
+ import asyncio
4
+ from typing import List, Dict, Any, Optional
5
+
6
+ # 保留您指定的导入和框架结构
7
+ from isa_model.inference.services.base_service import BaseEmbeddingService
8
+ from isa_model.inference.providers.base_provider import BaseProvider
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class OllamaEmbedService(BaseEmbeddingService):
13
+ """
14
+ Ollama embedding service.
15
+ 此类遵循基础服务架构,但使用其自己的 HTTP 客户端与 Ollama API 通信,
16
+ 而不依赖于注入的 backend 对象。
17
+ """
18
+
19
+ def __init__(self, provider: 'BaseProvider', model_name: str = "bge-m3"):
20
+ # 保持对基类和 provider 的兼容
21
+ super().__init__(provider, model_name)
22
+
23
+ # 从基类继承的 self.config 中获取配置
24
+ host = self.config.get("host", "localhost")
25
+ port = self.config.get("port", 11434)
26
+
27
+ # 创建并持有自己的 httpx 客户端实例
28
+ base_url = f"http://{host}:{port}"
29
+ self.client = httpx.AsyncClient(base_url=base_url, timeout=30.0)
30
+
31
+ logger.info(f"Initialized OllamaEmbedService with model '{self.model_name}' at {base_url}")
32
+
33
+ async def create_text_embedding(self, text: str) -> List[float]:
34
+ """为单个文本创建 embedding"""
35
+ try:
36
+ payload = {
37
+ "model": self.model_name,
38
+ "prompt": text
39
+ }
40
+ # 使用自己的 client 实例,而不是 self.backend
41
+ response = await self.client.post("/api/embeddings", json=payload)
42
+ response.raise_for_status() # 检查请求是否成功
43
+ return response.json()["embedding"]
44
+
45
+ except httpx.RequestError as e:
46
+ logger.error(f"An error occurred while requesting {e.request.url!r}: {e}")
47
+ raise
48
+ except Exception as e:
49
+ logger.error(f"Error creating text embedding: {e}")
50
+ raise
51
+
52
+ async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
53
+ """为多个文本并发地创建 embeddings"""
54
+ if not texts:
55
+ return []
56
+
57
+ tasks = [self.create_text_embedding(text) for text in texts]
58
+ embeddings = await asyncio.gather(*tasks)
59
+ return embeddings
60
+
61
+ async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
62
+ """将文本分块并为每个块创建 embedding"""
63
+ chunk_size = 200 # 单词数量
64
+ words = text.split()
65
+ chunk_texts = [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
66
+
67
+ if not chunk_texts:
68
+ return []
69
+
70
+ embeddings = await self.create_text_embeddings(chunk_texts)
71
+
72
+ chunks = [
73
+ {
74
+ "text": chunk_text,
75
+ "embedding": emb,
76
+ "metadata": metadata or {}
77
+ }
78
+ for chunk_text, emb in zip(chunk_texts, embeddings)
79
+ ]
80
+
81
+ return chunks
82
+
83
+ async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
84
+ """计算两个嵌入向量之间的余弦相似度"""
85
+ dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
86
+ norm1 = sum(a * a for a in embedding1) ** 0.5
87
+ norm2 = sum(b * b for b in embedding2) ** 0.5
88
+
89
+ if norm1 * norm2 == 0:
90
+ return 0.0
91
+
92
+ return dot_product / (norm1 * norm2)
93
+
94
+ async def close(self):
95
+ """关闭内置的 HTTP 客户端"""
96
+ await self.client.aclose()
97
+ logger.info("OllamaEmbedService's internal client has been closed.")
@@ -0,0 +1,12 @@
1
+ """
2
+ LLM Services - Business logic services for Language Models
3
+ """
4
+
5
+ # Import LLM services here when created
6
+ from .ollama_llm_service import OllamaLLMService
7
+ from .openai_llm_service import OpenAILLMService
8
+
9
+ __all__ = [
10
+ "OllamaLLMService",
11
+ "OpenAILLMService",
12
+ ]
@@ -0,0 +1,134 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, List, Union, Optional, AsyncGenerator, TypeVar
3
+ from isa_model.inference.services.base_service import BaseService
4
+
5
+ T = TypeVar('T') # Generic type for responses
6
+
7
+ class BaseLLMService(BaseService):
8
+ """Base class for Large Language Model services"""
9
+
10
+ @abstractmethod
11
+ async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> T:
12
+ """
13
+ Universal invocation method that handles different input types
14
+
15
+ Args:
16
+ prompt: Can be a string, list of messages, or other format
17
+
18
+ Returns:
19
+ Model response in the appropriate format
20
+ """
21
+ pass
22
+
23
+ @abstractmethod
24
+ async def achat(self, messages: List[Dict[str, str]]) -> T:
25
+ """
26
+ Chat completion method using message format
27
+
28
+ Args:
29
+ messages: List of message dictionaries with 'role' and 'content' keys
30
+ Example: [{"role": "user", "content": "Hello"}]
31
+
32
+ Returns:
33
+ Chat completion response
34
+ """
35
+ pass
36
+
37
+ @abstractmethod
38
+ async def acompletion(self, prompt: str) -> T:
39
+ """
40
+ Text completion method for simple prompt completion
41
+
42
+ Args:
43
+ prompt: Input text prompt
44
+
45
+ Returns:
46
+ Text completion response
47
+ """
48
+ pass
49
+
50
+ @abstractmethod
51
+ async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[T]:
52
+ """
53
+ Generate multiple completions for the same input
54
+
55
+ Args:
56
+ messages: List of message dictionaries
57
+ n: Number of completions to generate
58
+
59
+ Returns:
60
+ List of completion responses
61
+ """
62
+ pass
63
+
64
+ @abstractmethod
65
+ async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
66
+ """
67
+ Stream chat responses token by token
68
+
69
+ Args:
70
+ messages: List of message dictionaries
71
+
72
+ Yields:
73
+ Individual tokens or chunks of the response
74
+ """
75
+ pass
76
+
77
+ @abstractmethod
78
+ async def astream_completion(self, prompt: str) -> AsyncGenerator[str, None]:
79
+ """
80
+ Stream completion responses token by token
81
+
82
+ Args:
83
+ prompt: Input text prompt
84
+
85
+ Yields:
86
+ Individual tokens or chunks of the response
87
+ """
88
+ pass
89
+
90
+ @abstractmethod
91
+ def get_token_usage(self) -> Dict[str, Any]:
92
+ """
93
+ Get cumulative token usage statistics for this service instance
94
+
95
+ Returns:
96
+ Dict containing token usage information:
97
+ - total_tokens: Total tokens used
98
+ - prompt_tokens: Tokens used for prompts
99
+ - completion_tokens: Tokens used for completions
100
+ - requests_count: Number of requests made
101
+ """
102
+ pass
103
+
104
+ @abstractmethod
105
+ def get_last_token_usage(self) -> Dict[str, int]:
106
+ """
107
+ Get token usage from the last request
108
+
109
+ Returns:
110
+ Dict containing last request token usage:
111
+ - prompt_tokens: Tokens in last prompt
112
+ - completion_tokens: Tokens in last completion
113
+ - total_tokens: Total tokens in last request
114
+ """
115
+ pass
116
+
117
+ @abstractmethod
118
+ def get_model_info(self) -> Dict[str, Any]:
119
+ """
120
+ Get information about the current model
121
+
122
+ Returns:
123
+ Dict containing model information:
124
+ - name: Model name
125
+ - max_tokens: Maximum context length
126
+ - supports_streaming: Whether streaming is supported
127
+ - supports_functions: Whether function calling is supported
128
+ """
129
+ pass
130
+
131
+ @abstractmethod
132
+ async def close(self):
133
+ """Cleanup resources and close connections"""
134
+ pass
@@ -0,0 +1,99 @@
1
+ import logging
2
+ from typing import Dict, Any, List, Union, AsyncGenerator, Optional
3
+ from isa_model.inference.services.base_service import BaseLLMService
4
+ from isa_model.inference.providers.base_provider import BaseProvider
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class OllamaLLMService(BaseLLMService):
9
+ """Ollama LLM service using backend client"""
10
+
11
+ def __init__(self, provider: 'BaseProvider', model_name: str = "llama3.1"):
12
+ super().__init__(provider, model_name)
13
+
14
+ self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
15
+ logger.info(f"Initialized OllamaLLMService with model {model_name}")
16
+
17
+ async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]):
18
+ """Universal invocation method"""
19
+ if isinstance(prompt, str):
20
+ return await self.acompletion(prompt)
21
+ elif isinstance(prompt, list):
22
+ return await self.achat(prompt)
23
+ else:
24
+ raise ValueError("Prompt must be string or list of messages")
25
+
26
+ async def achat(self, messages: List[Dict[str, str]]):
27
+ """Chat completion method"""
28
+ try:
29
+ payload = {
30
+ "model": self.model_name,
31
+ "messages": messages,
32
+ "stream": False
33
+ }
34
+ response = await self.backend.post("/api/chat", payload)
35
+
36
+ # Update token usage if available
37
+ if "eval_count" in response:
38
+ self.last_token_usage = {
39
+ "prompt_tokens": response.get("prompt_eval_count", 0),
40
+ "completion_tokens": response.get("eval_count", 0),
41
+ "total_tokens": response.get("prompt_eval_count", 0) + response.get("eval_count", 0)
42
+ }
43
+
44
+ return response["message"]["content"]
45
+
46
+ except Exception as e:
47
+ logger.error(f"Error in chat completion: {e}")
48
+ raise
49
+
50
+ async def acompletion(self, prompt: str):
51
+ """Text completion method"""
52
+ try:
53
+ payload = {
54
+ "model": self.model_name,
55
+ "prompt": prompt,
56
+ "stream": False
57
+ }
58
+ response = await self.backend.post("/api/generate", payload)
59
+
60
+ # Update token usage if available
61
+ if "eval_count" in response:
62
+ self.last_token_usage = {
63
+ "prompt_tokens": response.get("prompt_eval_count", 0),
64
+ "completion_tokens": response.get("eval_count", 0),
65
+ "total_tokens": response.get("prompt_eval_count", 0) + response.get("eval_count", 0)
66
+ }
67
+
68
+ return response["response"]
69
+
70
+ except Exception as e:
71
+ logger.error(f"Error in text completion: {e}")
72
+ raise
73
+
74
+ async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
75
+ """Generate multiple completions"""
76
+ results = []
77
+ for _ in range(n):
78
+ result = await self.achat(messages)
79
+ results.append(result)
80
+ return results
81
+
82
+ async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
83
+ """Stream chat responses"""
84
+ # Note: This would require modifying the backend to support streaming
85
+ # For now, return the full response
86
+ response = await self.achat(messages)
87
+ yield response
88
+
89
+ def get_token_usage(self):
90
+ """Get total token usage statistics"""
91
+ return self.last_token_usage
92
+
93
+ def get_last_token_usage(self) -> Dict[str, int]:
94
+ """Get token usage from last request"""
95
+ return self.last_token_usage
96
+
97
+ async def close(self):
98
+ """Close the backend client"""
99
+ await self.backend.close()
@@ -0,0 +1,138 @@
1
+ import logging
2
+ import os
3
+ from typing import Dict, Any, List, Union, AsyncGenerator, Optional
4
+
5
+ # 使用官方 OpenAI 库和 dotenv
6
+ from openai import AsyncOpenAI
7
+ from dotenv import load_dotenv
8
+
9
+ from isa_model.inference.services.base_service import BaseLLMService
10
+ from isa_model.inference.providers.base_provider import BaseProvider
11
+
12
+ # 加载 .env.local 文件中的环境变量
13
+ load_dotenv(dotenv_path='.env.local')
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class OpenAILLMService(BaseLLMService):
18
+ """OpenAI LLM service implementation"""
19
+
20
+ def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-3.5-turbo"):
21
+ super().__init__(provider, model_name)
22
+
23
+ # 从provider配置初始化 AsyncOpenAI 客户端
24
+ try:
25
+ api_key = provider.config.get("api_key") or os.getenv("OPENAI_API_KEY")
26
+ base_url = provider.config.get("api_base") or os.getenv("OPENAI_API_BASE")
27
+
28
+ self.client = AsyncOpenAI(
29
+ api_key=api_key,
30
+ base_url=base_url
31
+ )
32
+ except TypeError as e:
33
+ logger.error("初始化 OpenAI 客户端失败。请检查您的 .env.local 文件中是否正确设置了 OPENAI_API_KEY。")
34
+ raise ValueError("OPENAI_API_KEY 未设置。") from e
35
+
36
+ self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
37
+ logger.info(f"Initialized OpenAILLMService with model {self.model_name} and endpoint {self.client.base_url}")
38
+
39
+ async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> str:
40
+ """Universal invocation method"""
41
+ if isinstance(prompt, str):
42
+ return await self.acompletion(prompt)
43
+ elif isinstance(prompt, list):
44
+ return await self.achat(prompt)
45
+ else:
46
+ raise ValueError("Prompt must be a string or a list of messages")
47
+
48
+ async def achat(self, messages: List[Dict[str, str]]) -> str:
49
+ """Chat completion method"""
50
+ try:
51
+ temperature = self.config.get("temperature", 0.7)
52
+ max_tokens = self.config.get("max_tokens", 1024)
53
+
54
+ response = await self.client.chat.completions.create(
55
+ model=self.model_name,
56
+ messages=messages,
57
+ temperature=temperature,
58
+ max_tokens=max_tokens
59
+ )
60
+
61
+ if response.usage:
62
+ self.last_token_usage = {
63
+ "prompt_tokens": response.usage.prompt_tokens,
64
+ "completion_tokens": response.usage.completion_tokens,
65
+ "total_tokens": response.usage.total_tokens
66
+ }
67
+
68
+ return response.choices[0].message.content or ""
69
+
70
+ except Exception as e:
71
+ logger.error(f"Error in chat completion: {e}")
72
+ raise
73
+
74
+ async def acompletion(self, prompt: str) -> str:
75
+ """Text completion method (using chat API)"""
76
+ messages = [{"role": "user", "content": prompt}]
77
+ return await self.achat(messages)
78
+
79
+ async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
80
+ """Generate multiple completions"""
81
+ try:
82
+ temperature = self.config.get("temperature", 0.7)
83
+ max_tokens = self.config.get("max_tokens", 1024)
84
+
85
+ response = await self.client.chat.completions.create(
86
+ model=self.model_name,
87
+ messages=messages,
88
+ temperature=temperature,
89
+ max_tokens=max_tokens,
90
+ n=n
91
+ )
92
+
93
+ if response.usage:
94
+ self.last_token_usage = {
95
+ "prompt_tokens": response.usage.prompt_tokens,
96
+ "completion_tokens": response.usage.completion_tokens,
97
+ "total_tokens": response.usage.total_tokens
98
+ }
99
+
100
+ return [choice.message.content or "" for choice in response.choices]
101
+ except Exception as e:
102
+ logger.error(f"Error in generate: {e}")
103
+ raise
104
+
105
+ async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
106
+ """Stream chat responses"""
107
+ try:
108
+ temperature = self.config.get("temperature", 0.7)
109
+ max_tokens = self.config.get("max_tokens", 1024)
110
+
111
+ stream = await self.client.chat.completions.create(
112
+ model=self.model_name,
113
+ messages=messages,
114
+ temperature=temperature,
115
+ max_tokens=max_tokens,
116
+ stream=True
117
+ )
118
+
119
+ async for chunk in stream:
120
+ content = chunk.choices[0].delta.content
121
+ if content:
122
+ yield content
123
+
124
+ except Exception as e:
125
+ logger.error(f"Error in stream chat: {e}")
126
+ raise
127
+
128
+ def get_token_usage(self) -> Dict[str, int]:
129
+ """Get total token usage statistics"""
130
+ return self.last_token_usage
131
+
132
+ def get_last_token_usage(self) -> Dict[str, int]:
133
+ """Get token usage from last request"""
134
+ return self.last_token_usage
135
+
136
+ async def close(self):
137
+ """Close the backend client"""
138
+ await self.client.aclose()