isa-model 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/mlflow_gateway/__init__.py +8 -0
  12. isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
  13. isa_model/deployment/unified_multimodal_client.py +341 -0
  14. isa_model/inference/__init__.py +11 -0
  15. isa_model/inference/adapter/triton_adapter.py +453 -0
  16. isa_model/inference/adapter/unified_api.py +248 -0
  17. isa_model/inference/ai_factory.py +354 -0
  18. isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
  19. isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
  20. isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
  21. isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
  22. isa_model/inference/backends/__init__.py +53 -0
  23. isa_model/inference/backends/base_backend_client.py +26 -0
  24. isa_model/inference/backends/container_services.py +104 -0
  25. isa_model/inference/backends/local_services.py +72 -0
  26. isa_model/inference/backends/openai_client.py +130 -0
  27. isa_model/inference/backends/replicate_client.py +197 -0
  28. isa_model/inference/backends/third_party_services.py +239 -0
  29. isa_model/inference/backends/triton_client.py +97 -0
  30. isa_model/inference/base.py +46 -0
  31. isa_model/inference/client_sdk/__init__.py +0 -0
  32. isa_model/inference/client_sdk/client.py +134 -0
  33. isa_model/inference/client_sdk/client_data_std.py +34 -0
  34. isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
  35. isa_model/inference/client_sdk/exceptions.py +0 -0
  36. isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
  37. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
  38. isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
  39. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
  40. isa_model/inference/providers/__init__.py +19 -0
  41. isa_model/inference/providers/base_provider.py +30 -0
  42. isa_model/inference/providers/model_cache_manager.py +341 -0
  43. isa_model/inference/providers/ollama_provider.py +73 -0
  44. isa_model/inference/providers/openai_provider.py +87 -0
  45. isa_model/inference/providers/replicate_provider.py +94 -0
  46. isa_model/inference/providers/triton_provider.py +439 -0
  47. isa_model/inference/providers/vllm_provider.py +0 -0
  48. isa_model/inference/providers/yyds_provider.py +83 -0
  49. isa_model/inference/services/__init__.py +14 -0
  50. isa_model/inference/services/audio/fish_speech/handler.py +215 -0
  51. isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
  52. isa_model/inference/services/audio/triton_speech_service.py +138 -0
  53. isa_model/inference/services/audio/whisper_service.py +186 -0
  54. isa_model/inference/services/audio/yyds_audio_service.py +71 -0
  55. isa_model/inference/services/base_service.py +106 -0
  56. isa_model/inference/services/base_tts_service.py +66 -0
  57. isa_model/inference/services/embedding/bge_service.py +183 -0
  58. isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
  59. isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
  60. isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
  61. isa_model/inference/services/llm/__init__.py +16 -0
  62. isa_model/inference/services/llm/gemma_service.py +143 -0
  63. isa_model/inference/services/llm/llama_service.py +143 -0
  64. isa_model/inference/services/llm/ollama_llm_service.py +108 -0
  65. isa_model/inference/services/llm/openai_llm_service.py +129 -0
  66. isa_model/inference/services/llm/replicate_llm_service.py +179 -0
  67. isa_model/inference/services/llm/triton_llm_service.py +230 -0
  68. isa_model/inference/services/others/table_transformer_service.py +61 -0
  69. isa_model/inference/services/vision/__init__.py +12 -0
  70. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  71. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  72. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  73. isa_model/inference/services/vision/replicate_vision_service.py +241 -0
  74. isa_model/inference/services/vision/triton_vision_service.py +199 -0
  75. isa_model/inference/services/vision/yyds_vision_service.py +80 -0
  76. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  77. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  78. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  79. isa_model/scripts/inference_tracker.py +283 -0
  80. isa_model/scripts/mlflow_manager.py +379 -0
  81. isa_model/scripts/model_registry.py +465 -0
  82. isa_model/scripts/start_mlflow.py +95 -0
  83. isa_model/scripts/training_tracker.py +257 -0
  84. isa_model/training/engine/llama_factory/__init__.py +39 -0
  85. isa_model/training/engine/llama_factory/config.py +115 -0
  86. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  87. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  88. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  89. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  90. isa_model/training/engine/llama_factory/factory.py +331 -0
  91. isa_model/training/engine/llama_factory/rl.py +254 -0
  92. isa_model/training/engine/llama_factory/trainer.py +171 -0
  93. isa_model/training/image_model/configs/create_config.py +37 -0
  94. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  95. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  96. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  97. isa_model/training/image_model/prepare_upload.py +17 -0
  98. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  99. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  100. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  101. isa_model/training/image_model/train/train.py +42 -0
  102. isa_model/training/image_model/train/train_flux.py +41 -0
  103. isa_model/training/image_model/train/train_lora.py +57 -0
  104. isa_model/training/image_model/train_main.py +25 -0
  105. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  106. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  107. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  108. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  109. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  110. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  111. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  112. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  113. isa_model-0.1.0.dist-info/METADATA +116 -0
  114. isa_model-0.1.0.dist-info/RECORD +117 -0
  115. isa_model-0.1.0.dist-info/WHEEL +5 -0
  116. isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
  117. isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,46 @@
1
+ from typing import Dict, List, Optional
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from uuid import uuid4
4
+
5
+ class TextChunkHelper:
6
+ """Text splitting and chunking helper"""
7
+
8
+ def __init__(self,
9
+ chunk_size: int = 512,
10
+ chunk_overlap: int = 50,
11
+ min_chunk_size: int = 50):
12
+ self.text_splitter = RecursiveCharacterTextSplitter(
13
+ chunk_size=chunk_size,
14
+ chunk_overlap=chunk_overlap,
15
+ length_function=len,
16
+ separators=["\n\n", "\n", ". ", ", ", " "]
17
+ )
18
+ self.min_chunk_size = min_chunk_size
19
+
20
+ def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
21
+ """Create text chunks with metadata"""
22
+ if not text or not isinstance(text, str):
23
+ raise ValueError("Text must be a non-empty string")
24
+
25
+ chunks = self.text_splitter.split_text(text)
26
+ valid_chunks = [
27
+ chunk for chunk in chunks
28
+ if len(chunk) >= self.min_chunk_size
29
+ ]
30
+
31
+ results = []
32
+ for i, chunk in enumerate(valid_chunks):
33
+ chunk_data = {
34
+ "chunk_id": f"chunk_{uuid4().hex[:8]}",
35
+ "content": chunk,
36
+ "token_count": len(chunk),
37
+ "metadata": {
38
+ **(metadata or {}),
39
+ "position": i,
40
+ "start_idx": text.find(chunk),
41
+ "end_idx": text.find(chunk) + len(chunk)
42
+ }
43
+ }
44
+ results.append(chunk_data)
45
+
46
+ return results
@@ -0,0 +1,60 @@
1
+ import os
2
+ import json
3
+ import base64
4
+ import ollama
5
+ from typing import Dict, Any, Union
6
+ from tenacity import retry, stop_after_attempt, wait_exponential
7
+ from ...base_service import BaseService
8
+ from ...base_provider import BaseProvider
9
+ from app.config.config_manager import config_manager
10
+
11
+ logger = config_manager.get_logger(__name__)
12
+
13
+ class OllamaVisionService(BaseService):
14
+ """Vision model service wrapper for Ollama using base64 encoded images"""
15
+
16
+ def __init__(self, provider: 'BaseProvider', model_name: str = 'gemma3:4b'):
17
+ super().__init__(provider, model_name)
18
+ self.max_tokens = self.config.get('max_tokens', 1000)
19
+ self.temperature = self.config.get('temperature', 0.7)
20
+
21
+ @retry(
22
+ stop=stop_after_attempt(3),
23
+ wait=wait_exponential(multiplier=1, min=4, max=10),
24
+ reraise=True
25
+ )
26
+ async def analyze_image(self, image_data: Union[bytes, str], query: str) -> str:
27
+ """分析图片并返回结果
28
+
29
+ Args:
30
+ image_data: 图片数据,可以是 bytes 或图片路径字符串
31
+ query: 查询文本
32
+
33
+ Returns:
34
+ str: 分析结果
35
+ """
36
+ try:
37
+ # 如果是文件路径,读取文件内容
38
+ if isinstance(image_data, str):
39
+ with open(image_data, 'rb') as f:
40
+ image_data = f.read()
41
+
42
+ # 转换为base64
43
+ image_base64 = base64.b64encode(image_data).decode('utf-8')
44
+
45
+ # 使用 ollama 库直接调用
46
+ response = ollama.chat(
47
+ model=self.model_name,
48
+ messages=[{
49
+ 'role': 'user',
50
+ 'content': query,
51
+ 'images': [image_base64]
52
+ }]
53
+ )
54
+
55
+ return response['message']['content']
56
+
57
+ except Exception as e:
58
+ logger.error(f"Error in image analysis: {e}")
59
+ raise
60
+
@@ -0,0 +1,241 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Replicate Vision服务
6
+ 用于与Replicate API交互,支持图像生成和图像分析
7
+ """
8
+
9
+ import os
10
+ import time
11
+ import uuid
12
+ import logging
13
+ from typing import Dict, Any, List, Optional, Union
14
+ import asyncio
15
+ import aiohttp
16
+ import replicate
17
+ import requests
18
+ from PIL import Image
19
+ from io import BytesIO
20
+
21
+ from isa_model.inference.services.base_service import BaseService
22
+ from isa_model.inference.providers.base_provider import BaseProvider
23
+ from isa_model.inference.base import ModelType
24
+
25
+ # 设置日志记录
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ class ReplicateVisionService(BaseService):
30
+ """
31
+ Replicate Vision服务,用于处理图像生成和分析
32
+ """
33
+
34
+ def __init__(self, provider: BaseProvider, model_name: str):
35
+ """
36
+ 初始化Replicate Vision服务
37
+
38
+ Args:
39
+ provider: Replicate提供商实例
40
+ model_name: Replicate模型ID (格式: 'username/model_name:version')
41
+ """
42
+ super().__init__(provider, model_name)
43
+ self.api_token = provider.config.get("api_token", os.environ.get("REPLICATE_API_TOKEN"))
44
+ self.client = replicate.Client(api_token=self.api_token)
45
+ self.model_type = ModelType.VISION
46
+
47
+ # 可选的默认配置
48
+ self.guidance_scale = provider.config.get("guidance_scale", 7.5)
49
+ self.num_inference_steps = provider.config.get("num_inference_steps", 30)
50
+
51
+ # 生成的图像存储目录
52
+ self.output_dir = "generated_images"
53
+ os.makedirs(self.output_dir, exist_ok=True)
54
+
55
+ async def load(self) -> None:
56
+ """
57
+ 加载模型(对于Replicate,这只是验证API令牌)
58
+ """
59
+ if not self.api_token:
60
+ raise ValueError("缺少Replicate API令牌,请设置REPLICATE_API_TOKEN环境变量")
61
+
62
+ # 验证令牌有效性
63
+ try:
64
+ self.client.api_token = self.api_token
65
+ logger.info(f"Replicate Vision服务初始化成功,使用模型: {self.model_name}")
66
+ except Exception as e:
67
+ logger.error(f"Replicate初始化失败: {e}")
68
+ raise
69
+
70
+ async def generate_image(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
71
+ """
72
+ 使用Replicate模型生成图像
73
+
74
+ Args:
75
+ input_data: 包含生成参数的字典
76
+
77
+ Returns:
78
+ 包含生成图像URL的结果字典
79
+ """
80
+ try:
81
+ # 设置默认参数
82
+ if "guidance_scale" not in input_data and self.guidance_scale:
83
+ input_data["guidance_scale"] = self.guidance_scale
84
+
85
+ if "num_inference_steps" not in input_data and self.num_inference_steps:
86
+ input_data["num_inference_steps"] = self.num_inference_steps
87
+
88
+ # 运行模型(同步API调用)
89
+ logger.info(f"开始使用模型 {self.model_name} 生成图像")
90
+
91
+ # 转换成异步操作
92
+ loop = asyncio.get_event_loop()
93
+ output = await loop.run_in_executor(
94
+ None,
95
+ lambda: replicate.run(self.model_name, input=input_data)
96
+ )
97
+
98
+ # 将结果转换为标准格式
99
+ # 处理Replicate对象输出
100
+ if hasattr(output, 'url'):
101
+ urls = [output.url]
102
+ elif isinstance(output, list) and all(hasattr(item, 'url') for item in output if item is not None):
103
+ urls = [item.url for item in output if item is not None]
104
+ else:
105
+ # 兼容直接返回URL字符串的情况
106
+ urls = output if isinstance(output, list) else [output]
107
+
108
+ result = {
109
+ "urls": urls,
110
+ "metadata": {
111
+ "model": self.model_name,
112
+ "input": input_data
113
+ }
114
+ }
115
+
116
+ logger.info(f"图像生成完成: {result['urls']}")
117
+ return result
118
+
119
+ except Exception as e:
120
+ logger.error(f"图像生成失败: {e}")
121
+ raise
122
+
123
+ async def analyze_image(self, image_path: str, prompt: str) -> Dict[str, Any]:
124
+ """
125
+ 分析图像(用于支持视觉分析模型)
126
+
127
+ Args:
128
+ image_path: 图像路径或URL
129
+ prompt: 分析提示
130
+
131
+ Returns:
132
+ 分析结果字典
133
+ """
134
+ try:
135
+ # 构建输入数据
136
+ input_data = {
137
+ "image": self._get_image_url(image_path),
138
+ "prompt": prompt
139
+ }
140
+
141
+ # 运行模型
142
+ logger.info(f"开始使用模型 {self.model_name} 分析图像")
143
+
144
+ # 转换成异步操作
145
+ loop = asyncio.get_event_loop()
146
+ output = await loop.run_in_executor(
147
+ None,
148
+ lambda: replicate.run(self.model_name, input=input_data)
149
+ )
150
+
151
+ result = {
152
+ "text": output,
153
+ "metadata": {
154
+ "model": self.model_name,
155
+ "input": input_data
156
+ }
157
+ }
158
+
159
+ logger.info(f"图像分析完成")
160
+ return result
161
+
162
+ except Exception as e:
163
+ logger.error(f"图像分析失败: {e}")
164
+ raise
165
+
166
+ async def generate_and_save(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
167
+ """
168
+ 生成图像并保存到本地
169
+
170
+ Args:
171
+ input_data: 包含生成参数的字典
172
+
173
+ Returns:
174
+ 包含生成图像URL和保存路径的结果字典
175
+ """
176
+ # 首先生成图像
177
+ result = await self.generate_image(input_data)
178
+
179
+ # 然后下载并保存
180
+ saved_paths = []
181
+ for i, url in enumerate(result["urls"]):
182
+ # 生成唯一文件名
183
+ timestamp = int(time.time())
184
+ file_name = f"{self.output_dir}/{timestamp}_{uuid.uuid4().hex[:8]}_{i+1}.png"
185
+
186
+ # 异步下载图像
187
+ try:
188
+ await self._download_image(url, file_name)
189
+ saved_paths.append(file_name)
190
+ logger.info(f"图像已保存至: {file_name}")
191
+ except Exception as e:
192
+ logger.error(f"保存图像失败: {e}")
193
+
194
+ # 添加保存路径到结果
195
+ result["saved_paths"] = saved_paths
196
+ return result
197
+
198
+ async def _download_image(self, url: str, save_path: str) -> None:
199
+ """
200
+ 异步下载图像并保存
201
+
202
+ Args:
203
+ url: 图像URL
204
+ save_path: 保存路径
205
+ """
206
+ try:
207
+ async with aiohttp.ClientSession() as session:
208
+ async with session.get(url) as response:
209
+ if response.status == 200:
210
+ content = await response.read()
211
+ img = Image.open(BytesIO(content))
212
+ img.save(save_path)
213
+ else:
214
+ logger.error(f"下载图像失败: HTTP {response.status}")
215
+ raise Exception(f"下载图像失败: HTTP {response.status}")
216
+ except Exception as e:
217
+ logger.error(f"下载图像时出错: {e}")
218
+ raise
219
+
220
+ def _get_image_url(self, image_path: str) -> str:
221
+ """
222
+ 获取图像URL(如果提供的是本地路径,则上传到临时存储)
223
+
224
+ Args:
225
+ image_path: 图像路径或URL
226
+
227
+ Returns:
228
+ 图像URL
229
+ """
230
+ # 如果已经是URL,直接返回
231
+ if image_path.startswith(("http://", "https://")):
232
+ return image_path
233
+
234
+ # 否则,这是一个需要上传的本地文件
235
+ # 注意:这里可以实现上传逻辑,但为简单起见,我们仅支持URL
236
+ raise NotImplementedError("当前仅支持图像URL,不支持上传本地文件")
237
+
238
+ async def unload(self) -> None:
239
+ """卸载模型(对于Replicate API,这是一个无操作)"""
240
+ logger.info(f"卸载Replicate Vision服务: {self.model_name}")
241
+ # 没有需要清理的资源
@@ -0,0 +1,199 @@
1
+ import json
2
+ import logging
3
+ import asyncio
4
+ import base64
5
+ import io
6
+ from PIL import Image
7
+ import numpy as np
8
+ from typing import Dict, List, Any, AsyncGenerator, Optional, Union
9
+
10
+ from isa_model.inference.services.base_service import BaseService
11
+ from isa_model.inference.providers.triton_provider import TritonProvider
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class TritonVisionService(BaseService):
17
+ """
18
+ Vision service that uses Triton Inference Server to run inference.
19
+ """
20
+
21
+ def __init__(self, provider: TritonProvider, model_name: str):
22
+ """
23
+ Initialize the Triton Vision service.
24
+
25
+ Args:
26
+ provider: The Triton provider
27
+ model_name: Name of the model in Triton (e.g., "Gemma3-4B")
28
+ """
29
+ super().__init__(provider, model_name)
30
+ self.client = None
31
+ self.token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
32
+ self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
33
+
34
+ async def _initialize_client(self):
35
+ """Initialize the Triton client"""
36
+ if self.client is None:
37
+ self.client = self.provider.create_client()
38
+
39
+ # Check if model is ready
40
+ if not self.provider.is_model_ready(self.model_name):
41
+ logger.error(f"Model {self.model_name} is not ready on Triton server")
42
+ raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
43
+
44
+ logger.info(f"Initialized Triton client for vision model: {self.model_name}")
45
+
46
+ async def process_image(self,
47
+ image: Union[str, Image.Image, bytes],
48
+ prompt: Optional[str] = None,
49
+ params: Optional[Dict[str, Any]] = None) -> str:
50
+ """
51
+ Process an image and generate a description.
52
+
53
+ Args:
54
+ image: Input image (PIL Image, base64 string, or bytes)
55
+ prompt: Optional text prompt to guide the model
56
+ params: Generation parameters
57
+
58
+ Returns:
59
+ Generated text description
60
+ """
61
+ await self._initialize_client()
62
+
63
+ try:
64
+ import tritonclient.http as httpclient
65
+
66
+ # Process the image to get numpy array
67
+ image_array = self._prepare_image_input(image)
68
+
69
+ # Create input tensors for the image
70
+ image_input = httpclient.InferInput("IMAGE", image_array.shape, "UINT8")
71
+ image_input.set_data_from_numpy(image_array)
72
+ inputs = [image_input]
73
+
74
+ # Add text prompt if provided
75
+ if prompt:
76
+ text_data = np.array([prompt], dtype=np.object_)
77
+ text_input = httpclient.InferInput("TEXT", text_data.shape, "BYTES")
78
+ text_input.set_data_from_numpy(text_data)
79
+ inputs.append(text_input)
80
+
81
+ # Add parameters if provided
82
+ if params:
83
+ default_params = {
84
+ "max_new_tokens": 512,
85
+ "temperature": 0.7,
86
+ "top_p": 0.9,
87
+ "do_sample": True
88
+ }
89
+ generation_params = {**default_params, **params}
90
+
91
+ param_json = json.dumps(generation_params)
92
+ param_data = np.array([param_json], dtype=np.object_)
93
+ param_input = httpclient.InferInput("PARAMETERS", param_data.shape, "BYTES")
94
+ param_input.set_data_from_numpy(param_data)
95
+ inputs.append(param_input)
96
+
97
+ # Create output tensor
98
+ outputs = [httpclient.InferRequestedOutput("TEXT")]
99
+
100
+ # Send the request
101
+ response = await asyncio.to_thread(
102
+ self.client.infer,
103
+ self.model_name,
104
+ inputs,
105
+ outputs=outputs
106
+ )
107
+
108
+ # Process the response
109
+ output = response.as_numpy("TEXT")
110
+ response_text = output[0].decode('utf-8')
111
+
112
+ # Update token usage (estimated since we don't have actual token counts)
113
+ prompt_tokens = len(prompt) // 4 if prompt else 100 # Rough estimate
114
+ completion_tokens = len(response_text) // 4 # Rough estimate
115
+ total_tokens = prompt_tokens + completion_tokens
116
+
117
+ self.last_token_usage = {
118
+ "prompt_tokens": prompt_tokens,
119
+ "completion_tokens": completion_tokens,
120
+ "total_tokens": total_tokens
121
+ }
122
+
123
+ # Update total token usage
124
+ self.token_usage["prompt_tokens"] += prompt_tokens
125
+ self.token_usage["completion_tokens"] += completion_tokens
126
+ self.token_usage["total_tokens"] += total_tokens
127
+
128
+ return response_text
129
+
130
+ except Exception as e:
131
+ logger.error(f"Error during Triton vision inference: {str(e)}")
132
+ raise
133
+
134
+ def get_token_usage(self) -> Dict[str, int]:
135
+ """
136
+ Get total token usage statistics.
137
+
138
+ Returns:
139
+ Dictionary with token usage statistics
140
+ """
141
+ return self.token_usage
142
+
143
+ def get_last_token_usage(self) -> Dict[str, int]:
144
+ """
145
+ Get token usage from last request.
146
+
147
+ Returns:
148
+ Dictionary with token usage statistics from last request
149
+ """
150
+ return self.last_token_usage
151
+
152
+ def _prepare_image_input(self, image: Union[str, Image.Image, bytes]) -> np.ndarray:
153
+ """
154
+ Process different types of image inputs into a numpy array.
155
+
156
+ Args:
157
+ image: Image input (PIL Image, base64 string, or bytes)
158
+
159
+ Returns:
160
+ Numpy array of the image
161
+ """
162
+ # Convert to PIL image first
163
+ pil_image = self._to_pil_image(image)
164
+
165
+ # Convert PIL image to numpy array
166
+ return np.array(pil_image)
167
+
168
+ def _to_pil_image(self, image: Union[str, Image.Image, bytes]) -> Image.Image:
169
+ """
170
+ Convert different image inputs to PIL Image.
171
+
172
+ Args:
173
+ image: Image input (PIL Image, base64 string, or bytes)
174
+
175
+ Returns:
176
+ PIL Image
177
+ """
178
+ if isinstance(image, Image.Image):
179
+ return image
180
+
181
+ elif isinstance(image, str):
182
+ # Check if it's a base64 string
183
+ if image.startswith("data:image"):
184
+ # Extract the base64 part
185
+ image = image.split(",")[1]
186
+
187
+ try:
188
+ # Try to decode as base64
189
+ image_bytes = base64.b64decode(image)
190
+ return Image.open(io.BytesIO(image_bytes))
191
+ except Exception:
192
+ # Try to open as a file path
193
+ return Image.open(image)
194
+
195
+ elif isinstance(image, bytes):
196
+ return Image.open(io.BytesIO(image))
197
+
198
+ else:
199
+ raise ValueError(f"Unsupported image type: {type(image)}")
@@ -0,0 +1,80 @@
1
+ from typing import Dict, Any, Union
2
+ from openai import AsyncOpenAI
3
+ from tenacity import retry, stop_after_attempt, wait_exponential
4
+ from ...base_service import BaseService
5
+ from ...base_provider import BaseProvider
6
+ from .helpers.image_utils import compress_image, encode_image_to_base64
7
+ from app.config.config_manager import config_manager
8
+
9
+ logger = config_manager.get_logger(__name__)
10
+
11
+ class YYDSVisionService(BaseService):
12
+ """Vision model service wrapper for YYDS"""
13
+
14
+ def __init__(self, provider: 'BaseProvider', model_name: str):
15
+ super().__init__(provider, model_name)
16
+ # 初始化 AsyncOpenAI 客户端
17
+ self._client = AsyncOpenAI(
18
+ api_key=self.config.get('api_key'),
19
+ base_url=self.config.get('base_url')
20
+ )
21
+ self.max_tokens = self.config.get('max_tokens', 1000)
22
+ self.temperature = self.config.get('temperature', 0.7)
23
+
24
+ @property
25
+ def client(self) -> AsyncOpenAI:
26
+ """获取底层的 OpenAI 客户端"""
27
+ return self._client
28
+
29
+ @retry(
30
+ stop=stop_after_attempt(3),
31
+ wait=wait_exponential(multiplier=1, min=4, max=10),
32
+ reraise=True
33
+ )
34
+ async def analyze_image(self, image_data: Union[bytes, str], query: str) -> str:
35
+ """分析图片并返回结果
36
+
37
+ Args:
38
+ image_data: 图片数据,可以是 bytes 或已编码的 base64 字符串
39
+ query: 查询文本
40
+
41
+ Returns:
42
+ str: 分析结果
43
+ """
44
+ try:
45
+ # 处理图片数据
46
+ if isinstance(image_data, bytes):
47
+ # 压缩并编码图片
48
+ compressed_image = compress_image(image_data)
49
+ image_b64 = encode_image_to_base64(compressed_image)
50
+ else:
51
+ image_b64 = image_data
52
+
53
+ # 移除可能存在的 base64 前缀
54
+ if 'base64,' in image_b64:
55
+ image_b64 = image_b64.split('base64,')[1]
56
+
57
+ # 使用 AsyncOpenAI 客户端创建请求
58
+ response = await self._client.chat.completions.create(
59
+ model=self.model_name,
60
+ messages=[{
61
+ "role": "user",
62
+ "content": [
63
+ {"type": "text", "text": query},
64
+ {
65
+ "type": "image_url",
66
+ "image_url": {
67
+ "url": f"data:image/jpeg;base64,{image_b64}"
68
+ }
69
+ }
70
+ ]
71
+ }],
72
+ max_tokens=self.max_tokens,
73
+ temperature=self.temperature
74
+ )
75
+
76
+ return response.choices[0].message.content
77
+
78
+ except Exception as e:
79
+ logger.error(f"Error in image analysis: {e}")
80
+ raise